~egtann/sum

7949e028d8ef440f3e49eef123b5ac2c3733e946 — Evan Tann 7 months ago 03a1fa3
add all keyword support, allow custom hosts
3 files changed, 42 insertions(+), 15 deletions(-)

M README.md
M parser.go
M testdata/complex
M README.md => README.md +2 -3
@@ 49,7 49,7 @@ And a `sf.conf` input similar to the following:

```
read  = { bob jim }
write = { alice _dashboard }
write = { alice@localhost _dashboard }

# default deny
deny all


@@ 75,7 75,7 @@ REVOKE ALL PRIVILEGES FROM '%'@'%';

GRANT ALL PRIVILEGES TO root@'%' WITH GRANT OPTION;

GRANT SELECT, INSERT, UPDATE, DELETE TO alice@'%';
GRANT SELECT, INSERT, UPDATE, DELETE TO alice@localhost;
GRANT SELECT, INSERT, UPDATE, DELETE TO _dashboard@'%';

GRANT SELECT (id, email) ON dashboard.admins   TO bob@'%';


@@ 111,6 111,5 @@ every database migration.

## TODO

* support `all` for db, statement, table, column keywords
* quoted values, e.g. multi-word statement "alter routine"
* combine users into the `permissions` struct to simplify code?

M parser.go => parser.go +37 -9
@@ 143,9 143,14 @@ func tablesAllGranted(tables map[string]*tablePermission) bool {
// grants outputs the minimum set of GRANT and REVOKE statements that will
// enforce the permissions.
func (p *permissions) grants(user string) []string {
	// Default to any host, but allow overriding it
	if !strings.Contains(user, "@") {
		user += "@'%'"
	}

	// Combine all grants into a cross-db grant where appropriate.
	if dbsAllGranted(p.Databases) {
		return []string{fmt.Sprintf("GRANT ALL PRIVILEGES ON *.* TO %s@'%%' WITH GRANT OPTION", user)}
		return []string{fmt.Sprintf("GRANT ALL PRIVILEGES ON *.* TO %s WITH GRANT OPTION", user)}
	}

	// If we're here, that mean we're denying at least one thing for this


@@ 155,7 160,7 @@ func (p *permissions) grants(user string) []string {
		// If every single privilege was granted across the database,
		// then combine that into a large GRANT.
		if tablesAllGranted(db.Tables) {
			out = append(out, fmt.Sprintf("GRANT ALL PRIVILEGES ON %s TO %s@'%%'", d, user))
			out = append(out, fmt.Sprintf("GRANT ALL PRIVILEGES ON %s TO %s", d, user))
			continue
		}



@@ 186,7 191,7 @@ func (p *permissions) grants(user string) []string {
			if len(parts) > 0 {
				sort.Strings(parts)
				tmp := strings.Join(parts, ", ")
				out = append(out, fmt.Sprintf("GRANT %s ON %s.%s TO %s@'%%'", tmp, d, t, user))
				out = append(out, fmt.Sprintf("GRANT %s ON %s.%s TO %s", tmp, d, t, user))
			}
		}
	}


@@ 367,18 372,21 @@ func parseLine(words []string) (*line, error) {
	if err != nil {
		return nil, fmt.Errorf("bad user collection: %w", err)
	}
	if len(l.users) == 0 {
		return nil, errors.New("expected user, got none")
	}
	l.databases, words, err = parseCollection("db", words)
	if err != nil {
		return nil, fmt.Errorf("bad db collection: %w", err)
	}
	l.statements, words, err = parseCollection("statement", words)
	if err != nil {
		return nil, fmt.Errorf("bad statement collection: %w", err)
	}
	l.tables, words, err = parseCollection("table", words)
	if err != nil {
		return nil, fmt.Errorf("bad table collection: %w", err)
	}
	l.statements, words, err = parseCollection("statement", words)
	if err != nil {
		return nil, fmt.Errorf("bad statement collection: %w", err)
	}
	l.columns, words, err = parseCollection("column", words)
	if err != nil {
		return nil, fmt.Errorf("bad column collection: %w", err)


@@ 386,6 394,25 @@ func parseLine(words []string) (*line, error) {
	if len(words) > 0 {
		return nil, fmt.Errorf("expected <eol>, got %s", words)
	}

	// Don't allow more specific definitions like "column" without all
	// less-specific definitions like "table" or "statement" being defined.
	// If this fails silently, it's difficult to detect.
	if len(l.tables) > 0 {
		if len(l.databases) == 0 {
			return nil, errors.New("undefined database")
		}
	}
	if len(l.statements) > 0 {
		if len(l.tables) == 0 {
			return nil, errors.New("undefined table")
		}
	}
	if len(l.columns) > 0 {
		if len(l.statements) == 0 {
			return nil, errors.New("undefined statement")
		}
	}
	return l, nil
}



@@ 440,7 467,7 @@ func parseVar(words []string) (string, []string, error) {
	}

	// TODO(egtann) perhaps some regex validation on this to ensure it's
	// [a-z_][a-z0-9_]*
	// [a-z][a-z0-9_]*
	key := words[0]

	if words[1] != "=" {


@@ 471,7 498,8 @@ func substituteVars(ss []string, vars map[string][]string) []string {

func in(ss []string, s string) bool {
	for _, x := range ss {
		if x == s {
		switch x {
		case "all", s:
			return true
		}
	}

M testdata/complex => testdata/complex +3 -3
@@ 1,5 1,5 @@
read  = { bob jim }
write = { alice _dashboard }
write = { alice@localhost _dashboard }
admin = sarah
crud  = { select insert update delete }



@@ 7,6 7,6 @@ deny all

allow root

allow $read  db dashboard statement select table { admins users }
allow $write db dashboard statement $crud
allow $read  db dashboard table { admins users } statement select
allow $write db dashboard table all statement $crud
allow $admin db { dashboard nogrant }