M README.md => README.md +2 -0
@@ 111,5 111,7 @@ every database migration.
## TODO
+* validate that every defined user, db, table, statement, column at least
+ exists in the db, since otherwise the failure is silent
* quoted values, e.g. multi-word statement "alter routine"
* combine users into the `permissions` struct to simplify code?
M cmd/sf/main.go => cmd/sf/main.go +1 -1
@@ 95,7 95,7 @@ func parseFlags() *flags {
flag.StringVar(&f.sslKey, "ssl-key", "", "pem file containing the ssl client key")
flag.StringVar(&f.sslCert, "ssl-cert", "", "pem file containing the ssl client certificate")
flag.StringVar(&f.sslCA, "ssl-ca", "", "pem file containing the ssl server ca")
- flag.StringVar(&f.sslServerName, "ssl-server-name", "", "ssl server name")
+ flag.StringVar(&f.sslServerName, "ssl-server", "", "ssl server name")
flag.Parse()
return f
}
M go.mod => go.mod +2 -1
@@ 3,8 3,9 @@ module egt.run/sf
go 1.14
require (
- github.com/go-sql-driver/mysql v1.4.0
+ github.com/go-sql-driver/mysql v1.5.0
github.com/jmoiron/sqlx v1.2.0
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9
+ golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae // indirect
google.golang.org/appengine v1.6.6 // indirect
)
M go.sum => go.sum +4 -0
@@ 1,5 1,7 @@
github.com/go-sql-driver/mysql v1.4.0 h1:7LxgVwFb2hIQtMm87NdgAVfXjnt4OePseqT1tKx+opk=
github.com/go-sql-driver/mysql v1.4.0/go.mod h1:zAC/RDZ24gD3HViQzih4MyKcchzm+sOG5ZlKdlhCg5w=
+github.com/go-sql-driver/mysql v1.5.0 h1:ozyZYNQW3x3HtqT1jira07DN2PArx2v7/mN66gGcHOs=
+github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/jmoiron/sqlx v1.2.0 h1:41Ip0zITnmWNR/vHV+S4m+VoUivnWY5E4OJfLZjCJMA=
github.com/jmoiron/sqlx v1.2.0/go.mod h1:1FEQNm3xlJgrMD+FBdI9+xvCksHtbpVBBw5dYhBSsks=
@@ 15,6 17,8 @@ golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d h1:+R4KGOnez64A81RvjARKc4UT5/tI9ujCIVX+P5KiHuI=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
+golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae h1:Ih9Yo4hSPImZOpfGuA4bR/ORKTAbhZo2AbWNRCnevdo=
+golang.org/x/sys v0.0.0-20200625212154-ddb9806d33ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
M mysql/store.go => mysql/store.go +31 -10
@@ 6,6 6,7 @@ import (
"errors"
"fmt"
"io/ioutil"
+ "strings"
"egt.run/sf"
"github.com/go-sql-driver/mysql"
@@ 30,7 31,7 @@ func New(
db := &DB{}
db.connURL = fmt.Sprintf("%s:%s@tcp(%s:%d)/", user, pass, host, port)
if sslKey != "" {
- db.connURL = fmt.Sprintf("%s&tls=%s", db.connURL, sslServerName)
+ db.connURL = fmt.Sprintf("%s?tls=%s", db.connURL, sslServerName)
var err error
db.tlsConfig, err = newTLSConfig(sslKey, sslCert, sslCA,
sslServerName)
@@ 54,6 55,9 @@ func (db *DB) Open() error {
if err != nil {
return fmt.Errorf("open db connection: %w", err)
}
+ if err = db.DB.Ping(); err != nil {
+ return fmt.Errorf("ping: %w", err)
+ }
return nil
}
@@ 95,7 99,7 @@ func (db *DB) Statements() []string {
}
}
-func (db *DB) GetSchema() (sf.Schema, error) {
+func (db *DB) GetSchema() (*sf.Schema, error) {
var data []struct {
Database string `db:"table_schema"`
Table string `db:"table_name"`
@@ 104,19 108,36 @@ func (db *DB) GetSchema() (sf.Schema, error) {
q := `SELECT table_schema, table_name, column_name
FROM information_schema.columns`
if err := db.Select(&data, q); err != nil {
- return nil, fmt.Errorf("select: %w", err)
+ return nil, fmt.Errorf("select columns: %w", err)
}
- schema := map[string]sf.Database{}
+ schema := &sf.Schema{Databases: map[string]sf.Database{}}
for _, d := range data {
- if _, ok := schema[d.Database]; !ok {
- schema[d.Database] = sf.Database{}
+ if _, ok := schema.Databases[d.Database]; !ok {
+ schema.Databases[d.Database] = sf.Database{}
+ }
+ if _, ok := schema.Databases[d.Database][d.Table]; !ok {
+ schema.Databases[d.Database][d.Table] = sf.Table{}
}
- if _, ok := schema[d.Database][d.Table]; !ok {
- schema[d.Database][d.Table] = sf.Table{}
+ schema.Databases[d.Database][d.Table][d.Column] = struct{}{}
+ }
+
+ var users []struct {
+ User string
+ Host string
+ }
+ q = `SELECT user, host FROM mysql.user`
+ if err := db.Select(&users, q); err != nil {
+ return nil, fmt.Errorf("select users: %w", err)
+ }
+ for _, u := range users {
+ // Don't modify mysql.session or mysql.sys
+ if strings.HasPrefix(u.User, "mysql.") {
+ continue
}
- schema[d.Database][d.Table][d.Column] = struct{}{}
+ userHost := fmt.Sprintf("'%s'@'%s'", u.User, u.Host)
+ schema.Users = append(schema.Users, userHost)
}
- return sf.Schema(schema), nil
+ return schema, nil
}
func (db *DB) Apply(s string) error {
M parser.go => parser.go +12 -15
@@ 45,7 45,7 @@ type tablePermission struct {
// compile permissions for every user in statements.
func compile(
- schema Schema,
+ schema *Schema,
allStatements []string,
a *ast,
) (map[string]*permissions, error) {
@@ 72,22 72,18 @@ func compile(
// "all" users as needed.
userLines := map[string][]*line{}
for _, line := range a.lines {
+ if len(line.users) == 1 && line.users[0] == "all" {
+ line.users = allUsers
+ }
for _, u := range line.users {
- if strings.HasPrefix(u, "$") {
- users := a.vars[strings.TrimPrefix(u, "$")]
- for _, u2 := range users {
- userLines[u2] = append(userLines[u2], line)
- }
- continue
- }
- if u != "all" {
+ if !strings.HasPrefix(u, "$") {
userLines[u] = append(userLines[u], line)
continue
}
- for _, u2 := range allUsers {
+ users := a.vars[strings.TrimPrefix(u, "$")]
+ for _, u2 := range users {
userLines[u2] = append(userLines[u2], line)
}
- continue
}
}
@@ 95,7 91,8 @@ func compile(
userPerms := map[string]*permissions{}
for u, lines := range userLines {
var err error
- userPerms[u], err = permsForLines(schema, allStatements, a.vars, lines)
+ userPerms[u], err = permsForLines(schema, allStatements,
+ a.vars, lines)
if err != nil {
return nil, fmt.Errorf("perms for lines: %w", err)
}
@@ 261,9 258,9 @@ func (p *permissions) apply(vars map[string][]string, l *line) error {
// permsForSchema builds a permission-set for a schema with every database,
// statement, table and column filled out. Everything defaults to allowed.
-func permsForSchema(schema Schema, allStatements []string) *permissions {
+func permsForSchema(schema *Schema, allStatements []string) *permissions {
perms := &permissions{Databases: map[string]*dbPermission{}}
- for dbName, tables := range schema {
+ for dbName, tables := range schema.Databases {
perms.Databases[dbName] = &dbPermission{
Tables: map[string]*tablePermission{},
}
@@ 285,7 282,7 @@ func permsForSchema(schema Schema, allStatements []string) *permissions {
}
func permsForLines(
- schema Schema,
+ schema *Schema,
allStatements []string,
vars map[string][]string,
ls []*line,
M parser_test.go => parser_test.go +30 -24
@@ 173,10 173,12 @@ func TestPermissionsApply(t *testing.T) {
t.Run(fmt.Sprint(i), func(t *testing.T) {
t.Parallel()
- schema := Schema{
- "db": Database{
- "table": Table{
- "column": struct{}{},
+ schema := &Schema{
+ Databases: map[string]Database{
+ "db": Database{
+ "table": Table{
+ "column": struct{}{},
+ },
},
},
}
@@ 267,10 269,12 @@ func TestPermsForLines(t *testing.T) {
tc := tc
t.Run(fmt.Sprint(i), func(t *testing.T) {
t.Parallel()
- schema := Schema{
- "db": Database{
- "table": Table{
- "column": struct{}{},
+ schema := &Schema{
+ Databases: map[string]Database{
+ "db": Database{
+ "table": Table{
+ "column": struct{}{},
+ },
},
},
}
@@ 409,23 413,25 @@ func TestParse(t *testing.T) {
if err != nil {
t.Fatal(err)
}
- schema := Schema{
- "dashboard": Database{
- "admins": Table{
- "id": struct{}{},
- "email": struct{}{},
- "password": struct{}{},
- },
- "users": Table{
- "id": struct{}{},
- "email": struct{}{},
- "password": struct{}{},
+ schema := &Schema{
+ Databases: map[string]Database{
+ "dashboard": Database{
+ "admins": Table{
+ "id": struct{}{},
+ "email": struct{}{},
+ "password": struct{}{},
+ },
+ "users": Table{
+ "id": struct{}{},
+ "email": struct{}{},
+ "password": struct{}{},
+ },
},
- },
- "nogrant": Database{
- "messages": Table{
- "id": struct{}{},
- "content": struct{}{},
+ "nogrant": Database{
+ "messages": Table{
+ "id": struct{}{},
+ "content": struct{}{},
+ },
},
},
}
M sf.go => sf.go +5 -1
@@ 19,7 19,11 @@ func BuildGrants(db Store, r io.Reader) ([]string, error) {
if err != nil {
return nil, fmt.Errorf("compile: %w", err)
}
- var allGrants []string
+ allGrants := make([]string, 0, len(schema.Users))
+ for _, u := range schema.Users {
+ g := fmt.Sprintf("REVOKE ALL PRIVILEGES, GRANT OPTION FROM %s", u)
+ allGrants = append(allGrants, g)
+ }
for u, perm := range userPerms {
allGrants = append(allGrants, perm.grants(u)...)
}
M store.go => store.go +5 -2
@@ 1,7 1,10 @@
package sf
// Schema maps names to databases.
-type Schema map[string]Database
+type Schema struct {
+ Users []string
+ Databases map[string]Database
+}
// Database maps names to tables.
type Database map[string]Table
@@ 15,7 18,7 @@ type Store interface {
Statements() []string
// GetSchema for all databases.
- GetSchema() (Schema, error)
+ GetSchema() (*Schema, error)
// Apply GRANT, REVOKE, or FLUSH PRIVILEGES statements.
Apply(string) error