~emersion/soju

203dc3df6ada6d9567382d3a40b39e3927188033 — fox.cpp 4 years ago c43ce0d
Implement upstream SASL EXTERNAL support

Closes: https://todo.sr.ht/~emersion/soju/47
5 files changed, 251 insertions(+), 8 deletions(-)

M db.go
M doc/soju.1.scd
M downstream.go
M service.go
M upstream.go
M db.go => db.go +29 -7
@@ 21,6 21,14 @@ type SASL struct {
		Username string
		Password string
	}

	// TLS client certificate authentication.
	External struct {
		// X.509 certificate in DER form.
		CertBlob []byte
		// PKCS#8 private key in DER form.
		PrivKeyBlob []byte
	}
}

type Network struct {


@@ 68,6 76,8 @@ CREATE TABLE Network (
	sasl_mechanism VARCHAR(255),
	sasl_plain_username VARCHAR(255),
	sasl_plain_password VARCHAR(255),
	sasl_external_cert BLOB DEFAULT NULL,
	sasl_external_key BLOB DEFAULT NULL,
	FOREIGN KEY(user) REFERENCES User(username),
	UNIQUE(user, addr, nick)
);


@@ 87,6 97,8 @@ var migrations = []string{
	"", // migration #0 is reserved for schema initialization
	"ALTER TABLE Network ADD COLUMN connect_commands VARCHAR(1023)",
	"ALTER TABLE Channel ADD COLUMN detached INTEGER NOT NULL DEFAULT 0",
	"ALTER TABLE Network ADD COLUMN sasl_external_cert BLOB DEFAULT NULL",
	"ALTER TABLE Network ADD COLUMN sasl_external_key BLOB DEFAULT NULL",
}

type DB struct {


@@ 238,7 250,8 @@ func (db *DB) ListNetworks(username string) ([]Network, error) {
	defer db.lock.RUnlock()

	rows, err := db.db.Query(`SELECT id, name, addr, nick, username, realname, pass,
			connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password
			connect_commands, sasl_mechanism, sasl_plain_username, sasl_plain_password,
			sasl_external_cert, sasl_external_key
		FROM Network
		WHERE user = ?`,
		username)


@@ 253,7 266,8 @@ func (db *DB) ListNetworks(username string) ([]Network, error) {
		var name, username, realname, pass, connectCommands *string
		var saslMechanism, saslPlainUsername, saslPlainPassword *string
		err := rows.Scan(&net.ID, &name, &net.Addr, &net.Nick, &username, &realname,
			&pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword)
			&pass, &connectCommands, &saslMechanism, &saslPlainUsername, &saslPlainPassword,
			&net.SASL.External.CertBlob, &net.SASL.External.PrivKeyBlob)
		if err != nil {
			return nil, err
		}


@@ 293,6 307,10 @@ func (db *DB) StoreNetwork(username string, network *Network) error {
		case "PLAIN":
			saslPlainUsername = toStringPtr(network.SASL.Plain.Username)
			saslPlainPassword = toStringPtr(network.SASL.Plain.Password)
			network.SASL.External.CertBlob = nil
			network.SASL.External.PrivKeyBlob = nil
		case "EXTERNAL":
			// keep saslPlain* nil
		default:
			return fmt.Errorf("soju: cannot store network: unsupported SASL mechanism %q", network.SASL.Mechanism)
		}


@@ 302,18 320,22 @@ func (db *DB) StoreNetwork(username string, network *Network) error {
	if network.ID != 0 {
		_, err = db.db.Exec(`UPDATE Network
			SET name = ?, addr = ?, nick = ?, username = ?, realname = ?, pass = ?, connect_commands = ?,
				sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ?
				sasl_mechanism = ?, sasl_plain_username = ?, sasl_plain_password = ?,
				sasl_external_cert = ?, sasl_external_key = ?
			WHERE id = ?`,
			netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
			saslMechanism, saslPlainUsername, saslPlainPassword, network.ID)
			saslMechanism, saslPlainUsername, saslPlainPassword,
			network.SASL.External.CertBlob, network.SASL.External.PrivKeyBlob,
			network.ID)
	} else {
		var res sql.Result
		res, err = db.db.Exec(`INSERT INTO Network(user, name, addr, nick, username,
				realname, pass, connect_commands, sasl_mechanism, sasl_plain_username,
				sasl_plain_password)
			VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
				sasl_plain_password, sasl_external_cert, sasl_external_key)
			VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
			username, netName, network.Addr, network.Nick, netUsername, realname, pass, connectCommands,
			saslMechanism, saslPlainUsername, saslPlainPassword)
			saslMechanism, saslPlainUsername, saslPlainPassword, network.SASL.External.CertBlob,
			network.SASL.External.PrivKeyBlob)
		if err != nil {
			return err
		}

M doc/soju.1.scd => doc/soju.1.scd +21 -0
@@ 119,6 119,27 @@ abbreviated form, for instance *network* can be abbreviated as *net* or just
		Connect with the specified nickname. By default, the account's username
		is used.

*certfp generate* *[options...]* <network name>
	Generate self-signed certificate and use it for authentication.

	Generates RSA-3072 private key by default.

	Options are:

	*-key-type* <type>
		Private key algoritm to use. Valid values are: rsa, ecdsa, ed25519.
		ecdsa uses NIST P-521 curve.

	*-bits* <bits>
		Size of RSA key to generate. Ignored for other key types.

*certfp fingerprint* <network name>
	Show SHA-1 and SHA-256 fingerprints for the certificate
	currently used with the network.

*certfp reset* <network name>
	Disable SASL EXTERNAL authentication and remove stored certificate.

*network delete* <name>
	Disconnect and delete a network.


M downstream.go => downstream.go +6 -0
@@ 1523,6 1523,12 @@ func (dc *downstreamConn) handleNickServPRIVMSG(uc *upstreamConn, text string) {
		return
	}

	// User may have e.g. EXTERNAL mechanism configured. We do not want to
	// automatically erase the key pair or any other credentials.
	if uc.network.SASL.Mechanism != "" && uc.network.SASL.Mechanism != "PLAIN" {
		return
	}

	dc.logger.Printf("auto-saving NickServ credentials with username %q", username)
	n := uc.network
	n.SASL.Mechanism = "PLAIN"

M service.go => service.go +164 -0
@@ 1,10 1,24 @@
package soju

import (
	"crypto"
	"crypto/ecdsa"
	"crypto/ed25519"
	"crypto/elliptic"
	"crypto/rand"
	"crypto/rsa"
	"crypto/sha1"
	"crypto/sha256"
	"crypto/x509"
	"crypto/x509/pkix"
	"encoding/hex"
	"errors"
	"flag"
	"fmt"
	"io/ioutil"
	"math/big"
	"strings"
	"time"

	"github.com/google/shlex"
	"golang.org/x/crypto/bcrypt"


@@ 119,6 133,25 @@ func init() {
				},
			},
		},
		"certfp": {
			children: serviceCommandSet{
				"generate": {
					usage:  "[-key-type rsa|ecdsa|ed25519] [-bits N] <network name>",
					desc:   "generate a new self-signed certificate, defaults to using RSA-3072 key",
					handle: handleServiceCertfpGenerate,
				},
				"fingerprint": {
					usage:  "<network name>",
					desc:   "show fingerprints of certificate associated with the network",
					handle: handleServiceCertfpFingerprints,
				},
				"reset": {
					usage:  "<network name>",
					desc:   "disable SASL EXTERNAL authentication and remove stored certificate",
					handle: handleServiceCertfpReset,
				},
			},
		},
		"change-password": {
			usage:  "<new password>",
			desc:   "change your password",


@@ 127,6 160,137 @@ func init() {
	}
}

func handleServiceCertfpGenerate(dc *downstreamConn, params []string) error {
	fs := newFlagSet()
	keyType := fs.String("key-type", "rsa", "key type to generate (rsa, ecdsa, ed25519)")
	bits := fs.Int("bits", 3072, "size of key to generate, meaningful only for RSA")

	if err := fs.Parse(params); err != nil {
		return err
	}

	if len(fs.Args()) != 1 {
		return errors.New("exactly one argument is required")
	}

	net := dc.user.getNetwork(fs.Arg(0))
	if net == nil {
		return fmt.Errorf("unknown network %q", fs.Arg(0))
	}

	var (
		privKey crypto.PrivateKey
		pubKey  crypto.PublicKey
	)
	switch *keyType {
	case "rsa":
		key, err := rsa.GenerateKey(rand.Reader, *bits)
		if err != nil {
			return err
		}
		privKey = key
		pubKey = key.Public()
	case "ecdsa":
		key, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
		if err != nil {
			return err
		}
		privKey = key
		pubKey = key.Public()
	case "ed25519":
		var err error
		pubKey, privKey, err = ed25519.GenerateKey(rand.Reader)
		if err != nil {
			return err
		}
	}

	// Using PKCS#8 allows easier extension for new key types.
	privKeyBytes, err := x509.MarshalPKCS8PrivateKey(privKey)
	if err != nil {
		return err
	}

	notBefore := time.Now()
	// Lets make a fair assumption nobody will use the same cert for more than 20 years...
	notAfter := notBefore.Add(24 * time.Hour * 365 * 20)
	serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
	serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
	if err != nil {
		return err
	}
	cert := &x509.Certificate{
		SerialNumber: serialNumber,
		Subject:      pkix.Name{CommonName: "soju auto-generated certificate"},
		NotBefore:    notBefore,
		NotAfter:     notAfter,
		KeyUsage:     x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
		ExtKeyUsage:  []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
	}
	derBytes, err := x509.CreateCertificate(rand.Reader, cert, cert, pubKey, privKey)
	if err != nil {
		return err
	}

	net.SASL.External.CertBlob = derBytes
	net.SASL.External.PrivKeyBlob = privKeyBytes
	net.SASL.Mechanism = "EXTERNAL"

	if err := dc.srv.db.StoreNetwork(net.Username, &net.Network); err != nil {
		return err
	}

	sendServicePRIVMSG(dc, "certificate generated")

	sha1Sum := sha1.Sum(derBytes)
	sendServicePRIVMSG(dc, "SHA-1 fingerprint: "+hex.EncodeToString(sha1Sum[:]))
	sha256Sum := sha256.Sum256(derBytes)
	sendServicePRIVMSG(dc, "SHA-256 fingerprint: "+hex.EncodeToString(sha256Sum[:]))

	return nil
}

func handleServiceCertfpFingerprints(dc *downstreamConn, params []string) error {
	if len(params) != 1 {
		return fmt.Errorf("expected exactly one argument")
	}

	net := dc.user.getNetwork(params[0])
	if net == nil {
		return fmt.Errorf("unknown network %q", params[0])
	}

	sha1Sum := sha1.Sum(net.SASL.External.CertBlob)
	sendServicePRIVMSG(dc, "SHA-1 fingerprint: "+hex.EncodeToString(sha1Sum[:]))
	sha256Sum := sha256.Sum256(net.SASL.External.CertBlob)
	sendServicePRIVMSG(dc, "SHA-256 fingerprint: "+hex.EncodeToString(sha256Sum[:]))
	return nil
}

func handleServiceCertfpReset(dc *downstreamConn, params []string) error {
	if len(params) != 1 {
		return fmt.Errorf("expected exactly one argument")
	}

	net := dc.user.getNetwork(params[0])
	if net == nil {
		return fmt.Errorf("unknown network %q", params[0])
	}

	net.SASL.External.CertBlob = nil
	net.SASL.External.PrivKeyBlob = nil

	if net.SASL.Mechanism == "EXTERNAL" {
		net.SASL.Mechanism = ""
	}
	if err := dc.srv.db.StoreNetwork(dc.user.Username, &net.Network); err != nil {
		return err
	}

	sendServicePRIVMSG(dc, "certificate reset")
	return nil
}

func appendServiceCommandSetHelp(cmds serviceCommandSet, prefix []string, l *[]string) {
	for name, cmd := range cmds {
		words := append(prefix, name)

M upstream.go => upstream.go +31 -1
@@ 1,7 1,10 @@
package soju

import (
	"crypto"
	"crypto/sha256"
	"crypto/tls"
	"crypto/x509"
	"encoding/base64"
	"errors"
	"fmt"


@@ 100,7 103,31 @@ func connectToUpstream(network *network) (*upstreamConn, error) {
		}

		logger.Printf("connecting to TLS server at address %q", addr)
		netConn, err = tls.DialWithDialer(&dialer, "tcp", addr, nil)

		var cfg *tls.Config
		if network.SASL.Mechanism == "EXTERNAL" {
			if network.SASL.External.CertBlob == nil {
				return nil, fmt.Errorf("missing certificate for authentication")
			}
			if network.SASL.External.PrivKeyBlob == nil {
				return nil, fmt.Errorf("missing private key for authentication")
			}
			key, err := x509.ParsePKCS8PrivateKey(network.SASL.External.PrivKeyBlob)
			if err != nil {
				return nil, fmt.Errorf("failed to parse private key: %v", err)
			}
			cfg = &tls.Config{
				Certificates: []tls.Certificate{
					{
						Certificate: [][]byte{network.SASL.External.CertBlob},
						PrivateKey:  key.(crypto.PrivateKey),
					},
				},
			}
			logger.Printf("using TLS client certificate %x", sha256.Sum256(network.SASL.External.CertBlob))
		}

		netConn, err = tls.DialWithDialer(&dialer, "tcp", addr, cfg)
	case "irc+insecure":
		if !strings.ContainsRune(addr, ':') {
			addr = addr + ":6667"


@@ 1399,6 1426,9 @@ func (uc *upstreamConn) handleCapAck(name string, ok bool) error {
		case "PLAIN":
			uc.logger.Printf("starting SASL PLAIN authentication with username %q", auth.Plain.Username)
			uc.saslClient = sasl.NewPlainClient("", auth.Plain.Username, auth.Plain.Password)
		case "EXTERNAL":
			uc.logger.Printf("starting SASL EXTERNAL authentication")
			uc.saslClient = sasl.NewExternalClient("")
		default:
			return fmt.Errorf("unsupported SASL mechanism %q", name)
		}