@@ 1,11 1,13 @@
package mysql
import (
+ "context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io/ioutil"
+ "os"
"strings"
"egt.run/sf"
@@ 23,6 25,25 @@ type DB struct {
*sqlx.DB
}
+type tablePerm struct {
+ DB string
+ User string
+ Host string
+ TableName string
+ Grantor string
+ TablePrivs []string
+ ColumnPrivs []string
+}
+
+type columnPerm struct {
+ DB string
+ User string
+ Host string
+ TableName string
+ ColumnName string
+ ColumnPrivs []string
+}
+
func New(
user, pass, host string,
port int,
@@ 122,16 143,16 @@ func (db *DB) GetSchema() (*sf.Schema, error) {
}
var users []struct {
- User string
- Host string
+ User string `db:"User"`
+ Host string `db:"Host"`
}
- q = `SELECT user, host FROM mysql.user`
+ q = `SELECT User, Host FROM mysql.user`
if err := db.Select(&users, q); err != nil {
return nil, fmt.Errorf("select users: %w", err)
}
for _, u := range users {
// Don't modify mysql.session or mysql.sys
- if strings.HasPrefix(u.User, "mysql.") {
+ if strings.HasPrefix(u.User, "mysql.") || u.User == "_mysql" {
continue
}
userHost := fmt.Sprintf("'%s'@'%s'", u.User, u.Host)
@@ 140,9 161,218 @@ func (db *DB) GetSchema() (*sf.Schema, error) {
return schema, nil
}
-func (db *DB) Apply(s string) error {
- if _, err := db.Exec(s); err != nil {
- return fmt.Errorf("exec %s: %w", s, err)
+// Apply permissions for all users.
+func (db *DB) Apply(userPerms map[string]*sf.Permissions) (err error) {
+ var tx *sqlx.Tx
+ tx, err = db.BeginTxx(context.Background(), nil)
+ if err != nil {
+ return fmt.Errorf("begin tx: %w", err)
+ }
+ defer func() {
+ if err == nil {
+ if err2 := tx.Commit(); err2 != nil {
+ fmt.Fprintln(os.Stderr, "failed to commit", err2)
+ }
+ } else {
+ if err2 := tx.Rollback(); err2 != nil {
+ fmt.Fprintln(os.Stderr, "failed to rollback", err2)
+ }
+ }
+ }()
+
+ // Update user table and revoke all table + column privileges
+ q := `DELETE FROM mysql.db WHERE user NOT LIKE 'mysql.%'`
+ if _, err = tx.Exec(q); err != nil {
+ return fmt.Errorf("delete db priv: %w", err)
+ }
+ q = `DELETE FROM mysql.tables_priv`
+ if _, err = tx.Exec(q); err != nil {
+ return fmt.Errorf("delete tables_priv: %w", err)
+ }
+ q = `DELETE FROM mysql.columns_priv`
+ if _, err = tx.Exec(q); err != nil {
+ return fmt.Errorf("delete columns_priv: %w", err)
+ }
+
+ for user, p := range userPerms {
+ parts := strings.SplitN(user, "@", 2)
+ user = parts[0]
+ var host string
+ if len(parts) != 2 {
+ host = "%"
+ } else {
+ host = parts[1]
+ }
+
+ // TODO(egtann) add tests for this complex code...
+ if err = applyUser(tx, user, host, p); err != nil {
+ return fmt.Errorf("apply user %s@%s: %w", user, host, err)
+ }
+ }
+
+ // And now ensure we're applying those permissions
+ // TODO(egtann) should this go in defer after commit?
+ q = `FLUSH PRIVILEGES`
+ if _, err = tx.Exec(q); err != nil {
+ return fmt.Errorf("flush privileges: %w", err)
+ }
+ return nil
+}
+
+// applyUser permissions by directly editing priv tables so we can manually
+// trigger flushing privileges.
+func applyUser(tx *sqlx.Tx, user, host string, p *sf.Permissions) error {
+ var pq *sqlUserPerms
+ if p.Deny {
+ pq = newSQLUserPerms("N")
+ } else {
+ // This is the behavior from:
+ // GRANT ALL ON *.* TO user WITH GRANT OPTION;
+ pq = newSQLUserPerms("Y")
+ pq.File = "N"
+ pq.Super = "N"
+ }
+ q := `UPDATE mysql.user
+ SET Select_priv=?, Insert_priv=?, Update_priv=?, Delete_priv=?,
+ Create_priv=?, Drop_priv=?, Reload_priv=?, Shutdown_priv=?,
+ Process_priv=?, File_priv=?, Grant_priv=?, References_priv=?,
+ Index_priv=?, Alter_priv=?, Show_db_priv=?, Super_priv=?,
+ Create_tmp_table_priv=?, Lock_tables_priv=?, Execute_priv=?,
+ Repl_slave_priv=?, Repl_client_priv=?, Create_view_priv=?,
+ Show_view_priv=?, Create_routine_priv=?, Event_priv=?,
+ Trigger_priv=?, Create_tablespace_priv=?
+ WHERE User=? AND Host=?`
+ _, err := tx.Exec(q, pq.Select, pq.Insert, pq.Update, pq.Delete,
+ pq.Create, pq.Drop, pq.Reload, pq.Shutdown,
+ pq.Process, pq.File, pq.Grant, pq.References,
+ pq.Index, pq.Alter, pq.ShowDB, pq.Super,
+ pq.CreateTmpTable, pq.LockTables, pq.Execute,
+ pq.ReplSlave, pq.ReplClient, pq.CreateView,
+ pq.ShowView, pq.CreateRoutine, pq.Event,
+ pq.Trigger, pq.CreateTablespace,
+ user, host,
+ )
+ if err != nil {
+ return fmt.Errorf("set user privs: %w", err)
+ }
+ if !p.Deny {
+ return nil
+ }
+
+ // Update tables_priv table since we're only allowing partial access to
+ // specific tables or columns
+ tables := map[string]*tablePerm{}
+ columns := map[string]*columnPerm{}
+ for d, db := range p.Databases {
+ for t, table := range db.Tables {
+ for s, statement := range table.Statements {
+ if statement.Deny {
+ continue
+ }
+ s = strings.Title(s)
+ if s == "Show View" {
+ s = "Show view" // ugh :/
+ }
+ if allAllowed(statement.Columns) {
+ // Add a table entry, no column entry
+ var tp *tablePerm
+ key := tableKey(d, user, host, t)
+ if _, ok := tables[key]; !ok {
+ tables[key] = &tablePerm{
+ DB: d,
+ User: user,
+ Host: host,
+ TableName: t,
+ }
+ }
+ tp = tables[key]
+ tp.TablePrivs = append(tp.TablePrivs, s)
+ tables[key] = tp
+ continue
+ }
+ switch s {
+ case "Select", "Insert", "Update", "References":
+ // Keep going
+ default:
+ // Other statements cannot apply to the
+ // column-level
+ continue
+ }
+
+ // Add a table entry indicating that this is a
+ // column-level permission
+ var tp *tablePerm
+ key := tableKey(d, user, host, t)
+ if _, ok := tables[key]; !ok {
+ tables[key] = &tablePerm{
+ DB: d,
+ User: user,
+ Host: host,
+ TableName: t,
+ }
+ }
+ tp = tables[key]
+ tp.ColumnPrivs = append(tp.ColumnPrivs, s)
+ tables[key] = tp
+ for c, deny := range statement.Columns {
+ if deny {
+ continue
+ }
+ c = strings.Title(c)
+ var cp *columnPerm
+ key = columnKey(d, user, host, t, c)
+ if _, ok := columns[key]; !ok {
+ columns[key] = &columnPerm{
+ DB: d,
+ User: user,
+ Host: host,
+ TableName: t,
+ ColumnName: c,
+ }
+ }
+ cp = columns[key]
+ cp.ColumnPrivs = append(cp.ColumnPrivs, s)
+ columns[key] = cp
+ }
+ }
+ }
+ }
+
+ for _, t := range tables {
+ var (
+ tablePrivs string
+ columnPrivs string
+ )
+ if len(t.TablePrivs) > 0 {
+ tablePrivs = strings.Join(t.TablePrivs, ",")
+ }
+ if len(t.ColumnPrivs) > 0 {
+ columnPrivs = strings.Join(t.ColumnPrivs, ",")
+ }
+ if len(t.TablePrivs) == 0 && len(t.ColumnPrivs) == 0 {
+ // This shouldn't happen
+ return errors.New("empty privs")
+ }
+ q = `INSERT INTO mysql.tables_priv (
+ Db, User, Host, Table_name, Grantor, Table_priv,
+ Column_priv
+ ) VALUES (?, ?, ?, ?, (SELECT USER()), ?, ?)`
+ _, err = tx.Exec(q, t.DB, t.User, t.Host, t.TableName,
+ tablePrivs, columnPrivs)
+ if err != nil {
+ return fmt.Errorf("insert tables_priv: %w", err)
+ }
+ }
+ for _, c := range columns {
+ privs := strings.Join(c.ColumnPrivs, ",")
+ q = `INSERT INTO mysql.columns_priv (
+ Db, User, Host, Table_name, Column_name, Column_priv
+ ) VALUES (?, ?, ?, ?, ?, ?)`
+ _, err = tx.Exec(q, c.DB, c.User, c.Host, c.TableName,
+ c.ColumnName, privs)
+ if err != nil {
+ return fmt.Errorf("insert columns_priv: %w", err)
+ }
}
return nil
}
@@ 178,3 408,87 @@ func newTLSConfig(
}
return conf, nil
}
+
+type sqlUserPerms = struct {
+ Select string
+ Insert string
+ Update string
+ Delete string
+ Create string
+ Drop string
+ Reload string
+ Shutdown string
+ Process string
+ File string
+ Grant string
+ References string
+ Index string
+ Alter string
+ ShowDB string
+ Super string
+ CreateTmpTable string
+ LockTables string
+ Execute string
+ ReplSlave string
+ ReplClient string
+ CreateView string
+ ShowView string
+ CreateRoutine string
+ Event string
+ Trigger string
+ CreateTablespace string
+}
+
+// newSQLUserPerms initializes permissions all set to the provided value.
+func newSQLUserPerms(val string) *sqlUserPerms {
+ return &sqlUserPerms{
+ Select: val,
+ Insert: val,
+ Update: val,
+ Delete: val,
+ Create: val,
+ Drop: val,
+ Reload: val,
+ Shutdown: val,
+ Process: val,
+ File: val,
+ Grant: val,
+ References: val,
+ Index: val,
+ Alter: val,
+ ShowDB: val,
+ Super: val,
+ CreateTmpTable: val,
+ LockTables: val,
+ Execute: val,
+ ReplSlave: val,
+ ReplClient: val,
+ CreateView: val,
+ ShowView: val,
+ CreateRoutine: val,
+ Event: val,
+ Trigger: val,
+ CreateTablespace: val,
+ }
+}
+
+func allAllowed(columns map[string]bool) bool {
+ for _, deny := range columns {
+ if deny {
+ return false
+ }
+ }
+ return true
+}
+
+// Key is db.user.host.table_name, which is the primary key for the
+// tables_priv table.
+func tableKey(db, user, host, table string) string {
+ return fmt.Sprintf("%s.%s.%s.%s", db, user, host, table)
+}
+
+// Key is db.user.host.table_name.column_name, which is the primary key for the
+// columns_priv table.
+func columnKey(db, user, host, table, column string) string {
+ return fmt.Sprintf("%s.%s.%s.%s.%s", db, user, host, table, column)
+}
@@ 31,24 31,24 @@ type seenLines struct {
columns map[string]struct{}
}
-type permissions struct {
+type Permissions struct {
Deny bool `json:",omitempty"`
- Databases map[string]*dbPermission `json:",omitempty"`
+ Databases map[string]*DBPermission `json:",omitempty"`
}
-type dbPermission struct {
+type DBPermission struct {
Deny bool `json:",omitempty"`
- Tables map[string]*tablePermission `json:",omitempty"`
+ Tables map[string]*TablePermission `json:",omitempty"`
}
-type statementPermission struct {
+type StatementPermission struct {
Deny bool `json:",omitempty"`
Columns map[string]bool `json:",omitempty"`
}
-type tablePermission struct {
+type TablePermission struct {
Deny bool `json:",omitempty"`
- Statements map[string]*statementPermission `json:",omitempty"`
+ Statements map[string]*StatementPermission `json:",omitempty"`
}
// compile permissions for every user in statements.
@@ 56,7 56,7 @@ func compile(
schema *Schema,
allStatements []string,
a *ast,
-) (map[string]*permissions, error) {
+) (map[string]*Permissions, error) {
// First collect a list of all users
//
// TODO(egtann) this can be done during parsing...
@@ 100,8 100,9 @@ func compile(
}
// Given each user's set of lines, assemble permissions for them.
- userPerms := map[string]*permissions{}
+ userPerms := map[string]*Permissions{}
for u, lines := range userLines {
+ fmt.Println("USER LINES", u)
var err error
userPerms[u], err = permsForLines(schema, allStatements,
a.vars, lines)
@@ 115,7 116,7 @@ func compile(
// dbsAllGranted searches through all sub-permissions to determine if any are
// denied. It reports true iff all permissions are granted across every
// sub-resource.
-func dbsAllGranted(dbs map[string]*dbPermission) bool {
+func dbsAllGranted(dbs map[string]*DBPermission) bool {
for _, db := range dbs {
if db.Deny {
return false
@@ 130,7 131,7 @@ func dbsAllGranted(dbs map[string]*dbPermission) bool {
// tablesAllGranted searches through all sub-permissions to determine if any
// are denied. It reports true iff all permissions are granted across every
// sub-resource.
-func tablesAllGranted(tables map[string]*tablePermission) bool {
+func tablesAllGranted(tables map[string]*TablePermission) bool {
for _, s := range tables {
if s.Deny {
return false
@@ 151,7 152,7 @@ func tablesAllGranted(tables map[string]*tablePermission) bool {
// grants outputs the minimum set of GRANT and REVOKE statements that will
// enforce the permissions.
-func (p *permissions) grants(user string) []string {
+func (p *Permissions) grants(user string) []string {
// Default to any host, but allow overriding it
if !strings.Contains(user, "@") {
user += "@'%'"
@@ 208,7 209,7 @@ func (p *permissions) grants(user string) []string {
}
// apply permissions for a given line.
-func (p *permissions) apply(vars map[string][]string, l *line) error {
+func (p *Permissions) apply(vars map[string][]string, l *line) error {
deny := l.verb == "deny"
// Make the appropriate variable substitutions
@@ 339,7 340,7 @@ func (p *permissions) apply(vars map[string][]string, l *line) error {
return nil
}
-func markAnyDBSeen(sl *seenLines, p *permissions) *seenLines {
+func markAnyDBSeen(sl *seenLines, p *Permissions) *seenLines {
for d, db := range p.Databases {
sl.databases[d] = struct{}{}
for t, table := range db.Tables {
@@ 355,7 356,7 @@ func markAnyDBSeen(sl *seenLines, p *permissions) *seenLines {
return sl
}
-func markAnyTableSeen(sl *seenLines, db *dbPermission) *seenLines {
+func markAnyTableSeen(sl *seenLines, db *DBPermission) *seenLines {
for t, table := range db.Tables {
sl.tables[t] = struct{}{}
for s, statement := range table.Statements {
@@ 368,7 369,7 @@ func markAnyTableSeen(sl *seenLines, db *dbPermission) *seenLines {
return sl
}
-func markAnyStatementSeen(sl *seenLines, table *tablePermission) *seenLines {
+func markAnyStatementSeen(sl *seenLines, table *TablePermission) *seenLines {
for s, statement := range table.Statements {
sl.statements[s] = struct{}{}
for c := range statement.Columns {
@@ 380,18 381,18 @@ func markAnyStatementSeen(sl *seenLines, table *tablePermission) *seenLines {
// permsForSchema builds a permission-set for a schema with every database,
// statement, table and column filled out. Everything defaults to allowed.
-func permsForSchema(schema *Schema, allStatements []string) *permissions {
- perms := &permissions{Databases: map[string]*dbPermission{}}
+func permsForSchema(schema *Schema, allStatements []string) *Permissions {
+ perms := &Permissions{Databases: map[string]*DBPermission{}}
for dbName, tables := range schema.Databases {
- perms.Databases[dbName] = &dbPermission{
- Tables: map[string]*tablePermission{},
+ perms.Databases[dbName] = &DBPermission{
+ Tables: map[string]*TablePermission{},
}
for tableName, columns := range tables {
- perms.Databases[dbName].Tables[tableName] = &tablePermission{
- Statements: map[string]*statementPermission{},
+ perms.Databases[dbName].Tables[tableName] = &TablePermission{
+ Statements: map[string]*StatementPermission{},
}
for _, statement := range allStatements {
- perms.Databases[dbName].Tables[tableName].Statements[statement] = &statementPermission{
+ perms.Databases[dbName].Tables[tableName].Statements[statement] = &StatementPermission{
Columns: map[string]bool{},
}
for colName := range columns {
@@ 408,9 409,10 @@ func permsForLines(
allStatements []string,
vars map[string][]string,
ls []*line,
-) (*permissions, error) {
+) (*Permissions, error) {
perms := permsForSchema(schema, allStatements)
for _, l := range ls {
+ fmt.Println("APPLY", l)
if err := perms.apply(vars, l); err != nil {
return nil, fmt.Errorf("apply line %d: %w", l.line, err)
}