~egtann/migrate

3ca429816fbf9aee4e796307e91496d88acd9d76 — Evan Tann 25 days ago dccc868 override
add override for mariadb
2 files changed, 145 insertions(+), 43 deletions(-)

M cmd/migrate/main.go
M migrate.go
M cmd/migrate/main.go => cmd/migrate/main.go +20 -5
@@ 27,7 27,7 @@ func run() error {
	dbUser := flag.String("u", "", "database user")
	dbHost := flag.String("h", "127.0.0.1", "database host")
	dbPort := flag.Int("p", 0, "database port")
	dbType := flag.String("t", "mysql", "type of database (mysql, postgres, sqlite)")
	dbType := flag.String("t", "mysql", "type of database (mysql, mariadb, postgres, sqlite)")
	dry := flag.Bool("d", false, "dry run")
	sslKey := flag.String("ssl-key", "", "path to client key pem")
	sslCert := flag.String("ssl-cert", "", "path to client cert pem")


@@ 93,7 93,7 @@ func run() error {
		if *dbPort == 0 {
			*dbPort = 5432
		}
	case "mysql":
	case "mysql", "mariadb":
		if *dbUser == "" {
			*dbUser = "root"
		}


@@ 123,7 123,7 @@ func run() error {
	// Prepare our database-specific configs
	var db migrate.Store
	switch *dbType {
	case "mysql":
	case "mysql", "mariadb":
		var err error
		db, err = mysql.New(*dbUser, string(password), *dbHost,
			*dbName, *dbPort, *sslKey, *sslCert, *sslCA,


@@ 146,8 146,23 @@ func run() error {
		return errors.Wrap(err, "open")
	}

	var dbt migrate.DBType
	switch *dbType {
	case "mysql":
		dbt = migrate.DBTypeMySQL
	case "mariadb":
		dbt = migrate.DBTypeMariaDB
	case "postgres":
		dbt = migrate.DBTypePostgres
	case "sqlite":
		dbt = migrate.DBTypeSQLite
	default:
		return fmt.Errorf("unknown db type: %s", *dbType)
	}

	// Prepare our database for migrations and collect the relevant files.
	m, err := migrate.New(db, migrate.StdLogger{}, *migrationDir, *skip)
	m, err := migrate.New(db, migrate.StdLogger{}, dbt, *migrationDir,
		*skip)
	if err != nil {
		return err
	}


@@ 157,7 172,7 @@ func run() error {
			return nil
		}
		for i := len(m.Migrations); i < len(m.Files); i++ {
			fmt.Println("would migrate", m.Files[i].Name())
			fmt.Println("would migrate", m.Files[i].Info.Name())
		}
		return nil
	}

M migrate.go => migrate.go +125 -38
@@ 23,36 23,51 @@ var spaces = regexp.MustCompile(`\s+`)

type Migrate struct {
	Migrations []Migration
	Files      []os.FileInfo
	Files      []*file

	db  Store
	log Logger
	dir string
	idx int
}

type file struct {
	Info     os.FileInfo
	fullpath string
}

type Migration struct {
	Filename string
	Checksum string
	Content  string
	fullpath string
}

var regexNum = regexp.MustCompile(`^\d+`)

type DBType string

const (
	DBTypeMySQL    DBType = "mysql"
	DBTypeMariaDB  DBType = "mariadb"
	DBTypePostgres DBType = "postgres"
	DBTypeSQLite   DBType = "sqlite"
)

func New(
	db Store,
	log Logger,
	dbt DBType,
	dir, skip string,
) (*Migrate, error) {
	m := &Migrate{db: db, log: log, dir: dir}
	m := &Migrate{db: db, log: log}

	// Get files in migration dir and sort them
	var err error
	m.Files, err = readdir(dir)
	m.Files, err = readDir(dir, dbt)
	if err != nil {
		return nil, errors.Wrap(err, "get migrations")
	}
	if err = sortfiles(m.Files); err != nil {
	if err = sortFiles(m.Files); err != nil {
		return nil, errors.Wrap(err, "sort")
	}



@@ 101,6 116,19 @@ func New(
		return nil, errors.Wrap(err, "get migrations")
	}

	// Fill in migration fullpath field based on the db type.
	overrides, err := getOverrideSet(dir, dbt)
	if err != nil {
		return nil, fmt.Errorf("get override set: %w", err)
	}
	for i, mg := range m.Migrations {
		override, exist := overrides[mg.Filename]
		if exist {
			m.Migrations[i].fullpath = override.fullpath
		} else {
			m.Migrations[i].fullpath = filepath.Join(dir, mg.Filename)
		}
	}
	if err = m.validHistory(); err != nil {
		return nil, err
	}


@@ 112,11 140,11 @@ func New(
func (m *Migrate) Migrate() (bool, error) {
	var migrated bool
	for i := len(m.Migrations); i < len(m.Files); i++ {
		filename := m.Files[i].Name()
		if err := m.migrateFile(filename); err != nil {
		fi := m.Files[i]
		if err := m.migrateFile(fi); err != nil {
			return false, errors.Wrap(err, "migrate file")
		}
		m.log.Println("migrated", filename)
		m.log.Println("migrated", fi.Info.Name())
		migrated = true
	}
	return migrated, nil


@@ 131,11 159,12 @@ func (m *Migrate) validHistory() error {
	}
	for i := m.idx; i < len(m.Migrations); i++ {
		mg := m.Migrations[i]
		if mg.Filename != m.Files[i].Name() {
		if mg.Filename != m.Files[i].Info.Name() {
			m.log.Printf("\n%s was added to history before %s.\n",
				m.Files[i].Name(), mg.Filename)
				m.Files[i].Info.Name(), mg.Filename)
			return errors.New("failed to migrate. migrations must be appended")
		}
		fmt.Println("MIGRATION", mg.fullpath)
		if err := m.checkHash(mg); err != nil {
			return errors.Wrap(err, "check hash")
		}


@@ 144,7 173,7 @@ func (m *Migrate) validHistory() error {
}

func (m *Migrate) checkHash(mg Migration) error {
	fi, err := os.Open(filepath.Join(m.dir, mg.Filename))
	fi, err := os.Open(mg.fullpath)
	if err != nil {
		return err
	}


@@ 161,9 190,8 @@ func (m *Migrate) checkHash(mg Migration) error {
	return nil
}

func (m *Migrate) migrateFile(filename string) error {
	pth := filepath.Join(m.dir, filename)
	byt, err := ioutil.ReadFile(pth)
func (m *Migrate) migrateFile(fi *file) error {
	byt, err := ioutil.ReadFile(fi.fullpath)
	if err != nil {
		return err
	}


@@ 180,11 208,11 @@ func (m *Migrate) migrateFile(filename string) error {

	// Ensure that commands are present
	if len(filteredCmds) == 0 {
		return fmt.Errorf("no sql statements in file: %s", filename)
		return fmt.Errorf("no sql statements in file: %s", fi.Info.Name())
	}

	// Get our checkpoints, if any
	checkpoints, err := m.db.GetMetaCheckpoints(filename)
	checkpoints, err := m.db.GetMetaCheckpoints(fi.Info.Name())
	if err != nil {
		return errors.Wrap(err, "get checkpoints")
	}


@@ 209,7 237,7 @@ func (m *Migrate) migrateFile(filename string) error {
			if checksum != checkpoints[i] {
				return fmt.Errorf(
					"checksum does not equal checkpoint. has %s (cmd %d) changed?",
					filename, i)
					fi.Info.Name(), i)
			}
			continue
		}


@@ 228,7 256,7 @@ func (m *Migrate) migrateFile(filename string) error {
		_, err := m.db.Exec(cmd)
		if err != nil {
			m.log.Println("failed on", cmd)
			return fmt.Errorf("%s: %s", filename, err)
			return fmt.Errorf("%s: %s", fi.Info.Name(), err)
		}

		// Save a checkpoint


@@ 236,7 264,7 @@ func (m *Migrate) migrateFile(filename string) error {
		if err != nil {
			return errors.Wrap(err, "compute checksum")
		}
		err = m.db.InsertMetaCheckpoint(filename, cmd, checksum, i)
		err = m.db.InsertMetaCheckpoint(fi.Info.Name(), cmd, checksum, i)
		if err != nil {
			return errors.Wrap(err, "insert checkpoint")
		}


@@ 252,7 280,8 @@ func (m *Migrate) migrateFile(filename string) error {
	if err != nil {
		return errors.Wrap(err, "compute file checksum")
	}
	if err = m.db.InsertMigration(filename, string(byt), checksum); err != nil {
	err = m.db.InsertMigration(fi.Info.Name(), string(byt), checksum)
	if err != nil {
		return errors.Wrap(err, "insert migration")
	}
	return nil


@@ 265,7 294,7 @@ func (m *Migrate) skip(toFile string) (int, error) {
	// Ensure the file exists
	index := -1
	for i, fi := range m.Files {
		if fi.Name() == toFile {
		if fi.Info.Name() == toFile {
			index = i
			break
		}


@@ 274,8 303,7 @@ func (m *Migrate) skip(toFile string) (int, error) {
		return 0, fmt.Errorf("%s does not exist", toFile)
	}
	for i := 0; i <= index; i++ {
		name := m.Files[i].Name()
		fi, err := os.Open(filepath.Join(m.dir, name))
		fi, err := os.Open(m.Files[i].fullpath)
		if err != nil {
			return -1, err
		}


@@ 284,7 312,9 @@ func (m *Migrate) skip(toFile string) (int, error) {
			fi.Close()
			return -1, err
		}
		if err = m.db.UpsertMigration(name, content, checksum); err != nil {
		name := m.Files[i].Info.Name()
		err = m.db.UpsertMigration(name, content, checksum)
		if err != nil {
			fi.Close()
			return -1, err
		}


@@ 307,14 337,28 @@ func computeChecksum(r io.Reader) (content string, checksum string, err error) {
	return string(byt), fmt.Sprintf("%x", h.Sum(nil)), nil
}

// readdir collects file infos from the migration directory.
func readdir(dir string) ([]os.FileInfo, error) {
	files := []os.FileInfo{}
// readDir collects file infos from the migration directory.
func readDir(dir string, dbt DBType) ([]*file, error) {
	files := []*file{}
	tmp, err := ioutil.ReadDir(dir)
	if err != nil {
		return nil, errors.Wrap(err, "read dir")
	}

	// Allow for DB-specific workarounds. For instance, if MySQL and
	// MariaDB are subtly incompatible (and they are, as they name
	// CONSTRAINTS in different ways), then it's possible later migrations
	// will work on one database but not another, even though they should
	// be compatible. There is no easy workaround, especially when you're
	// on an OS with access to one database but not the other.
	//
	// To ease this, we crawl through secondary directories specific to the
	// name of the DB used. If "migrate -t maria-db" then we'll look for
	// the `maria-db` folder and prefer identical migration filenames in
	// that folder over the other one.
	for _, fi := range tmp {
		fullpath := filepath.Join(dir, fi.Name())

		// Skip directories and hidden files
		if fi.IsDir() || strings.HasPrefix(fi.Name(), ".") {
			continue


@@ 323,32 367,74 @@ func readdir(dir string) ([]os.FileInfo, error) {
		if filepath.Ext(fi.Name()) != ".sql" {
			continue
		}
		files = append(files, fi)
		files = append(files, &file{Info: fi, fullpath: fullpath})
	}
	if len(files) == 0 {
		return nil, errors.New("no sql migration files found (might be the wrong -dir)")
	}

	// Prioritize our specific database over the set in the main migration
	// directory.
	overrideSet, err := getOverrideSet(dir, dbt)
	if err != nil {
		return nil, fmt.Errorf("get override set: %w", err)
	}
	for i, fi := range files {
		if override, exist := overrideSet[fi.Info.Name()]; exist {
			files[i] = override
			fmt.Println("OVERRIDING", override.Info.Name())
		}
	}
	return files, nil
}

// sortfiles by name, ensuring that something like 1.sql, 2.sql, 10.sql is
func getOverrideSet(dir string, dbt DBType) (map[string]*file, error) {
	tmp, err := ioutil.ReadDir(dir)
	if err != nil {
		return nil, errors.Wrap(err, "read dir")
	}
	overrides := []*file{}
	for _, fi := range tmp {
		fullpath := filepath.Join(dir, fi.Name())
		if !fi.IsDir() || fi.Name() != string(dbt) {
			continue
		}

		// The empty DBType prevents recursive descent into structures
		// like ./mariadb/mariadb/mariadb/...
		overrides, err = readDir(fullpath, DBType(""))
		if err != nil {
			return nil, fmt.Errorf("read dir %s: %w",
				fi.Name(), err)
		}
	}
	overrideSet := make(map[string]*file, len(overrides))
	for _, o := range overrides {
		overrideSet[o.Info.Name()] = o
	}
	return overrideSet, nil
}

// sortFiles by name, ensuring that something like 1.sql, 2.sql, 10.sql is
// ordered correctly.
func sortfiles(files []os.FileInfo) error {
func sortFiles(files []*file) error {
	var nameErr error
	sort.Slice(files, func(i, j int) bool {
		if nameErr != nil {
			return false
		}
		fiName1 := regexNum.FindString(files[i].Name())
		fiName2 := regexNum.FindString(files[j].Name())
		fiName1 := regexNum.FindString(files[i].Info.Name())
		fiName2 := regexNum.FindString(files[j].Info.Name())
		fiNum1, err := strconv.ParseUint(fiName1, 10, 64)
		if err != nil {
			nameErr = errors.Wrapf(err, "parse uint in file %s", files[i].Name())
			nameErr = errors.Wrapf(err, "parse uint in file %s",
				files[i].Info.Name())
			return false
		}
		fiNum2, err := strconv.ParseUint(fiName2, 10, 64)
		if err != nil {
			nameErr = errors.Wrapf(err, "parse uint in file %s", files[i].Name())
			nameErr = errors.Wrapf(err, "parse uint in file %s",
				files[i].Info.Name())
			return false
		}
		if fiNum1 == fiNum2 {


@@ 362,15 448,16 @@ func sortfiles(files []os.FileInfo) error {

func migrationsFromFiles(m *Migrate) ([]Migration, error) {
	ms := make([]Migration, len(m.Files))
	for i, fileInfo := range m.Files {
		filename := filepath.Join(m.dir, fileInfo.Name())
		byt, err := ioutil.ReadFile(filename)
	for i, fi := range m.Files {
		fmt.Println("FULLPATH", fi.fullpath)
		byt, err := ioutil.ReadFile(fi.fullpath)
		if err != nil {
			return nil, errors.Wrap(err, "read file")
		}
		ms[i] = Migration{
			Filename: fileInfo.Name(),
			Filename: fi.Info.Name(),
			Content:  string(byt),
			fullpath: fi.fullpath,
		}
	}
	return ms, nil