// Copyright 2016 Sam Whited. // Use of this source code is governed by the BSD 2-clause license that can be // found in the LICENSE file. package xmpp import ( "context" "crypto/tls" "encoding/xml" "errors" "fmt" "io" "mellium.im/sasl" "mellium.im/xmpp/internal/saslerr" "mellium.im/xmpp/ns" "mellium.im/xmpp/streamerror" ) // BUG(ssw): SASL feature does not have security layer byte precision. // SASL returns a stream feature for performing authentication using the Simple // Authentication and Security Layer (SASL) as defined in RFC 4422. It panics if // no mechanisms are specified. The order in which mechanisms are specified will // be the prefered order, so stronger mechanisms should be listed first. func SASL(mechanisms ...sasl.Mechanism) StreamFeature { if len(mechanisms) == 0 { panic("xmpp: Must specify at least 1 SASL mechanism") } return StreamFeature{ Name: xml.Name{Space: ns.SASL, Local: "mechanisms"}, Necessary: Secure, Prohibited: Authn, List: func(ctx context.Context, conn io.Writer) (req bool, err error) { req = true _, err = fmt.Fprint(conn, ``) if err != nil { return } for _, m := range mechanisms { select { case <-ctx.Done(): return true, ctx.Err() default: } if _, err = fmt.Fprint(conn, ``); err != nil { return } if err = xml.EscapeText(conn, []byte(m.Name)); err != nil { return } if _, err = fmt.Fprint(conn, ``); err != nil { return } } _, err = fmt.Fprint(conn, ``) return }, Parse: func(ctx context.Context, d *xml.Decoder, start *xml.StartElement) (bool, interface{}, error) { parsed := struct { XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl mechanisms"` List []string `xml:"urn:ietf:params:xml:ns:xmpp-sasl mechanism"` }{} err := d.DecodeElement(&parsed, start) return true, parsed.List, err }, Negotiate: func(ctx context.Context, conn *Conn, data interface{}) (mask SessionState, rwc io.ReadWriteCloser, err error) { if (conn.state & Received) == Received { panic("SASL server not yet implemented") } else { var selected sasl.Mechanism // Select a mechanism, prefering the client order. selectmechanism: for _, m := range mechanisms { for _, name := range data.([]string) { if name == m.Name { selected = m break selectmechanism } } } // No matching mechanism found… if selected.Name == "" { return mask, nil, errors.New(`No matching SASL mechanisms found`) } c := conn.Config() opts := []sasl.Option{ sasl.Authz(c.Identity), sasl.Credentials(c.Username, c.Password), sasl.RemoteMechanisms(data.([]string)...), } if tlsconn, ok := conn.rwc.(*tls.Conn); ok { opts = append(opts, sasl.ConnState(tlsconn.ConnectionState())) } client := sasl.NewClient(selected, opts...) more, resp, err := client.Step(nil) if err != nil { return mask, nil, err } // RFC6120 §6.4.2: // If the initiating entity needs to send a zero-length initial // response, it MUST transmit the response as a single equals sign // character ("="), which indicates that the response is present but // contains no data. if len(resp) == 0 { resp = []byte{'='} } // Send and the initial payload to start SASL auth. if _, err = fmt.Fprintf(conn, `%s`, selected.Name, resp, ); err != nil { return mask, nil, err } // If we're already done after the first step, decode the or // before we exit. if !more { tok, err := conn.in.d.Token() if err != nil { return mask, nil, err } if t, ok := tok.(xml.StartElement); ok { // TODO: Handle the additional data that could be returned if // success? _, _, err := decodeSASLChallenge(conn.in.d, t, false) if err != nil { return mask, nil, err } } else { return mask, nil, streamerror.BadFormat } } success := false for more { select { case <-ctx.Done(): return mask, nil, ctx.Err() default: } tok, err := conn.in.d.Token() if err != nil { return mask, nil, err } var challenge []byte if t, ok := tok.(xml.StartElement); ok { challenge, success, err = decodeSASLChallenge(conn.in.d, t, true) if err != nil { return mask, nil, err } } else { return mask, nil, streamerror.BadFormat } if more, resp, err = client.Step(challenge); err != nil { return mask, nil, err } if !more && success { // We're done with SASL and we're successful break } // TODO: What happens if there's more and success (broken server)? if _, err = fmt.Fprintf(conn, `%s`, resp); err != nil { return mask, nil, err } } return Authn, conn.Raw(), nil } }, } } func decodeSASLChallenge(d *xml.Decoder, start xml.StartElement, allowChallenge bool) (challenge []byte, success bool, err error) { switch start.Name { case xml.Name{Space: ns.SASL, Local: "challenge"}, xml.Name{Space: ns.SASL, Local: "success"}: if !allowChallenge && start.Name.Local == "challenge" { return nil, false, streamerror.UnsupportedStanzaType } challenge := struct { Data []byte `xml:",chardata"` }{} if err = d.DecodeElement(&challenge, &start); err != nil { return nil, false, err } return challenge.Data, start.Name.Local == "success", nil case xml.Name{Space: ns.SASL, Local: "failure"}: fail := saslerr.Failure{} if err = d.DecodeElement(&fail, &start); err != nil { return nil, false, err } return nil, false, fail default: return nil, false, streamerror.UnsupportedStanzaType } }