~egtann/sum

b62379e992957abff4ea7a7a0800b0458d5b6aae — Evan Tann 4 months ago e6d880e
modify mysql priv tables directly

MySQL GRANT and REVOKE statements take effect quickly after executing
(but not immediately) without FLUSH PRIVILEGES. It's very dangerous to
use these statements as a result, since there's no guarantee they'll be
applied all at once.

To work around this, we modify the _priv tables directly in a
transaction then FLUSH PRIVILEGES manually. Note that this is not
compatible with MariaDB; although MariaDB has a user table, it's not a
real table, it's a view, and cannot be modified directly.
5 files changed, 388 insertions(+), 52 deletions(-)

M cmd/sf/main.go
M mysql/store.go
M parser.go
M sf.go
M store.go
M cmd/sf/main.go => cmd/sf/main.go +16 -12
@@ 36,7 36,7 @@ func main() {
func run() error {
	f := parseFlags()

	// TODO(egtann) migrate+pledge
	// TODO(egtann) pledge+unveil

	// Request database password if not provided as a flag argument
	if f.password == "" {


@@ 64,19 64,23 @@ func run() error {
	}
	defer fi.Close()

	grants, err := sf.BuildGrants(db, fi)
	if err != nil {
		return fmt.Errorf("build grants: %w", err)
	}
	if len(grants) == 0 {
		return errors.New("no grants")
	}
	if f.dry {
		fmt.Println(strings.Join(grants, ";\n") + ";")
	} else {
		if err = sf.ApplyGrants(db, grants); err != nil {
			return fmt.Errorf("apply grants: %w", err)
		grants, err := sf.BuildGrants(db, fi)
		if err != nil {
			return fmt.Errorf("build grants: %w", err)
		}
		if len(grants) == 0 {
			return errors.New("no grants")
		}
		fmt.Println(strings.Join(grants, ";\n") + ";")
		return nil
	}
	perms, err := sf.BuildPermissions(db, fi)
	if err != nil {
		return fmt.Errorf("build permissions: %w", err)
	}
	if err = db.Apply(perms); err != nil {
		return fmt.Errorf("apply permissions: %w", err)
	}
	return nil
}

M mysql/store.go => mysql/store.go +321 -7
@@ 1,11 1,13 @@
package mysql

import (
	"context"
	"crypto/tls"
	"crypto/x509"
	"errors"
	"fmt"
	"io/ioutil"
	"os"
	"strings"

	"egt.run/sf"


@@ 23,6 25,25 @@ type DB struct {
	*sqlx.DB
}

type tablePerm struct {
	DB          string
	User        string
	Host        string
	TableName   string
	Grantor     string
	TablePrivs  []string
	ColumnPrivs []string
}

type columnPerm struct {
	DB          string
	User        string
	Host        string
	TableName   string
	ColumnName  string
	ColumnPrivs []string
}

func New(
	user, pass, host string,
	port int,


@@ 122,16 143,16 @@ func (db *DB) GetSchema() (*sf.Schema, error) {
	}

	var users []struct {
		User string
		Host string
		User string `db:"User"`
		Host string `db:"Host"`
	}
	q = `SELECT user, host FROM mysql.user`
	q = `SELECT User, Host FROM mysql.user`
	if err := db.Select(&users, q); err != nil {
		return nil, fmt.Errorf("select users: %w", err)
	}
	for _, u := range users {
		// Don't modify mysql.session or mysql.sys
		if strings.HasPrefix(u.User, "mysql.") {
		if strings.HasPrefix(u.User, "mysql.") || u.User == "_mysql" {
			continue
		}
		userHost := fmt.Sprintf("'%s'@'%s'", u.User, u.Host)


@@ 140,9 161,218 @@ func (db *DB) GetSchema() (*sf.Schema, error) {
	return schema, nil
}

func (db *DB) Apply(s string) error {
	if _, err := db.Exec(s); err != nil {
		return fmt.Errorf("exec %s: %w", s, err)
// Apply permissions for all users.
func (db *DB) Apply(userPerms map[string]*sf.Permissions) (err error) {
	var tx *sqlx.Tx
	tx, err = db.BeginTxx(context.Background(), nil)
	if err != nil {
		return fmt.Errorf("begin tx: %w", err)
	}
	defer func() {
		if err == nil {
			if err2 := tx.Commit(); err2 != nil {
				fmt.Fprintln(os.Stderr, "failed to commit", err2)
			}
		} else {
			if err2 := tx.Rollback(); err2 != nil {
				fmt.Fprintln(os.Stderr, "failed to rollback", err2)
			}
		}
	}()

	// Update user table and revoke all table + column privileges
	q := `DELETE FROM mysql.db WHERE user NOT LIKE 'mysql.%'`
	if _, err = tx.Exec(q); err != nil {
		return fmt.Errorf("delete db priv: %w", err)
	}
	q = `DELETE FROM mysql.tables_priv`
	if _, err = tx.Exec(q); err != nil {
		return fmt.Errorf("delete tables_priv: %w", err)
	}
	q = `DELETE FROM mysql.columns_priv`
	if _, err = tx.Exec(q); err != nil {
		return fmt.Errorf("delete columns_priv: %w", err)
	}

	for user, p := range userPerms {
		parts := strings.SplitN(user, "@", 2)
		user = parts[0]
		var host string
		if len(parts) != 2 {
			host = "%"
		} else {
			host = parts[1]
		}

		// TODO(egtann) add tests for this complex code...
		if err = applyUser(tx, user, host, p); err != nil {
			return fmt.Errorf("apply user %s@%s: %w", user, host, err)
		}
	}

	// And now ensure we're applying those permissions
	// TODO(egtann) should this go in defer after commit?
	q = `FLUSH PRIVILEGES`
	if _, err = tx.Exec(q); err != nil {
		return fmt.Errorf("flush privileges: %w", err)
	}
	return nil
}

// applyUser permissions by directly editing priv tables so we can manually
// trigger flushing privileges.
func applyUser(tx *sqlx.Tx, user, host string, p *sf.Permissions) error {
	var pq *sqlUserPerms
	if p.Deny {
		pq = newSQLUserPerms("N")
	} else {
		// This is the behavior from:
		//	GRANT ALL ON *.* TO user WITH GRANT OPTION;
		pq = newSQLUserPerms("Y")
		pq.File = "N"
		pq.Super = "N"
	}
	q := `UPDATE mysql.user
	      SET Select_priv=?, Insert_priv=?, Update_priv=?, Delete_priv=?,
	          Create_priv=?, Drop_priv=?, Reload_priv=?, Shutdown_priv=?,
	          Process_priv=?, File_priv=?, Grant_priv=?, References_priv=?,
	          Index_priv=?, Alter_priv=?, Show_db_priv=?, Super_priv=?,
	          Create_tmp_table_priv=?, Lock_tables_priv=?, Execute_priv=?,
	          Repl_slave_priv=?, Repl_client_priv=?, Create_view_priv=?,
	          Show_view_priv=?, Create_routine_priv=?, Event_priv=?,
	          Trigger_priv=?, Create_tablespace_priv=?
	       WHERE User=? AND Host=?`
	_, err := tx.Exec(q, pq.Select, pq.Insert, pq.Update, pq.Delete,
		pq.Create, pq.Drop, pq.Reload, pq.Shutdown,
		pq.Process, pq.File, pq.Grant, pq.References,
		pq.Index, pq.Alter, pq.ShowDB, pq.Super,
		pq.CreateTmpTable, pq.LockTables, pq.Execute,
		pq.ReplSlave, pq.ReplClient, pq.CreateView,
		pq.ShowView, pq.CreateRoutine, pq.Event,
		pq.Trigger, pq.CreateTablespace,
		user, host,
	)
	if err != nil {
		return fmt.Errorf("set user privs: %w", err)
	}
	if !p.Deny {
		return nil
	}

	// Update tables_priv table since we're only allowing partial access to
	// specific tables or columns
	tables := map[string]*tablePerm{}
	columns := map[string]*columnPerm{}
	for d, db := range p.Databases {
		for t, table := range db.Tables {
			for s, statement := range table.Statements {
				if statement.Deny {
					continue
				}
				s = strings.Title(s)
				if s == "Show View" {
					s = "Show view" // ugh :/
				}
				if allAllowed(statement.Columns) {
					// Add a table entry, no column entry
					var tp *tablePerm
					key := tableKey(d, user, host, t)
					if _, ok := tables[key]; !ok {
						tables[key] = &tablePerm{
							DB:        d,
							User:      user,
							Host:      host,
							TableName: t,
						}
					}
					tp = tables[key]
					tp.TablePrivs = append(tp.TablePrivs, s)
					tables[key] = tp
					continue
				}
				switch s {
				case "Select", "Insert", "Update", "References":
					// Keep going
				default:
					// Other statements cannot apply to the
					// column-level
					continue
				}

				// Add a table entry indicating that this is a
				// column-level permission
				var tp *tablePerm
				key := tableKey(d, user, host, t)
				if _, ok := tables[key]; !ok {
					tables[key] = &tablePerm{
						DB:        d,
						User:      user,
						Host:      host,
						TableName: t,
					}
				}
				tp = tables[key]
				tp.ColumnPrivs = append(tp.ColumnPrivs, s)
				tables[key] = tp
				for c, deny := range statement.Columns {
					if deny {
						continue
					}
					c = strings.Title(c)
					var cp *columnPerm
					key = columnKey(d, user, host, t, c)
					if _, ok := columns[key]; !ok {
						columns[key] = &columnPerm{
							DB:         d,
							User:       user,
							Host:       host,
							TableName:  t,
							ColumnName: c,
						}
					}
					cp = columns[key]
					cp.ColumnPrivs = append(cp.ColumnPrivs, s)
					columns[key] = cp
				}
			}
		}
	}

	for _, t := range tables {
		var (
			tablePrivs  string
			columnPrivs string
		)
		if len(t.TablePrivs) > 0 {
			tablePrivs = strings.Join(t.TablePrivs, ",")
		}
		if len(t.ColumnPrivs) > 0 {
			columnPrivs = strings.Join(t.ColumnPrivs, ",")
		}
		if len(t.TablePrivs) == 0 && len(t.ColumnPrivs) == 0 {
			// This shouldn't happen
			return errors.New("empty privs")
		}
		q = `INSERT INTO mysql.tables_priv (
			 Db, User, Host, Table_name, Grantor, Table_priv,
			 Column_priv
		     ) VALUES (?, ?, ?, ?, (SELECT USER()), ?, ?)`
		_, err = tx.Exec(q, t.DB, t.User, t.Host, t.TableName,
			tablePrivs, columnPrivs)
		if err != nil {
			return fmt.Errorf("insert tables_priv: %w", err)
		}
	}
	for _, c := range columns {
		privs := strings.Join(c.ColumnPrivs, ",")
		q = `INSERT INTO mysql.columns_priv (
			 Db, User, Host, Table_name, Column_name, Column_priv
		     ) VALUES (?, ?, ?, ?, ?, ?)`
		_, err = tx.Exec(q, c.DB, c.User, c.Host, c.TableName,
			c.ColumnName, privs)
		if err != nil {
			return fmt.Errorf("insert columns_priv: %w", err)
		}
	}
	return nil
}


@@ 178,3 408,87 @@ func newTLSConfig(
	}
	return conf, nil
}

type sqlUserPerms = struct {
	Select           string
	Insert           string
	Update           string
	Delete           string
	Create           string
	Drop             string
	Reload           string
	Shutdown         string
	Process          string
	File             string
	Grant            string
	References       string
	Index            string
	Alter            string
	ShowDB           string
	Super            string
	CreateTmpTable   string
	LockTables       string
	Execute          string
	ReplSlave        string
	ReplClient       string
	CreateView       string
	ShowView         string
	CreateRoutine    string
	Event            string
	Trigger          string
	CreateTablespace string
}

// newSQLUserPerms initializes permissions all set to the provided value.
func newSQLUserPerms(val string) *sqlUserPerms {
	return &sqlUserPerms{
		Select:           val,
		Insert:           val,
		Update:           val,
		Delete:           val,
		Create:           val,
		Drop:             val,
		Reload:           val,
		Shutdown:         val,
		Process:          val,
		File:             val,
		Grant:            val,
		References:       val,
		Index:            val,
		Alter:            val,
		ShowDB:           val,
		Super:            val,
		CreateTmpTable:   val,
		LockTables:       val,
		Execute:          val,
		ReplSlave:        val,
		ReplClient:       val,
		CreateView:       val,
		ShowView:         val,
		CreateRoutine:    val,
		Event:            val,
		Trigger:          val,
		CreateTablespace: val,
	}
}

func allAllowed(columns map[string]bool) bool {
	for _, deny := range columns {
		if deny {
			return false
		}
	}
	return true
}

// Key is db.user.host.table_name, which is the primary key for the
// tables_priv table.
func tableKey(db, user, host, table string) string {
	return fmt.Sprintf("%s.%s.%s.%s", db, user, host, table)
}

// Key is db.user.host.table_name.column_name, which is the primary key for the
// columns_priv table.
func columnKey(db, user, host, table, column string) string {
	return fmt.Sprintf("%s.%s.%s.%s.%s", db, user, host, table, column)
}

M parser.go => parser.go +26 -24
@@ 31,24 31,24 @@ type seenLines struct {
	columns    map[string]struct{}
}

type permissions struct {
type Permissions struct {
	Deny      bool                     `json:",omitempty"`
	Databases map[string]*dbPermission `json:",omitempty"`
	Databases map[string]*DBPermission `json:",omitempty"`
}

type dbPermission struct {
type DBPermission struct {
	Deny   bool                        `json:",omitempty"`
	Tables map[string]*tablePermission `json:",omitempty"`
	Tables map[string]*TablePermission `json:",omitempty"`
}

type statementPermission struct {
type StatementPermission struct {
	Deny    bool            `json:",omitempty"`
	Columns map[string]bool `json:",omitempty"`
}

type tablePermission struct {
type TablePermission struct {
	Deny       bool                            `json:",omitempty"`
	Statements map[string]*statementPermission `json:",omitempty"`
	Statements map[string]*StatementPermission `json:",omitempty"`
}

// compile permissions for every user in statements.


@@ 56,7 56,7 @@ func compile(
	schema *Schema,
	allStatements []string,
	a *ast,
) (map[string]*permissions, error) {
) (map[string]*Permissions, error) {
	// First collect a list of all users
	//
	// TODO(egtann) this can be done during parsing...


@@ 100,8 100,9 @@ func compile(
	}

	// Given each user's set of lines, assemble permissions for them.
	userPerms := map[string]*permissions{}
	userPerms := map[string]*Permissions{}
	for u, lines := range userLines {
		fmt.Println("USER LINES", u)
		var err error
		userPerms[u], err = permsForLines(schema, allStatements,
			a.vars, lines)


@@ 115,7 116,7 @@ func compile(
// dbsAllGranted searches through all sub-permissions to determine if any are
// denied. It reports true iff all permissions are granted across every
// sub-resource.
func dbsAllGranted(dbs map[string]*dbPermission) bool {
func dbsAllGranted(dbs map[string]*DBPermission) bool {
	for _, db := range dbs {
		if db.Deny {
			return false


@@ 130,7 131,7 @@ func dbsAllGranted(dbs map[string]*dbPermission) bool {
// tablesAllGranted searches through all sub-permissions to determine if any
// are denied. It reports true iff all permissions are granted across every
// sub-resource.
func tablesAllGranted(tables map[string]*tablePermission) bool {
func tablesAllGranted(tables map[string]*TablePermission) bool {
	for _, s := range tables {
		if s.Deny {
			return false


@@ 151,7 152,7 @@ func tablesAllGranted(tables map[string]*tablePermission) bool {

// grants outputs the minimum set of GRANT and REVOKE statements that will
// enforce the permissions.
func (p *permissions) grants(user string) []string {
func (p *Permissions) grants(user string) []string {
	// Default to any host, but allow overriding it
	if !strings.Contains(user, "@") {
		user += "@'%'"


@@ 208,7 209,7 @@ func (p *permissions) grants(user string) []string {
}

// apply permissions for a given line.
func (p *permissions) apply(vars map[string][]string, l *line) error {
func (p *Permissions) apply(vars map[string][]string, l *line) error {
	deny := l.verb == "deny"

	// Make the appropriate variable substitutions


@@ 339,7 340,7 @@ func (p *permissions) apply(vars map[string][]string, l *line) error {
	return nil
}

func markAnyDBSeen(sl *seenLines, p *permissions) *seenLines {
func markAnyDBSeen(sl *seenLines, p *Permissions) *seenLines {
	for d, db := range p.Databases {
		sl.databases[d] = struct{}{}
		for t, table := range db.Tables {


@@ 355,7 356,7 @@ func markAnyDBSeen(sl *seenLines, p *permissions) *seenLines {
	return sl
}

func markAnyTableSeen(sl *seenLines, db *dbPermission) *seenLines {
func markAnyTableSeen(sl *seenLines, db *DBPermission) *seenLines {
	for t, table := range db.Tables {
		sl.tables[t] = struct{}{}
		for s, statement := range table.Statements {


@@ 368,7 369,7 @@ func markAnyTableSeen(sl *seenLines, db *dbPermission) *seenLines {
	return sl
}

func markAnyStatementSeen(sl *seenLines, table *tablePermission) *seenLines {
func markAnyStatementSeen(sl *seenLines, table *TablePermission) *seenLines {
	for s, statement := range table.Statements {
		sl.statements[s] = struct{}{}
		for c := range statement.Columns {


@@ 380,18 381,18 @@ func markAnyStatementSeen(sl *seenLines, table *tablePermission) *seenLines {

// permsForSchema builds a permission-set for a schema with every database,
// statement, table and column filled out. Everything defaults to allowed.
func permsForSchema(schema *Schema, allStatements []string) *permissions {
	perms := &permissions{Databases: map[string]*dbPermission{}}
func permsForSchema(schema *Schema, allStatements []string) *Permissions {
	perms := &Permissions{Databases: map[string]*DBPermission{}}
	for dbName, tables := range schema.Databases {
		perms.Databases[dbName] = &dbPermission{
			Tables: map[string]*tablePermission{},
		perms.Databases[dbName] = &DBPermission{
			Tables: map[string]*TablePermission{},
		}
		for tableName, columns := range tables {
			perms.Databases[dbName].Tables[tableName] = &tablePermission{
				Statements: map[string]*statementPermission{},
			perms.Databases[dbName].Tables[tableName] = &TablePermission{
				Statements: map[string]*StatementPermission{},
			}
			for _, statement := range allStatements {
				perms.Databases[dbName].Tables[tableName].Statements[statement] = &statementPermission{
				perms.Databases[dbName].Tables[tableName].Statements[statement] = &StatementPermission{
					Columns: map[string]bool{},
				}
				for colName := range columns {


@@ 408,9 409,10 @@ func permsForLines(
	allStatements []string,
	vars map[string][]string,
	ls []*line,
) (*permissions, error) {
) (*Permissions, error) {
	perms := permsForSchema(schema, allStatements)
	for _, l := range ls {
		fmt.Println("APPLY", l)
		if err := perms.apply(vars, l); err != nil {
			return nil, fmt.Errorf("apply line %d: %w", l.line, err)
		}

M sf.go => sf.go +23 -7
@@ 5,6 5,11 @@ import (
	"io"
)

// BuildGrants which can be applied in a database. Note that grant and revoke
// statements take effect immediately, so you may lock yourself out by applying
// this. Instead use this as a quick way to audit permissions generated by sf,
// but use sf to apply those permissions directly to the database by editing
// priv tables directly.
func BuildGrants(db Store, r io.Reader) ([]string, error) {
	schema, err := db.GetSchema()
	if err != nil {


@@ 27,15 32,26 @@ func BuildGrants(db Store, r io.Reader) ([]string, error) {
	for u, perm := range userPerms {
		allGrants = append(allGrants, perm.grants(u)...)
	}
	allGrants = append(allGrants, "FLUSH PRIVILEGES")
	return allGrants, nil
}

func ApplyGrants(db Store, grants []string) error {
	for _, g := range grants {
		if err := db.Apply(g); err != nil {
			return fmt.Errorf("apply: %w", err)
		}
// BuildPermissions for each user.
func BuildPermissions(
	db Store,
	r io.Reader,
) (map[string]*Permissions, error) {
	schema, err := db.GetSchema()
	if err != nil {
		return nil, fmt.Errorf("get schema: %w", err)
	}
	allStatements := db.Statements()
	ast, err := parse(r)
	if err != nil {
		return nil, fmt.Errorf("parse: %w", err)
	}
	userPerms, err := compile(schema, allStatements, ast)
	if err != nil {
		return nil, fmt.Errorf("compile: %w", err)
	}
	return nil
	return userPerms, nil
}

M store.go => store.go +2 -2
@@ 20,6 20,6 @@ type Store interface {
	// GetSchema for all databases.
	GetSchema() (*Schema, error)

	// Apply GRANT, REVOKE, or FLUSH PRIVILEGES statements.
	Apply(string) error
	// Apply permissions in the DB for every user.
	Apply(map[string]*Permissions) error
}