~egtann/sum

ee1a1b48d30802877aa6ed635cbd660ee327d4d6 — Evan Tann 4 months ago ce3ca90
add line nums to errors, fix dupe users and seen bug
1 files changed, 90 insertions(+), 30 deletions(-)

M parser.go
M parser.go => parser.go +90 -30
@@ 15,6 15,7 @@ type ast struct {
}

type line struct {
	line       int
	verb       string
	users      []string
	databases  []string


@@ 23,6 24,13 @@ type line struct {
	columns    []string
}

type seenLines struct {
	databases  map[string]struct{}
	statements map[string]struct{}
	tables     map[string]struct{}
	columns    map[string]struct{}
}

type permissions struct {
	Deny      bool                     `json:",omitempty"`
	Databases map[string]*dbPermission `json:",omitempty"`


@@ 52,21 60,25 @@ func compile(
	// First collect a list of all users
	//
	// TODO(egtann) this can be done during parsing...
	var allUsers []string
	userSet := map[string]struct{}{}
	for _, line := range a.lines {
		for _, u := range line.users {
			if strings.HasPrefix(u, "$") {
				users := a.vars[strings.TrimPrefix(u, "$")]
				for _, u2 := range users {
					allUsers = append(allUsers, u2)
					userSet[u2] = struct{}{}
				}
				continue
			}
			if u != "all" {
				allUsers = append(allUsers, u)
				userSet[u] = struct{}{}
			}
		}
	}
	allUsers := make([]string, 0, len(userSet))
	for u := range userSet {
		allUsers = append(allUsers, u)
	}

	// Create lines split out for each user, substituting variables and
	// "all" users as needed.


@@ 211,12 223,7 @@ func (p *permissions) apply(vars map[string][]string, l *line) error {
	// descriptive error, not silently ignoring it. If a database doesn't
	// actually exist and we silently allow it, we revoke a user's
	// permissions and not re-grant them.
	seen := struct {
		databases  map[string]struct{}
		statements map[string]struct{}
		tables     map[string]struct{}
		columns    map[string]struct{}
	}{
	seen := &seenLines{
		databases:  map[string]struct{}{},
		statements: map[string]struct{}{},
		tables:     map[string]struct{}{},


@@ 232,16 239,25 @@ func (p *permissions) apply(vars map[string][]string, l *line) error {
	// so many loops.
	var a int
	for d, db := range p.Databases {
		if len(l.databases) == 1 && l.databases[0] == "any" {
			seen = markAnyDBSeen(seen, p)
		}
		if len(l.databases) > 0 && !in(l.databases, d) {
			continue
		}
		var i int
		for t, table := range db.Tables {
			if len(l.tables) == 1 && l.tables[0] == "any" {
				seen = markAnyTableSeen(seen, db)
			}
			if len(l.tables) > 0 && !in(l.tables, t) {
				continue
			}
			var j int
			for s, statement := range table.Statements {
				if len(l.statements) == 1 && l.statements[0] == "any" {
					seen = markAnyStatementSeen(seen, table)
				}
				if len(l.statements) > 0 && !in(l.statements, s) {
					continue
				}


@@ 293,23 309,23 @@ func (p *permissions) apply(vars map[string][]string, l *line) error {
			break
		}
		if _, ok := seen.columns[c]; !ok {
			return fmt.Errorf("unknown column: %s", c)
			return fmt.Errorf("unapplied column rule: %s", c)
		}
	}
	for _, s := range l.statements {
		if s == "any" {
	for _, t := range l.tables {
		if t == "any" {
			break
		}
		if _, ok := seen.statements[s]; !ok {
			return fmt.Errorf("unknown statement: %s", s)
		if _, ok := seen.tables[t]; !ok {
			return fmt.Errorf("unapplied table rule: %s", t)
		}
	}
	for _, t := range l.tables {
		if t == "any" {
	for _, s := range l.statements {
		if s == "any" {
			break
		}
		if _, ok := seen.tables[t]; !ok {
			return fmt.Errorf("unknown table: %s", t)
		if _, ok := seen.statements[s]; !ok {
			return fmt.Errorf("unapplied statement rule: %s", s)
		}
	}
	for _, d := range l.databases {


@@ 317,12 333,51 @@ func (p *permissions) apply(vars map[string][]string, l *line) error {
			break
		}
		if _, ok := seen.databases[d]; !ok {
			return fmt.Errorf("unknown database: %s", d)
			return fmt.Errorf("unapplied database rule: %s", d)
		}
	}
	return nil
}

func markAnyDBSeen(sl *seenLines, p *permissions) *seenLines {
	for d, db := range p.Databases {
		sl.databases[d] = struct{}{}
		for t, table := range db.Tables {
			sl.tables[t] = struct{}{}
			for s, statement := range table.Statements {
				sl.statements[s] = struct{}{}
				for c := range statement.Columns {
					sl.columns[c] = struct{}{}
				}
			}
		}
	}
	return sl
}

func markAnyTableSeen(sl *seenLines, db *dbPermission) *seenLines {
	for t, table := range db.Tables {
		sl.tables[t] = struct{}{}
		for s, statement := range table.Statements {
			sl.statements[s] = struct{}{}
			for c := range statement.Columns {
				sl.columns[c] = struct{}{}
			}
		}
	}
	return sl
}

func markAnyStatementSeen(sl *seenLines, table *tablePermission) *seenLines {
	for s, statement := range table.Statements {
		sl.statements[s] = struct{}{}
		for c := range statement.Columns {
			sl.columns[c] = struct{}{}
		}
	}
	return sl
}

// permsForSchema builds a permission-set for a schema with every database,
// statement, table and column filled out. Everything defaults to allowed.
func permsForSchema(schema *Schema, allStatements []string) *permissions {


@@ 357,7 412,7 @@ func permsForLines(
	perms := permsForSchema(schema, allStatements)
	for _, l := range ls {
		if err := perms.apply(vars, l); err != nil {
			return nil, fmt.Errorf("apply: %w", err)
			return nil, fmt.Errorf("apply line %d: %w", l.line, err)
		}
	}
	return perms, nil


@@ 367,9 422,11 @@ func permsForLines(
func parse(r io.Reader) (*ast, error) {
	scn := bufio.NewScanner(r)
	a := &ast{vars: map[string][]string{}}
	for scn.Scan() {
		if err := parseScanner(a, scn); err != nil {
			return nil, fmt.Errorf("parse line: %w", err)
	for i := 1; scn.Scan(); i++ {
		var err error
		i, err = parseScanner(a, scn, i)
		if err != nil {
			return nil, fmt.Errorf("parse line %d: %w", i, err)
		}
	}
	if err := scn.Err(); err != nil {


@@ 378,17 435,19 @@ func parse(r io.Reader) (*ast, error) {
	return a, nil
}

func parseScanner(a *ast, scn *bufio.Scanner) error {
func parseScanner(a *ast, scn *bufio.Scanner, curLine int) (int, error) {
	line := strings.TrimSpace(scn.Text())
	if line == "" {
		return nil
		return curLine, nil
	}
	if line[0] == '#' {
		return nil
		return curLine, nil
	}
	retLine := curLine
	for line[len(line)-1] == '\\' && scn.Scan() {
		nextLine := strings.TrimSpace(scn.Text())
		line = fmt.Sprintf("%s %s", line[:len(line)-1], nextLine)
		retLine++
	}
	words := strings.Fields(line)
	switch words[0] {


@@ 398,22 457,23 @@ func parseScanner(a *ast, scn *bufio.Scanner) error {
		// may be relaxed or changed in the future.
		if len(a.lines) == 0 {
			if line != "deny all" {
				return errors.New("first statement must be 'deny all'")
				return 0, errors.New("first statement must be 'deny all'")
			}
		}
		l, err := parseLine(words)
		if err != nil {
			return fmt.Errorf("parse line: %w", err)
			return 0, fmt.Errorf("parse line %d: %w", curLine, err)
		}
		l.line = curLine
		a.lines = append(a.lines, l)
	default:
		key, vals, err := parseVar(words)
		if err != nil {
			return fmt.Errorf("parse var: %w", err)
			return curLine, fmt.Errorf("parse var %d: %w", curLine, err)
		}
		a.vars[key] = vals
	}
	return nil
	return retLine, nil
}

func parseLine(words []string) (*line, error) {