~emersion/soju

1ac895430afb455537ef667987c4803c4c6348d6 — Simon Ser 1 year, 4 months ago 7ebe47a
contrib/znc-import: new utility

Allows populating the soju database from a ZNC config file.
1 files changed, 469 insertions(+), 0 deletions(-)

A contrib/znc-import.go
A contrib/znc-import.go => contrib/znc-import.go +469 -0
@@ 0,0 1,469 @@
package main

import (
	"bufio"
	"flag"
	"fmt"
	"io"
	"log"
	"net/url"
	"os"
	"strings"
	"unicode"

	"git.sr.ht/~emersion/soju"
	"git.sr.ht/~emersion/soju/config"
)

const usage = `usage: znc-import [options...] <znc config path>

Imports configuration from a ZNC file. Users and networks are merged if they
already exist in the soju database. ZNC settings overwrite existing soju
settings.

Options:

  -help             Show this help message
  -config <path>    Path to soju config file
  -user <username>  Limit import to username (may be specified multiple times)
  -network <name>   Limit import to network (may be specified multiple times)
`

func init() {
	flag.Usage = func() {
		fmt.Fprintf(flag.CommandLine.Output(), usage)
	}
}

func main() {
	var configPath string
	users := make(map[string]bool)
	networks := make(map[string]bool)
	flag.StringVar(&configPath, "config", "", "path to configuration file")
	flag.Var((*stringSetFlag)(&users), "user", "")
	flag.Var((*stringSetFlag)(&networks), "network", "")
	flag.Parse()

	zncPath := flag.Arg(0)
	if zncPath == "" {
		flag.Usage()
		os.Exit(1)
	}

	var cfg *config.Server
	if configPath != "" {
		var err error
		cfg, err = config.Load(configPath)
		if err != nil {
			log.Fatalf("failed to load config file: %v", err)
		}
	} else {
		cfg = config.Defaults()
	}

	db, err := soju.OpenSQLDB(cfg.SQLDriver, cfg.SQLSource)
	if err != nil {
		log.Fatalf("failed to open database: %v", err)
	}
	defer db.Close()

	f, err := os.Open(zncPath)
	if err != nil {
		log.Fatalf("failed to open ZNC configuration file: %v", err)
	}
	defer f.Close()

	zp := zncParser{bufio.NewReader(f), 1}
	root, err := zp.sectionBody("", "")
	if err != nil {
		log.Fatalf("failed to parse %q: line %v: %v", zncPath, zp.line, err)
	}

	l, err := db.ListUsers()
	if err != nil {
		log.Fatalf("failed to list users in DB: %v", err)
	}
	existingUsers := make(map[string]*soju.User, len(l))
	for i, u := range l {
		existingUsers[u.Username] = &l[i]
	}

	usersCreated := 0
	usersImported := 0
	networksImported := 0
	channelsImported := 0
	root.ForEach("User", func(section *zncSection) {
		username := section.Name
		if len(users) > 0 && !users[username] {
			return
		}
		usersImported++

		u, ok := existingUsers[username]
		if ok {
			log.Printf("user %q: updating existing user", username)
		} else {
			// "!!" is an invalid crypt format, thus disables password auth
			u = &soju.User{Username: username, Password: "!!"}
			usersCreated++
			log.Printf("user %q: creating new user", username)
		}

		u.Admin = section.Values.Get("Admin") == "true"

		if err := db.StoreUser(u); err != nil {
			log.Fatalf("failed to store user %q: %v", username, err)
		}

		l, err := db.ListNetworks(username)
		if err != nil {
			log.Fatalf("failed to list networks for user %q: %v", username, err)
		}
		existingNetworks := make(map[string]*soju.Network, len(l))
		for i, n := range l {
			existingNetworks[n.GetName()] = &l[i]
		}

		nick := section.Values.Get("Nick")
		realname := section.Values.Get("RealName")
		ident := section.Values.Get("Ident")

		section.ForEach("Network", func(section *zncSection) {
			netName := section.Name
			if len(networks) > 0 && !networks[netName] {
				return
			}
			networksImported++

			logPrefix := fmt.Sprintf("user %q: network %q: ", username, netName)
			logger := log.New(os.Stderr, logPrefix, log.LstdFlags|log.Lmsgprefix)

			netNick := section.Values.Get("Nick")
			if netNick == "" {
				netNick = nick
			}
			netRealname := section.Values.Get("RealName")
			if netRealname == "" {
				netRealname = realname
			}
			netIdent := section.Values.Get("Ident")
			if netIdent == "" {
				netIdent = ident
			}

			for _, name := range section.Values["LoadModule"] {
				switch name {
				case "sasl":
					logger.Printf("warning: SASL credentials not imported")
				case "nickserv":
					logger.Printf("warning: NickServ credentials not imported")
				case "perform":
					logger.Printf("warning: \"perform\" plugin commands not imported")
				}
			}

			u, pass, err := importNetworkServer(section.Values.Get("Server"))
			if err != nil {
				logger.Fatalf("failed to import server %q: %v", section.Values.Get("Server"), err)
			}

			n, ok := existingNetworks[netName]
			if ok {
				logger.Printf("updating existing network")
			} else {
				n = &soju.Network{Name: netName}
				logger.Printf("creating new network")
			}

			n.Addr = u.String()
			n.Nick = netNick
			n.Username = netIdent
			n.Realname = netRealname
			n.Pass = pass

			if err := db.StoreNetwork(username, n); err != nil {
				logger.Fatalf("failed to store network: %v", err)
			}

			l, err := db.ListChannels(n.ID)
			if err != nil {
				logger.Fatalf("failed to list channels: %v", err)
			}
			existingChannels := make(map[string]*soju.Channel, len(l))
			for i, ch := range l {
				existingChannels[ch.Name] = &l[i]
			}

			section.ForEach("Chan", func(section *zncSection) {
				chName := section.Name

				if section.Values.Get("Disabled") == "true" {
					logger.Printf("skipping import of disabled channel %q", chName)
					return
				}

				channelsImported++

				ch, ok := existingChannels[chName]
				if ok {
					logger.Printf("channel %q: updating existing channel", chName)
				} else {
					ch = &soju.Channel{Name: chName}
					logger.Printf("channel %q: creating new channel", chName)
				}

				ch.Key = section.Values.Get("Key")
				ch.Detached = section.Values.Get("Detached") == "true"

				if err := db.StoreChannel(n.ID, ch); err != nil {
					logger.Printf("channel %q: failed to store channel: %v", chName, err)
				}
			})
		})
	})

	if err := db.Close(); err != nil {
		log.Printf("failed to close database: %v", err)
	}

	if usersCreated > 0 {
		log.Printf("warning: user passwords haven't been imported, please set them with `sojuctl change-password <username>`")
	}

	log.Printf("imported %v users, %v networks and %v channels", usersImported, networksImported, channelsImported)
}

func importNetworkServer(s string) (u *url.URL, pass string, err error) {
	parts := strings.Fields(s)
	if len(parts) < 2 {
		return nil, "", fmt.Errorf("expected space-separated host and port")
	}

	scheme := "irc+insecure"
	host := parts[0]
	port := parts[1]
	if strings.HasPrefix(port, "+") {
		port = port[1:]
		scheme = "ircs"
	}

	if len(parts) > 2 {
		pass = parts[2]
	}

	u = &url.URL{
		Scheme: scheme,
		Host:   host + ":" + port,
	}
	return u, pass, nil
}

type zncSection struct {
	Type     string
	Name     string
	Values   zncValues
	Children []zncSection
}

func (s *zncSection) ForEach(typ string, f func(*zncSection)) {
	for _, section := range s.Children {
		if section.Type == typ {
			f(&section)
		}
	}
}

type zncValues map[string][]string

func (zv zncValues) Get(k string) string {
	if len(zv[k]) == 0 {
		return ""
	}
	return zv[k][0]
}

type zncParser struct {
	br   *bufio.Reader
	line int
}

func (zp *zncParser) readByte() (byte, error) {
	b, err := zp.br.ReadByte()
	if b == '\n' {
		zp.line++
	}
	return b, err
}

func (zp *zncParser) readRune() (rune, int, error) {
	r, n, err := zp.br.ReadRune()
	if r == '\n' {
		zp.line++
	}
	return r, n, err
}

func (zp *zncParser) sectionBody(typ, name string) (*zncSection, error) {
	section := &zncSection{Type: typ, Name: name, Values: make(zncValues)}

Loop:
	for {
		if err := zp.skipSpace(); err != nil {
			return nil, err
		}

		b, err := zp.br.Peek(2)
		if err == io.EOF {
			break
		} else if err != nil {
			return nil, err
		}

		switch b[0] {
		case '<':
			if b[1] == '/' {
				break Loop
			} else {
				childType, childName, err := zp.sectionHeader()
				if err != nil {
					return nil, err
				}
				child, err := zp.sectionBody(childType, childName)
				if err != nil {
					return nil, err
				}
				if footerType, err := zp.sectionFooter(); err != nil {
					return nil, err
				} else if footerType != childType {
					return nil, fmt.Errorf("invalid section footer: expected type %q, got %q", childType, footerType)
				}
				section.Children = append(section.Children, *child)
			}
		case '/':
			if b[1] == '/' {
				if err := zp.skipComment(); err != nil {
					return nil, err
				}
				break
			}
			fallthrough
		default:
			k, v, err := zp.keyValuePair()
			if err != nil {
				return nil, err
			}
			section.Values[k] = append(section.Values[k], v)
		}
	}

	return section, nil
}

func (zp *zncParser) skipSpace() error {
	for {
		r, _, err := zp.readRune()
		if err == io.EOF {
			return nil
		} else if err != nil {
			return err
		}

		if !unicode.IsSpace(r) {
			zp.br.UnreadRune()
			return nil
		}
	}
}

func (zp *zncParser) skipComment() error {
	if err := zp.expectRune('/'); err != nil {
		return err
	}
	if err := zp.expectRune('/'); err != nil {
		return err
	}

	for {
		b, err := zp.readByte()
		if err == io.EOF {
			return nil
		} else if err != nil {
			return err
		}

		if b == '\n' {
			return nil
		}
	}
}

func (zp *zncParser) sectionHeader() (string, string, error) {
	if err := zp.expectRune('<'); err != nil {
		return "", "", err
	}
	typ, err := zp.readWord(' ')
	if err != nil {
		return "", "", err
	}
	name, err := zp.readWord('>')
	return typ, name, err
}

func (zp *zncParser) sectionFooter() (string, error) {
	if err := zp.expectRune('<'); err != nil {
		return "", err
	}
	if err := zp.expectRune('/'); err != nil {
		return "", err
	}
	return zp.readWord('>')
}

func (zp *zncParser) keyValuePair() (string, string, error) {
	k, err := zp.readWord('=')
	if err != nil {
		return "", "", err
	}
	v, err := zp.readWord('\n')
	return strings.TrimSpace(k), strings.TrimSpace(v), err
}

func (zp *zncParser) expectRune(expected rune) error {
	r, _, err := zp.readRune()
	if err != nil {
		return err
	} else if r != expected {
		return fmt.Errorf("expected %q, got %q", expected, r)
	}
	return nil
}

func (zp *zncParser) readWord(delim byte) (string, error) {
	var sb strings.Builder
	for {
		b, err := zp.readByte()
		if err != nil {
			return "", err
		}

		if b == delim {
			return sb.String(), nil
		}
		if b == '\n' {
			return "", fmt.Errorf("expected %q before newline", delim)
		}

		sb.WriteByte(b)
	}
}

type stringSetFlag map[string]bool

func (v *stringSetFlag) String() string {
	return fmt.Sprint(map[string]bool(*v))
}

func (v *stringSetFlag) Set(s string) error {
	(*v)[s] = true
	return nil
}