@@ 15,6 15,7 @@ type ast struct {
}
type line struct {
+ line int
verb string
users []string
databases []string
@@ 23,6 24,13 @@ type line struct {
columns []string
}
+type seenLines struct {
+ databases map[string]struct{}
+ statements map[string]struct{}
+ tables map[string]struct{}
+ columns map[string]struct{}
+}
+
type permissions struct {
Deny bool `json:",omitempty"`
Databases map[string]*dbPermission `json:",omitempty"`
@@ 52,21 60,25 @@ func compile(
// First collect a list of all users
//
// TODO(egtann) this can be done during parsing...
- var allUsers []string
+ userSet := map[string]struct{}{}
for _, line := range a.lines {
for _, u := range line.users {
if strings.HasPrefix(u, "$") {
users := a.vars[strings.TrimPrefix(u, "$")]
for _, u2 := range users {
- allUsers = append(allUsers, u2)
+ userSet[u2] = struct{}{}
}
continue
}
if u != "all" {
- allUsers = append(allUsers, u)
+ userSet[u] = struct{}{}
}
}
}
+ allUsers := make([]string, 0, len(userSet))
+ for u := range userSet {
+ allUsers = append(allUsers, u)
+ }
// Create lines split out for each user, substituting variables and
// "all" users as needed.
@@ 211,12 223,7 @@ func (p *permissions) apply(vars map[string][]string, l *line) error {
// descriptive error, not silently ignoring it. If a database doesn't
// actually exist and we silently allow it, we revoke a user's
// permissions and not re-grant them.
- seen := struct {
- databases map[string]struct{}
- statements map[string]struct{}
- tables map[string]struct{}
- columns map[string]struct{}
- }{
+ seen := &seenLines{
databases: map[string]struct{}{},
statements: map[string]struct{}{},
tables: map[string]struct{}{},
@@ 232,16 239,25 @@ func (p *permissions) apply(vars map[string][]string, l *line) error {
// so many loops.
var a int
for d, db := range p.Databases {
+ if len(l.databases) == 1 && l.databases[0] == "any" {
+ seen = markAnyDBSeen(seen, p)
+ }
if len(l.databases) > 0 && !in(l.databases, d) {
continue
}
var i int
for t, table := range db.Tables {
+ if len(l.tables) == 1 && l.tables[0] == "any" {
+ seen = markAnyTableSeen(seen, db)
+ }
if len(l.tables) > 0 && !in(l.tables, t) {
continue
}
var j int
for s, statement := range table.Statements {
+ if len(l.statements) == 1 && l.statements[0] == "any" {
+ seen = markAnyStatementSeen(seen, table)
+ }
if len(l.statements) > 0 && !in(l.statements, s) {
continue
}
@@ 293,23 309,23 @@ func (p *permissions) apply(vars map[string][]string, l *line) error {
break
}
if _, ok := seen.columns[c]; !ok {
- return fmt.Errorf("unknown column: %s", c)
+ return fmt.Errorf("unapplied column rule: %s", c)
}
}
- for _, s := range l.statements {
- if s == "any" {
+ for _, t := range l.tables {
+ if t == "any" {
break
}
- if _, ok := seen.statements[s]; !ok {
- return fmt.Errorf("unknown statement: %s", s)
+ if _, ok := seen.tables[t]; !ok {
+ return fmt.Errorf("unapplied table rule: %s", t)
}
}
- for _, t := range l.tables {
- if t == "any" {
+ for _, s := range l.statements {
+ if s == "any" {
break
}
- if _, ok := seen.tables[t]; !ok {
- return fmt.Errorf("unknown table: %s", t)
+ if _, ok := seen.statements[s]; !ok {
+ return fmt.Errorf("unapplied statement rule: %s", s)
}
}
for _, d := range l.databases {
@@ 317,12 333,51 @@ func (p *permissions) apply(vars map[string][]string, l *line) error {
break
}
if _, ok := seen.databases[d]; !ok {
- return fmt.Errorf("unknown database: %s", d)
+ return fmt.Errorf("unapplied database rule: %s", d)
}
}
return nil
}
+func markAnyDBSeen(sl *seenLines, p *permissions) *seenLines {
+ for d, db := range p.Databases {
+ sl.databases[d] = struct{}{}
+ for t, table := range db.Tables {
+ sl.tables[t] = struct{}{}
+ for s, statement := range table.Statements {
+ sl.statements[s] = struct{}{}
+ for c := range statement.Columns {
+ sl.columns[c] = struct{}{}
+ }
+ }
+ }
+ }
+ return sl
+}
+
+func markAnyTableSeen(sl *seenLines, db *dbPermission) *seenLines {
+ for t, table := range db.Tables {
+ sl.tables[t] = struct{}{}
+ for s, statement := range table.Statements {
+ sl.statements[s] = struct{}{}
+ for c := range statement.Columns {
+ sl.columns[c] = struct{}{}
+ }
+ }
+ }
+ return sl
+}
+
+func markAnyStatementSeen(sl *seenLines, table *tablePermission) *seenLines {
+ for s, statement := range table.Statements {
+ sl.statements[s] = struct{}{}
+ for c := range statement.Columns {
+ sl.columns[c] = struct{}{}
+ }
+ }
+ return sl
+}
+
// permsForSchema builds a permission-set for a schema with every database,
// statement, table and column filled out. Everything defaults to allowed.
func permsForSchema(schema *Schema, allStatements []string) *permissions {
@@ 357,7 412,7 @@ func permsForLines(
perms := permsForSchema(schema, allStatements)
for _, l := range ls {
if err := perms.apply(vars, l); err != nil {
- return nil, fmt.Errorf("apply: %w", err)
+ return nil, fmt.Errorf("apply line %d: %w", l.line, err)
}
}
return perms, nil
@@ 367,9 422,11 @@ func permsForLines(
func parse(r io.Reader) (*ast, error) {
scn := bufio.NewScanner(r)
a := &ast{vars: map[string][]string{}}
- for scn.Scan() {
- if err := parseScanner(a, scn); err != nil {
- return nil, fmt.Errorf("parse line: %w", err)
+ for i := 1; scn.Scan(); i++ {
+ var err error
+ i, err = parseScanner(a, scn, i)
+ if err != nil {
+ return nil, fmt.Errorf("parse line %d: %w", i, err)
}
}
if err := scn.Err(); err != nil {
@@ 378,17 435,19 @@ func parse(r io.Reader) (*ast, error) {
return a, nil
}
-func parseScanner(a *ast, scn *bufio.Scanner) error {
+func parseScanner(a *ast, scn *bufio.Scanner, curLine int) (int, error) {
line := strings.TrimSpace(scn.Text())
if line == "" {
- return nil
+ return curLine, nil
}
if line[0] == '#' {
- return nil
+ return curLine, nil
}
+ retLine := curLine
for line[len(line)-1] == '\\' && scn.Scan() {
nextLine := strings.TrimSpace(scn.Text())
line = fmt.Sprintf("%s %s", line[:len(line)-1], nextLine)
+ retLine++
}
words := strings.Fields(line)
switch words[0] {
@@ 398,22 457,23 @@ func parseScanner(a *ast, scn *bufio.Scanner) error {
// may be relaxed or changed in the future.
if len(a.lines) == 0 {
if line != "deny all" {
- return errors.New("first statement must be 'deny all'")
+ return 0, errors.New("first statement must be 'deny all'")
}
}
l, err := parseLine(words)
if err != nil {
- return fmt.Errorf("parse line: %w", err)
+ return 0, fmt.Errorf("parse line %d: %w", curLine, err)
}
+ l.line = curLine
a.lines = append(a.lines, l)
default:
key, vals, err := parseVar(words)
if err != nil {
- return fmt.Errorf("parse var: %w", err)
+ return curLine, fmt.Errorf("parse var %d: %w", curLine, err)
}
a.vars[key] = vals
}
- return nil
+ return retLine, nil
}
func parseLine(words []string) (*line, error) {