~egtann/sum

69af265c92763de7d73a49e82e77487769205337 — Evan Tann 4 months ago 7949e02
fix mysql tls connection bug, all users bug
9 files changed, 92 insertions(+), 54 deletions(-)

M README.md
M cmd/sf/main.go
M go.mod
M go.sum
M mysql/store.go
M parser.go
M parser_test.go
M sf.go
M store.go
M README.md => README.md +2 -0
@@ 111,5 111,7 @@ every database migration.

## TODO

* validate that every defined user, db, table, statement, column at least
  exists in the db, since otherwise the failure is silent
* quoted values, e.g. multi-word statement "alter routine"
* combine users into the `permissions` struct to simplify code?

M cmd/sf/main.go => cmd/sf/main.go +1 -1
@@ 95,7 95,7 @@ func parseFlags() *flags {
	flag.StringVar(&f.sslKey, "ssl-key", "", "pem file containing the ssl client key")
	flag.StringVar(&f.sslCert, "ssl-cert", "", "pem file containing the ssl client certificate")
	flag.StringVar(&f.sslCA, "ssl-ca", "", "pem file containing the ssl server ca")
	flag.StringVar(&f.sslServerName, "ssl-server-name", "", "ssl server name")
	flag.StringVar(&f.sslServerName, "ssl-server", "", "ssl server name")
	flag.Parse()
	return f
}

M go.mod => go.mod +2 -1
@@ 3,8 3,9 @@ module egt.run/sf
go 1.14

require (
	github.com/go-sql-driver/mysql v1.4.0
	github.com/go-sql-driver/mysql v1.5.0
	github.com/jmoiron/sqlx v1.2.0
	golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9
	golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae // indirect
	google.golang.org/appengine v1.6.6 // indirect
)

M go.sum => go.sum +4 -0
@@ 1,5 1,7 @@
github.com/go-sql-driver/mysql v1.4.0 h1:7LxgVwFb2hIQtMm87NdgAVfXjnt4OePseqT1tKx+opk=
github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w=
github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs=
github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/jmoiron/sqlx v1.2.0 h1:41Ip0zITnmWNR/vHV+S4m+VoUivnWY5E4OJfLZjCJMA=
github.com/jmoiron/sqlx v1.2.0/go.mod h1:1FEQNm3xlJgrMD+FBdI9+xvCksHtbpVBBw5dYhBSsks=


@@ 15,6 17,8 @@ golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d h1:+R4KGOnez64A81RvjARKc4UT5/tI9ujCIVX+P5KiHuI=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae h1:Ih9Yo4hSPImZOpfGuA4bR/ORKTAbhZo2AbWNRCnevdo=
golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=

M mysql/store.go => mysql/store.go +31 -10
@@ 6,6 6,7 @@ import (
	"errors"
	"fmt"
	"io/ioutil"
	"strings"

	"egt.run/sf"
	"github.com/go-sql-driver/mysql"


@@ 30,7 31,7 @@ func New(
	db := &DB{}
	db.connURL = fmt.Sprintf("%s:%s@tcp(%s:%d)/", user, pass, host, port)
	if sslKey != "" {
		db.connURL = fmt.Sprintf("%s&tls=%s", db.connURL, sslServerName)
		db.connURL = fmt.Sprintf("%s?tls=%s", db.connURL, sslServerName)
		var err error
		db.tlsConfig, err = newTLSConfig(sslKey, sslCert, sslCA,
			sslServerName)


@@ 54,6 55,9 @@ func (db *DB) Open() error {
	if err != nil {
		return fmt.Errorf("open db connection: %w", err)
	}
	if err = db.DB.Ping(); err != nil {
		return fmt.Errorf("ping: %w", err)
	}
	return nil
}



@@ 95,7 99,7 @@ func (db *DB) Statements() []string {
	}
}

func (db *DB) GetSchema() (sf.Schema, error) {
func (db *DB) GetSchema() (*sf.Schema, error) {
	var data []struct {
		Database string `db:"table_schema"`
		Table    string `db:"table_name"`


@@ 104,19 108,36 @@ func (db *DB) GetSchema() (sf.Schema, error) {
	q := `SELECT table_schema, table_name, column_name
	      FROM information_schema.columns`
	if err := db.Select(&data, q); err != nil {
		return nil, fmt.Errorf("select: %w", err)
		return nil, fmt.Errorf("select columns: %w", err)
	}
	schema := map[string]sf.Database{}
	schema := &sf.Schema{Databases: map[string]sf.Database{}}
	for _, d := range data {
		if _, ok := schema[d.Database]; !ok {
			schema[d.Database] = sf.Database{}
		if _, ok := schema.Databases[d.Database]; !ok {
			schema.Databases[d.Database] = sf.Database{}
		}
		if _, ok := schema.Databases[d.Database][d.Table]; !ok {
			schema.Databases[d.Database][d.Table] = sf.Table{}
		}
		if _, ok := schema[d.Database][d.Table]; !ok {
			schema[d.Database][d.Table] = sf.Table{}
		schema.Databases[d.Database][d.Table][d.Column] = struct{}{}
	}

	var users []struct {
		User string
		Host string
	}
	q = `SELECT user, host FROM mysql.user`
	if err := db.Select(&users, q); err != nil {
		return nil, fmt.Errorf("select users: %w", err)
	}
	for _, u := range users {
		// Don't modify mysql.session or mysql.sys
		if strings.HasPrefix(u.User, "mysql.") {
			continue
		}
		schema[d.Database][d.Table][d.Column] = struct{}{}
		userHost := fmt.Sprintf("'%s'@'%s'", u.User, u.Host)
		schema.Users = append(schema.Users, userHost)
	}
	return sf.Schema(schema), nil
	return schema, nil
}

func (db *DB) Apply(s string) error {

M parser.go => parser.go +12 -15
@@ 45,7 45,7 @@ type tablePermission struct {

// compile permissions for every user in statements.
func compile(
	schema Schema,
	schema *Schema,
	allStatements []string,
	a *ast,
) (map[string]*permissions, error) {


@@ 72,22 72,18 @@ func compile(
	// "all" users as needed.
	userLines := map[string][]*line{}
	for _, line := range a.lines {
		if len(line.users) == 1 && line.users[0] == "all" {
			line.users = allUsers
		}
		for _, u := range line.users {
			if strings.HasPrefix(u, "$") {
				users := a.vars[strings.TrimPrefix(u, "$")]
				for _, u2 := range users {
					userLines[u2] = append(userLines[u2], line)
				}
				continue
			}
			if u != "all" {
			if !strings.HasPrefix(u, "$") {
				userLines[u] = append(userLines[u], line)
				continue
			}
			for _, u2 := range allUsers {
			users := a.vars[strings.TrimPrefix(u, "$")]
			for _, u2 := range users {
				userLines[u2] = append(userLines[u2], line)
			}
			continue
		}
	}



@@ 95,7 91,8 @@ func compile(
	userPerms := map[string]*permissions{}
	for u, lines := range userLines {
		var err error
		userPerms[u], err = permsForLines(schema, allStatements, a.vars, lines)
		userPerms[u], err = permsForLines(schema, allStatements,
			a.vars, lines)
		if err != nil {
			return nil, fmt.Errorf("perms for lines: %w", err)
		}


@@ 261,9 258,9 @@ func (p *permissions) apply(vars map[string][]string, l *line) error {

// permsForSchema builds a permission-set for a schema with every database,
// statement, table and column filled out. Everything defaults to allowed.
func permsForSchema(schema Schema, allStatements []string) *permissions {
func permsForSchema(schema *Schema, allStatements []string) *permissions {
	perms := &permissions{Databases: map[string]*dbPermission{}}
	for dbName, tables := range schema {
	for dbName, tables := range schema.Databases {
		perms.Databases[dbName] = &dbPermission{
			Tables: map[string]*tablePermission{},
		}


@@ 285,7 282,7 @@ func permsForSchema(schema Schema, allStatements []string) *permissions {
}

func permsForLines(
	schema Schema,
	schema *Schema,
	allStatements []string,
	vars map[string][]string,
	ls []*line,

M parser_test.go => parser_test.go +30 -24
@@ 173,10 173,12 @@ func TestPermissionsApply(t *testing.T) {
		t.Run(fmt.Sprint(i), func(t *testing.T) {
			t.Parallel()

			schema := Schema{
				"db": Database{
					"table": Table{
						"column": struct{}{},
			schema := &Schema{
				Databases: map[string]Database{
					"db": Database{
						"table": Table{
							"column": struct{}{},
						},
					},
				},
			}


@@ 267,10 269,12 @@ func TestPermsForLines(t *testing.T) {
		tc := tc
		t.Run(fmt.Sprint(i), func(t *testing.T) {
			t.Parallel()
			schema := Schema{
				"db": Database{
					"table": Table{
						"column": struct{}{},
			schema := &Schema{
				Databases: map[string]Database{
					"db": Database{
						"table": Table{
							"column": struct{}{},
						},
					},
				},
			}


@@ 409,23 413,25 @@ func TestParse(t *testing.T) {
	if err != nil {
		t.Fatal(err)
	}
	schema := Schema{
		"dashboard": Database{
			"admins": Table{
				"id":       struct{}{},
				"email":    struct{}{},
				"password": struct{}{},
			},
			"users": Table{
				"id":       struct{}{},
				"email":    struct{}{},
				"password": struct{}{},
	schema := &Schema{
		Databases: map[string]Database{
			"dashboard": Database{
				"admins": Table{
					"id":       struct{}{},
					"email":    struct{}{},
					"password": struct{}{},
				},
				"users": Table{
					"id":       struct{}{},
					"email":    struct{}{},
					"password": struct{}{},
				},
			},
		},
		"nogrant": Database{
			"messages": Table{
				"id":      struct{}{},
				"content": struct{}{},
			"nogrant": Database{
				"messages": Table{
					"id":      struct{}{},
					"content": struct{}{},
				},
			},
		},
	}

M sf.go => sf.go +5 -1
@@ 19,7 19,11 @@ func BuildGrants(db Store, r io.Reader) ([]string, error) {
	if err != nil {
		return nil, fmt.Errorf("compile: %w", err)
	}
	var allGrants []string
	allGrants := make([]string, 0, len(schema.Users))
	for _, u := range schema.Users {
		g := fmt.Sprintf("REVOKE ALL PRIVILEGES, GRANT OPTION FROM %s", u)
		allGrants = append(allGrants, g)
	}
	for u, perm := range userPerms {
		allGrants = append(allGrants, perm.grants(u)...)
	}

M store.go => store.go +5 -2
@@ 1,7 1,10 @@
package sf

// Schema maps names to databases.
type Schema map[string]Database
type Schema struct {
	Users     []string
	Databases map[string]Database
}

// Database maps names to tables.
type Database map[string]Table


@@ 15,7 18,7 @@ type Store interface {
	Statements() []string

	// GetSchema for all databases.
	GetSchema() (Schema, error)
	GetSchema() (*Schema, error)

	// Apply GRANT, REVOKE, or FLUSH PRIVILEGES statements.
	Apply(string) error