package sum
import (
"bufio"
"errors"
"fmt"
"io"
"strings"
)
type ast struct {
vars map[string][]string
lines []*line
}
type line struct {
line int
verb string
users []string
databases []string
tables []string
statements []string
columns []string
}
type seenLines struct {
databases map[string]struct{}
statements map[string]struct{}
tables map[string]struct{}
columns map[string]struct{}
}
type Permissions struct {
Deny bool `json:",omitempty"`
Databases map[string]*DBPermission `json:",omitempty"`
}
type DBPermission struct {
Deny bool `json:",omitempty"`
Tables map[string]*TablePermission `json:",omitempty"`
}
type StatementPermission struct {
Deny bool `json:",omitempty"`
Columns map[string]bool `json:",omitempty"`
}
type TablePermission struct {
Deny bool `json:",omitempty"`
Statements map[string]*StatementPermission `json:",omitempty"`
}
// compile permissions for every user in statements.
func compile(
schema *Schema,
allStatements []string,
a *ast,
) (map[string]*Permissions, error) {
// First collect a list of all users
//
// TODO(egtann) this can be done during parsing...
userSet := map[string]struct{}{}
for _, line := range a.lines {
for _, u := range line.users {
if strings.HasPrefix(u, "$") {
users := a.vars[strings.TrimPrefix(u, "$")]
for _, u2 := range users {
userSet[u2] = struct{}{}
}
continue
}
if u != "all" {
userSet[u] = struct{}{}
}
}
}
allUsers := make([]string, 0, len(userSet))
for u := range userSet {
allUsers = append(allUsers, u)
}
// Create lines split out for each user, substituting variables and
// "all" users as needed.
userLines := map[string][]*line{}
for _, line := range a.lines {
if len(line.users) == 1 && line.users[0] == "all" {
line.users = allUsers
}
for _, u := range line.users {
if !strings.HasPrefix(u, "$") {
userLines[u] = append(userLines[u], line)
continue
}
users := a.vars[strings.TrimPrefix(u, "$")]
for _, u2 := range users {
userLines[u2] = append(userLines[u2], line)
}
}
}
// Given each user's set of lines, assemble permissions for them.
userPerms := map[string]*Permissions{}
for u, lines := range userLines {
var err error
userPerms[u], err = permsForLines(schema, allStatements,
a.vars, lines)
if err != nil {
return nil, fmt.Errorf("perms for lines: %w", err)
}
}
return userPerms, nil
}
// apply permissions for a given line.
func (p *Permissions) apply(vars map[string][]string, l *line) error {
deny := l.verb == "deny"
// Make the appropriate variable substitutions
l.databases = substituteVars(l.databases, vars)
l.statements = substituteVars(l.statements, vars)
l.tables = substituteVars(l.tables, vars)
l.columns = substituteVars(l.columns, vars)
// Record all seen elements of the line. Any that haven't been seen by
// the time we're done applying it to our permissions are because those
// items don't exist in the schema. We want to fail loudly with a
// descriptive error, not silently ignoring it. If a database doesn't
// actually exist and we silently allow it, we revoke a user's
// permissions and not re-grant them.
seen := &seenLines{
databases: map[string]struct{}{},
statements: map[string]struct{}{},
tables: map[string]struct{}{},
columns: map[string]struct{}{},
}
// Iterate through all maps and set the permission toggles accordingly.
// Track the number of toggles to toggle the parent-level Deny flag if
// all children are set.
//
// TODO(egtann) opportunity for improved performance here by converting
// line slices to sets for constant-time lookups, since they're used in
// so many loops.
var a int
for d, db := range p.Databases {
if len(l.databases) == 1 && l.databases[0] == "any" {
seen = markAnyDBSeen(seen, p)
}
if len(l.databases) > 0 && !in(l.databases, d) {
continue
}
var i int
for t, table := range db.Tables {
if len(l.tables) == 1 && l.tables[0] == "any" {
seen = markAnyTableSeen(seen, db)
}
if len(l.tables) > 0 && !in(l.tables, t) {
continue
}
var j int
for s, statement := range table.Statements {
if len(l.statements) == 1 && l.statements[0] == "any" {
seen = markAnyStatementSeen(seen, table)
}
if len(l.statements) > 0 && !in(l.statements, s) {
continue
}
var k int
for c := range statement.Columns {
if len(l.columns) > 0 && !in(l.columns, c) {
continue
}
p.Databases[d].Tables[t].Statements[s].Columns[c] = deny
seen.databases[d] = struct{}{}
seen.tables[t] = struct{}{}
seen.statements[s] = struct{}{}
seen.columns[c] = struct{}{}
k++
}
if k == len(statement.Columns) {
p.Databases[d].Tables[t].Statements[s].Deny = deny
seen.databases[d] = struct{}{}
seen.tables[t] = struct{}{}
seen.statements[s] = struct{}{}
j++
}
}
if j == len(table.Statements) {
p.Databases[d].Tables[t].Deny = deny
seen.databases[d] = struct{}{}
seen.tables[t] = struct{}{}
i++
}
}
if i == len(db.Tables) {
p.Databases[d].Deny = deny
seen.databases[d] = struct{}{}
a++
}
}
if a == len(p.Databases) {
p.Deny = deny
// Return early, since we must have seen everything defined in
// the line, since it applied to everything.
return nil
}
// Confirm that we saw every item in the line in our schema. Work from
// most-specific to least-specific to produce better errors.
for _, c := range l.columns {
if c == "any" {
break
}
if _, ok := seen.columns[c]; !ok {
return fmt.Errorf("unapplied column rule: %s", c)
}
}
for _, t := range l.tables {
if t == "any" {
break
}
if _, ok := seen.tables[t]; !ok {
return fmt.Errorf("unapplied table rule: %s", t)
}
}
for _, s := range l.statements {
if s == "any" {
break
}
if _, ok := seen.statements[s]; !ok {
return fmt.Errorf("unapplied statement rule: %s", s)
}
}
for _, d := range l.databases {
if d == "any" {
break
}
if _, ok := seen.databases[d]; !ok {
return fmt.Errorf("unapplied database rule: %s", d)
}
}
return nil
}
func markAnyDBSeen(sl *seenLines, p *Permissions) *seenLines {
for d, db := range p.Databases {
sl.databases[d] = struct{}{}
for t, table := range db.Tables {
sl.tables[t] = struct{}{}
for s, statement := range table.Statements {
sl.statements[s] = struct{}{}
for c := range statement.Columns {
sl.columns[c] = struct{}{}
}
}
}
}
return sl
}
func markAnyTableSeen(sl *seenLines, db *DBPermission) *seenLines {
for t, table := range db.Tables {
sl.tables[t] = struct{}{}
for s, statement := range table.Statements {
sl.statements[s] = struct{}{}
for c := range statement.Columns {
sl.columns[c] = struct{}{}
}
}
}
return sl
}
func markAnyStatementSeen(sl *seenLines, table *TablePermission) *seenLines {
for s, statement := range table.Statements {
sl.statements[s] = struct{}{}
for c := range statement.Columns {
sl.columns[c] = struct{}{}
}
}
return sl
}
// permsForSchema builds a permission-set for a schema with every database,
// statement, table and column filled out. Everything defaults to allowed.
func permsForSchema(schema *Schema, allStatements []string) *Permissions {
perms := &Permissions{Databases: map[string]*DBPermission{}}
for dbName, tables := range schema.Databases {
perms.Databases[dbName] = &DBPermission{
Tables: map[string]*TablePermission{},
}
for tableName, columns := range tables {
perms.Databases[dbName].Tables[tableName] = &TablePermission{
Statements: map[string]*StatementPermission{},
}
for _, statement := range allStatements {
perms.Databases[dbName].Tables[tableName].Statements[statement] = &StatementPermission{
Columns: map[string]bool{},
}
for colName := range columns {
perms.Databases[dbName].Tables[tableName].Statements[statement].Columns[colName] = false
}
}
}
}
return perms
}
func permsForLines(
schema *Schema,
allStatements []string,
vars map[string][]string,
ls []*line,
) (*Permissions, error) {
perms := permsForSchema(schema, allStatements)
for _, l := range ls {
if err := perms.apply(vars, l); err != nil {
return nil, fmt.Errorf("apply line %d: %w", l.line, err)
}
}
return perms, nil
}
// parse into statements or vars and place into the AST.
func parse(r io.Reader) (*ast, error) {
scn := bufio.NewScanner(r)
a := &ast{vars: map[string][]string{}}
for i := 1; scn.Scan(); i++ {
var err error
i, err = parseScanner(a, scn, i)
if err != nil {
return nil, fmt.Errorf("parse line %d: %w", i, err)
}
}
if err := scn.Err(); err != nil {
return nil, fmt.Errorf("scan: %w", err)
}
return a, nil
}
func parseScanner(a *ast, scn *bufio.Scanner, curLine int) (int, error) {
line := strings.TrimSpace(scn.Text())
if line == "" {
return curLine, nil
}
if line[0] == '#' {
return curLine, nil
}
retLine := curLine
for line[len(line)-1] == '\\' && scn.Scan() {
nextLine := strings.TrimSpace(scn.Text())
line = fmt.Sprintf("%s %s", line[:len(line)-1], nextLine)
retLine++
}
words := strings.Fields(line)
switch words[0] {
case "allow", "deny":
// First line must be "deny all" to ensure running this
// multiple times produces the same output. This requirement
// may be relaxed or changed in the future.
if len(a.lines) == 0 {
if line != "deny all" {
return 0, errors.New("first statement must be 'deny all'")
}
}
l, err := parseLine(words)
if err != nil {
return 0, fmt.Errorf("parse line %d: %w", curLine, err)
}
l.line = curLine
a.lines = append(a.lines, l)
default:
key, vals, err := parseVar(words)
if err != nil {
return curLine, fmt.Errorf("parse var %d: %w", curLine, err)
}
a.vars[key] = vals
}
return retLine, nil
}
func parseLine(words []string) (*line, error) {
if len(words) < 2 {
return nil, errors.New("line is too short")
}
// Parse verb
l := &line{}
switch words[0] {
case "allow", "deny":
l.verb = words[0]
default:
return nil, fmt.Errorf("unexpected %s, want allow|deny", words[0])
}
words = words[1:]
var err error
l.users, words, err = parseCollection("", words)
if err != nil {
return nil, fmt.Errorf("bad user collection: %w", err)
}
if len(l.users) == 0 {
return nil, errors.New("expected user, got none")
}
l.databases, words, err = parseCollection("db", words)
if err != nil {
return nil, fmt.Errorf("bad db collection: %w", err)
}
l.tables, words, err = parseCollection("table", words)
if err != nil {
return nil, fmt.Errorf("bad table collection: %w", err)
}
l.statements, words, err = parseCollection("statement", words)
if err != nil {
return nil, fmt.Errorf("bad statement collection: %w", err)
}
l.columns, words, err = parseCollection("column", words)
if err != nil {
return nil, fmt.Errorf("bad column collection: %w", err)
}
if len(words) > 0 {
return nil, fmt.Errorf("expected <eol>, got %s", words)
}
// Don't allow more specific definitions like "column" without all
// less-specific definitions like "table" or "statement" being defined.
// If this fails silently, it's difficult to detect.
if len(l.tables) > 0 {
if len(l.databases) == 0 {
return nil, errors.New("undefined database")
}
}
if len(l.statements) > 0 {
if len(l.tables) == 0 {
return nil, errors.New("undefined table")
}
}
if len(l.columns) > 0 {
if len(l.statements) == 0 {
return nil, errors.New("undefined statement")
}
}
return l, nil
}
// parseCollection returning the collection itself, remaining words, and an
// error if any. If the leader is defined, then it must prefix
func parseCollection(leader string, words []string) ([]string, []string, error) {
if len(words) == 0 {
return nil, nil, nil
}
var i int
if leader != "" {
if words[i] != leader {
return nil, nil, fmt.Errorf("expected %s, got %s", leader, words[i])
}
i++
if len(words) < 2 {
return nil, nil, fmt.Errorf("expected %s name, got <eol>", leader)
}
}
var col []string
switch words[i] {
case "{":
// Keep going
for i = i + 1; i < len(words); i++ {
switch words[i] {
case "}":
var remainder []string
if len(words) > i+1 {
remainder = words[i+1:]
}
return col, remainder, nil
default:
col = append(col, words[i])
}
}
return nil, nil, errors.New("unmatched bracket")
default:
var remainder []string
if len(words) > i+1 {
remainder = words[i+1:]
}
return []string{words[i]}, remainder, nil
}
}
// parseVar into key and values.
func parseVar(words []string) (string, []string, error) {
if len(words) < 3 {
return "", nil, errors.New("line too short")
}
// TODO(egtann) perhaps some regex validation on this to ensure it's
// [a-z][a-z0-9_]*
key := words[0]
if words[1] != "=" {
return "", nil, fmt.Errorf("expected =, got %s", words[1])
}
// Pass a copy of the slice so we don't modify the original
vals, _, err := parseCollection("", words[2:])
if err != nil {
return "", nil, fmt.Errorf("parse collection: %w", err)
}
return key, vals, nil
}
func substituteVars(ss []string, vars map[string][]string) []string {
var out []string
for _, s := range ss {
if !strings.HasPrefix(s, "$") {
out = append(out, s)
continue
}
out = append(out, vars[strings.TrimPrefix(s, "$")]...)
}
return out
}
func in(ss []string, s string) bool {
for _, x := range ss {
switch x {
case "any", s:
return true
}
}
return false
}