@@ 16,6 16,7 @@ import (
)
type SASLClient interface {
+ Early() bool
Handshake() (mech string)
Respond(challenge string) (res string, err error)
}
@@ 25,6 26,10 @@ type SASLPlain struct {
Password string
}
+func (auth *SASLPlain) Early() bool {
+ return true
+}
+
func (auth *SASLPlain) Handshake() (mech string) {
mech = "PLAIN"
return
@@ 175,6 180,17 @@ func NewSession(out chan<- Message, params SessionParams) *Session {
}
s.out <- NewMessage("NICK", s.nick)
s.out <- NewMessage("USER", s.user, "0", "*", s.real)
+ if s.auth != nil && s.auth.Early() {
+ h := s.auth.Handshake()
+ s.out <- NewMessage("AUTHENTICATE", h)
+ res, err := s.auth.Respond("+")
+ if err != nil {
+ s.out <- NewMessage("AUTHENTICATE", "*")
+ } else {
+ s.out <- NewMessage("AUTHENTICATE", res)
+ }
+ s.auth = nil
+ }
if s.auth == nil {
s.endRegistration()
@@ 574,7 590,9 @@ func (s *Session) handleUnregistered(msg Message) (Event, error) {
s.out <- NewMessage("NICK", nick+"_")
case rplSaslsuccess:
- s.endRegistration()
+ if s.auth != nil {
+ s.endRegistration()
+ }
default:
return s.handleRegistered(msg)
}
@@ 647,7 665,9 @@ func (s *Session) handleMessageRegistered(msg Message, playback bool) (Event, er
s.user = prefix.User
s.host = prefix.Host
case errNicklocked, errSaslfail, errSasltoolong, errSaslaborted, errSaslalready, rplSaslmechs:
- s.endRegistration()
+ if s.auth != nil {
+ s.endRegistration()
+ }
return ErrorEvent{
Severity: SeverityFail,
Code: msg.Command,
@@ 1350,6 1370,10 @@ func (s *Session) handleMessageRegistered(msg Message, playback bool) (Event, er
if len(msg.Params) < 2 {
return nil, msg.errNotEnoughParams(2)
}
+ if msg.Command == errUnknowncommand && msg.Params[1] == "BOUNCER" {
+ // ignore any error in response to unconditional BOUNCER LISTNETWORKS
+ return nil, nil
+ }
return ErrorEvent{
Severity: ReplySeverity(msg.Command),
Code: msg.Command,
@@ 1488,13 1512,11 @@ func (s *Session) endRegistration() {
if s.registered {
return
}
- if _, ok := s.enabledCaps["soju.im/bouncer-networks"]; !ok {
- s.out <- NewMessage("CAP", "END")
- } else if s.netID == "" {
+ if s.netID != "" {
+ s.out <- NewMessage("BOUNCER", "BIND", s.netID)
s.out <- NewMessage("CAP", "END")
- s.out <- NewMessage("BOUNCER", "LISTNETWORKS")
} else {
- s.out <- NewMessage("BOUNCER", "BIND", s.netID)
s.out <- NewMessage("CAP", "END")
+ s.out <- NewMessage("BOUNCER", "LISTNETWORKS")
}
}