~emersion/soju

313c6e7f978215ba2373f43aff6102c0c5cfacfa — Simon Ser 10 days ago 4e84b41
Add support for post-connection-registration upstream SASL auth

Once the downstream connection has logged in with their bouncer
credentials, allow them to issue more SASL auths which will be
redirected to the upstream network. This allows downstream clients
to provide UIs to login to transparently login to upstream networks.
3 files changed, 202 insertions(+), 121 deletions(-)

M downstream.go
M upstream.go
M user.go
M downstream.go => downstream.go +161 -112
@@ 244,6 244,11 @@ var passthroughIsupport = map[string]bool{
	"WHOX":          true,
}

type downstreamSASL struct {
	server                       sasl.Server
	plainUsername, plainPassword string
}

type downstreamConn struct {
	conn



@@ 267,12 272,11 @@ type downstreamConn struct {
	capVersion      int
	supportedCaps   map[string]string
	caps            map[string]bool
	sasl            *downstreamSASL

	lastBatchRef uint64

	monitored casemapMap

	saslServer sasl.Server
}

func newDownstreamConn(srv *Server, ic ircConn, id uint64) *downstreamConn {


@@ 686,102 690,28 @@ func (dc *downstreamConn) handleMessageUnregistered(ctx context.Context, msg *ir
			return err
		}
	case "AUTHENTICATE":
		if !dc.caps["sasl"] {
			return ircError{&irc.Message{
				Prefix:  dc.srv.prefix(),
				Command: irc.ERR_SASLFAIL,
				Params:  []string{"*", "AUTHENTICATE requires the \"sasl\" capability to be enabled"},
			}}
		}
		if len(msg.Params) == 0 {
			return ircError{&irc.Message{
				Prefix:  dc.srv.prefix(),
				Command: irc.ERR_SASLFAIL,
				Params:  []string{"*", "Missing AUTHENTICATE argument"},
			}}
		credentials, err := dc.handleAuthenticateCommand(msg)
		if err != nil {
			return err
		} else if credentials == nil {
			break
		}

		var resp []byte
		if msg.Params[0] == "*" {
			dc.saslServer = nil
			return ircError{&irc.Message{
				Prefix:  dc.srv.prefix(),
				Command: irc.ERR_SASLABORTED,
				Params:  []string{"*", "SASL authentication aborted"},
			}}
		} else if dc.saslServer == nil {
			mech := strings.ToUpper(msg.Params[0])
			switch mech {
			case "PLAIN":
				dc.saslServer = sasl.NewPlainServer(sasl.PlainAuthenticator(func(identity, username, password string) error {
					// TODO: we can't use the command context here, because it
					// gets cancelled once the command handler returns. SASL
					// might take multiple AUTHENTICATE commands to complete.
					return dc.authenticate(context.TODO(), username, password)
				}))
			default:
				return ircError{&irc.Message{
					Prefix:  dc.srv.prefix(),
					Command: irc.ERR_SASLFAIL,
					Params:  []string{"*", fmt.Sprintf("Unsupported SASL mechanism %q", mech)},
				}}
			}
		} else if msg.Params[0] == "+" {
			resp = nil
		} else {
			// TODO: multi-line messages
			var err error
			resp, err = base64.StdEncoding.DecodeString(msg.Params[0])
			if err != nil {
				dc.saslServer = nil
				return ircError{&irc.Message{
					Prefix:  dc.srv.prefix(),
					Command: irc.ERR_SASLFAIL,
					Params:  []string{"*", "Invalid base64-encoded response"},
				}}
			}
		}

		challenge, done, err := dc.saslServer.Next(resp)
		if err != nil {
			dc.saslServer = nil
			if ircErr, ok := err.(ircError); ok && ircErr.Message.Command == irc.ERR_PASSWDMISMATCH {
				return ircError{&irc.Message{
					Prefix:  dc.srv.prefix(),
					Command: irc.ERR_SASLFAIL,
					Params:  []string{"*", ircErr.Message.Params[1]},
				}}
			}
			dc.SendMessage(&irc.Message{
		if err := dc.authenticate(ctx, credentials.plainUsername, credentials.plainPassword); err != nil {
			dc.logger.Printf("SASL authentication error: %v", err)
			dc.endSASL(&irc.Message{
				Prefix:  dc.srv.prefix(),
				Command: irc.ERR_SASLFAIL,
				Params:  []string{"*", "SASL error"},
			})
			return fmt.Errorf("SASL authentication failed: %v", err)
		} else if done {
			dc.saslServer = nil
			// Technically we should send RPL_LOGGEDIN here. However we use
			// RPL_LOGGEDIN to mirror the upstream connection status. Let's see
			// how many clients that breaks. See:
			// https://github.com/ircv3/ircv3-specifications/pull/476
			dc.SendMessage(&irc.Message{
				Prefix:  dc.srv.prefix(),
				Command: irc.RPL_SASLSUCCESS,
				Params:  []string{dc.nick, "SASL authentication successful"},
			})
		} else {
			challengeStr := "+"
			if len(challenge) > 0 {
				challengeStr = base64.StdEncoding.EncodeToString(challenge)
			}

			// TODO: multi-line messages
			dc.SendMessage(&irc.Message{
				Prefix:  dc.srv.prefix(),
				Command: "AUTHENTICATE",
				Params:  []string{challengeStr},
				Params:  []string{"Authentication failed"},
			})
			break
		}

		// Technically we should send RPL_LOGGEDIN here. However we use
		// RPL_LOGGEDIN to mirror the upstream connection status. Let's
		// see how many clients that breaks. See:
		// https://github.com/ircv3/ircv3-specifications/pull/476
		dc.endSASL(nil)
	case "BOUNCER":
		var subcommand string
		if err := parseMessageParams(msg, &subcommand); err != nil {


@@ 951,6 881,107 @@ func (dc *downstreamConn) handleCapCommand(cmd string, args []string) error {
	return nil
}

func (dc *downstreamConn) handleAuthenticateCommand(msg *irc.Message) (result *downstreamSASL, err error) {
	defer func() {
		if err != nil {
			dc.sasl = nil
		}
	}()

	if !dc.caps["sasl"] {
		return nil, ircError{&irc.Message{
			Prefix:  dc.srv.prefix(),
			Command: irc.ERR_SASLFAIL,
			Params:  []string{"*", "AUTHENTICATE requires the \"sasl\" capability to be enabled"},
		}}
	}
	if len(msg.Params) == 0 {
		return nil, ircError{&irc.Message{
			Prefix:  dc.srv.prefix(),
			Command: irc.ERR_SASLFAIL,
			Params:  []string{"*", "Missing AUTHENTICATE argument"},
		}}
	}
	if msg.Params[0] == "*" {
		return nil, ircError{&irc.Message{
			Prefix:  dc.srv.prefix(),
			Command: irc.ERR_SASLABORTED,
			Params:  []string{"*", "SASL authentication aborted"},
		}}
	}

	var resp []byte
	if dc.sasl == nil {
		mech := strings.ToUpper(msg.Params[0])
		var server sasl.Server
		switch mech {
		case "PLAIN":
			server = sasl.NewPlainServer(sasl.PlainAuthenticator(func(identity, username, password string) error {
				dc.sasl.plainUsername = username
				dc.sasl.plainPassword = password
				return nil
			}))
		default:
			return nil, ircError{&irc.Message{
				Prefix:  dc.srv.prefix(),
				Command: irc.ERR_SASLFAIL,
				Params:  []string{"*", fmt.Sprintf("Unsupported SASL mechanism %q", mech)},
			}}
		}

		dc.sasl = &downstreamSASL{server: server}
	} else {
		// TODO: multi-line messages
		if msg.Params[0] == "+" {
			resp = nil
		} else if resp, err = base64.StdEncoding.DecodeString(msg.Params[0]); err != nil {
			return nil, ircError{&irc.Message{
				Prefix:  dc.srv.prefix(),
				Command: irc.ERR_SASLFAIL,
				Params:  []string{"*", "Invalid base64-encoded response"},
			}}
		}
	}

	challenge, done, err := dc.sasl.server.Next(resp)
	if err != nil {
		return nil, err
	} else if done {
		return dc.sasl, nil
	} else {
		challengeStr := "+"
		if len(challenge) > 0 {
			challengeStr = base64.StdEncoding.EncodeToString(challenge)
		}

		// TODO: multi-line messages
		dc.SendMessage(&irc.Message{
			Prefix:  dc.srv.prefix(),
			Command: "AUTHENTICATE",
			Params:  []string{challengeStr},
		})
		return nil, nil
	}
}

func (dc *downstreamConn) endSASL(msg *irc.Message) {
	if dc.sasl == nil {
		return
	}

	dc.sasl = nil

	if msg != nil {
		dc.SendMessage(msg)
	} else {
		dc.SendMessage(&irc.Message{
			Prefix:  dc.srv.prefix(),
			Command: irc.RPL_SASLSUCCESS,
			Params:  []string{dc.nick, "SASL authentication successful"},
		})
	}
}

func (dc *downstreamConn) setSupportedCap(name, value string) {
	prevValue, hasPrev := dc.supportedCaps[name]
	changed := !hasPrev || prevValue != value


@@ 1141,9 1172,8 @@ func (dc *downstreamConn) register(ctx context.Context) error {
		return fmt.Errorf("tried to register twice")
	}

	if dc.saslServer != nil {
		dc.saslServer = nil
		dc.SendMessage(&irc.Message{
	if dc.sasl != nil {
		dc.endSASL(&irc.Message{
			Prefix:  dc.srv.prefix(),
			Command: irc.ERR_SASLABORTED,
			Params:  []string{"*", "SASL authentication aborted"},


@@ 2330,6 2360,40 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
			Command: "INVITE",
			Params:  []string{upstreamUser, upstreamChannel},
		})
	case "AUTHENTICATE":
		// Post-connection-registration AUTHENTICATE is unsupported in
		// multi-upstream mode, or if the upstream doesn't support SASL
		uc := dc.upstream()
		if uc == nil || !uc.caps["sasl"] {
			return ircError{&irc.Message{
				Prefix:  dc.srv.prefix(),
				Command: irc.ERR_SASLFAIL,
				Params:  []string{dc.nick, "Upstream network authentication not supported"},
			}}
		}

		credentials, err := dc.handleAuthenticateCommand(msg)
		if err != nil {
			return err
		}

		if credentials != nil {
			if uc.saslClient != nil {
				dc.endSASL(&irc.Message{
					Prefix:  dc.srv.prefix(),
					Command: irc.ERR_SASLFAIL,
					Params:  []string{dc.nick, "Another authentication attempt is already in progress"},
				})
				return nil
			}

			uc.logger.Printf("starting post-registration SASL PLAIN authentication with username %q", credentials.plainUsername)
			uc.saslClient = sasl.NewPlainClient("", credentials.plainUsername, credentials.plainPassword)
			uc.enqueueCommand(dc, &irc.Message{
				Command: "AUTHENTICATE",
				Params:  []string{"PLAIN"},
			})
		}
	case "MONITOR":
		// MONITOR is unsupported in multi-upstream mode
		uc := dc.upstream()


@@ 2700,23 2764,8 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.

func (dc *downstreamConn) handleNickServPRIVMSG(ctx context.Context, uc *upstreamConn, text string) {
	username, password, ok := parseNickServCredentials(text, uc.nick)
	if !ok {
		return
	}

	// User may have e.g. EXTERNAL mechanism configured. We do not want to
	// automatically erase the key pair or any other credentials.
	if uc.network.SASL.Mechanism != "" && uc.network.SASL.Mechanism != "PLAIN" {
		return
	}

	dc.logger.Printf("auto-saving NickServ credentials with username %q", username)
	n := uc.network
	n.SASL.Mechanism = "PLAIN"
	n.SASL.Plain.Username = username
	n.SASL.Plain.Password = password
	if err := dc.srv.db.StoreNetwork(ctx, dc.user.ID, &n.Network); err != nil {
		dc.logger.Printf("failed to save NickServ credentials: %v", err)
	if ok {
		uc.network.autoSaveSASLPlain(ctx, username, password)
	}
}


M upstream.go => upstream.go +25 -9
@@ 31,6 31,7 @@ var permanentUpstreamCaps = map[string]bool{
	"labeled-response": true,
	"message-tags":     true,
	"multi-prefix":     true,
	"sasl":             true,
	"server-time":      true,
	"setname":          true,



@@ 293,6 294,12 @@ func (uc *upstreamConn) endPendingCommands() {
					Command: irc.RPL_ENDOFWHO,
					Params:  []string{dc.nick, mask, "End of /WHO"},
				})
			case "AUTHENTICATE":
				dc.endSASL(&irc.Message{
					Prefix:  dc.srv.prefix(),
					Command: irc.ERR_SASLABORTED,
					Params:  []string{dc.nick, "SASL authentication aborted"},
				})
			default:
				panic(fmt.Errorf("Unsupported pending command %q", pendingCmd.msg.Command))
			}


@@ 311,7 318,7 @@ func (uc *upstreamConn) sendNextPendingCommand(cmd string) {

func (uc *upstreamConn) enqueueCommand(dc *downstreamConn, msg *irc.Message) {
	switch msg.Command {
	case "LIST", "WHO":
	case "LIST", "WHO", "AUTHENTICATE":
		// Supported
	default:
		panic(fmt.Errorf("Unsupported pending command %q", msg.Command))


@@ 612,10 619,20 @@ func (uc *upstreamConn) handleMessage(msg *irc.Message) error {
		uc.saslClient = nil
		uc.saslStarted = false

		uc.SendMessage(&irc.Message{
			Command: "CAP",
			Params:  []string{"END"},
		})
		if dc, _ := uc.dequeueCommand("AUTHENTICATE"); dc != nil && dc.sasl != nil {
			if msg.Command == irc.RPL_SASLSUCCESS {
				uc.network.autoSaveSASLPlain(context.TODO(), dc.sasl.plainUsername, dc.sasl.plainPassword)
			}

			dc.endSASL(msg)
		}

		if !uc.registered {
			uc.SendMessage(&irc.Message{
				Command: "CAP",
				Params:  []string{"END"},
			})
		}
	case irc.RPL_WELCOME:
		uc.registered = true
		uc.logger.Printf("connection registered")


@@ 1704,10 1721,6 @@ func (uc *upstreamConn) requestCaps() {
		}
	}

	if uc.requestSASL() && !uc.caps["sasl"] {
		requestCaps = append(requestCaps, "sasl")
	}

	if len(requestCaps) == 0 {
		return
	}


@@ 1749,6 1762,9 @@ func (uc *upstreamConn) handleCapAck(name string, ok bool) error {

	switch name {
	case "sasl":
		if !uc.requestSASL() {
			return nil
		}
		if !ok {
			uc.logger.Printf("server refused to acknowledge the SASL capability")
			return nil

M user.go => user.go +16 -0
@@ 404,6 404,22 @@ func (net *network) detachedMessageNeedsRelay(ch *Channel, msg *irc.Message) boo
	return ch.RelayDetached == FilterMessage || ((ch.RelayDetached == FilterHighlight || ch.RelayDetached == FilterDefault) && highlight)
}

func (net *network) autoSaveSASLPlain(ctx context.Context, username, password string) {
	// User may have e.g. EXTERNAL mechanism configured. We do not want to
	// automatically erase the key pair or any other credentials.
	if net.SASL.Mechanism != "" && net.SASL.Mechanism != "PLAIN" {
		return
	}

	net.logger.Printf("auto-saving SASL PLAIN credentials with username %q", username)
	net.SASL.Mechanism = "PLAIN"
	net.SASL.Plain.Username = username
	net.SASL.Plain.Password = password
	if err := net.user.srv.db.StoreNetwork(ctx, net.user.ID, &net.Network); err != nil {
		net.logger.Printf("failed to save SASL PLAIN credentials: %v", err)
	}
}

type user struct {
	User
	srv    *Server