// Copyright 2016 Sam Whited. // Use of this source code is governed by the BSD 2-clause license that can be // found in the LICENSE file. package xmpp import ( "context" "encoding/xml" "fmt" "io" "golang.org/x/text/language" "mellium.im/xmpp/internal" "mellium.im/xmpp/jid" "mellium.im/xmpp/ns" "mellium.im/xmpp/streamerror" ) const streamIDLength = 16 // SessionState represents the current state of an XMPP session. For a // description of each bit, see the various SessionState typed constants. type SessionState int8 const ( // Indicates that the underlying connection has been secured. For instance, // after STARTTLS has been performed or if a pre-secured connection is being // used such as websockets over HTTPS. Secure SessionState = 1 << iota // Indicates that the session has been authenticated (probably with SASL). Authn // Indicates that the session is fully negotiated and that XMPP stanzas may be // sent and received. Ready // Indicates that the session's streams must be restarted. This bit will // trigger an automatic restart and will be flipped back to off as soon as the // stream is restarted. StreamRestartRequired // Indicates that the session was initiated by a foreign entity. Received // Indicates that the stream should be (or has been) terminated. After being // flipped, this bit is left off unless the stream is restarted. This does not // provide any information about the underlying TLS connection. EndStream ) type stream struct { to *jid.JID from *jid.JID id string version internal.Version xmlns string lang language.Tag } // This MUST only return stream errors. func streamFromStartElement(s xml.StartElement) (stream, error) { stream := stream{} for _, attr := range s.Attr { switch attr.Name { case xml.Name{Space: "", Local: "to"}: stream.to = &jid.JID{} if err := stream.to.UnmarshalXMLAttr(attr); err != nil { return stream, streamerror.ImproperAddressing } case xml.Name{Space: "", Local: "from"}: stream.from = &jid.JID{} if err := stream.from.UnmarshalXMLAttr(attr); err != nil { return stream, streamerror.ImproperAddressing } case xml.Name{Space: "", Local: "id"}: stream.id = attr.Value case xml.Name{Space: "", Local: "version"}: (&stream.version).UnmarshalXMLAttr(attr) case xml.Name{Space: "", Local: "xmlns"}: if attr.Value != "jabber:client" && attr.Value != "jabber:server" { return stream, streamerror.InvalidNamespace } stream.xmlns = attr.Value case xml.Name{Space: "xmlns", Local: "stream"}: if attr.Value != ns.Stream { return stream, streamerror.InvalidNamespace } case xml.Name{Space: "xml", Local: "lang"}: stream.lang = language.Make(attr.Value) } } return stream, nil } // Sends a new XML header followed by a stream start element on the given // io.Writer. We don't use an xml.Encoder both because Go's standard library xml // package really doesn't like the namespaced stream:stream attribute and // because we can guarantee well-formedness of the XML with a print in this case // and printing is much faster than encoding. Afterwards, clear the // StreamRestartRequired bit and set the output stream information. func sendNewStream(w io.Writer, cfg *Config, id string) error { stream := stream{ to: cfg.Location, from: cfg.Origin, lang: cfg.Lang, version: cfg.Version, } switch cfg.S2S { case true: stream.xmlns = ns.Server case false: stream.xmlns = ns.Client } stream.id = id if id == "" { id = " " } else { id = ` id='` + id + `' ` } _, err := fmt.Fprint(w, xml.Header) if err != nil { return err } _, err = fmt.Fprintf(w, ``, id, cfg.Location.String(), cfg.Origin.String(), cfg.Version, cfg.Lang, stream.xmlns, ) if err != nil { return err } if conn, ok := w.(*Conn); ok { conn.out.stream = stream } return nil } func expectNewStream(ctx context.Context, r io.Reader) error { var foundHeader bool // If the reader is a Conn, use its decoder, otherwise make a new one. var d *xml.Decoder if conn, ok := r.(*Conn); ok { d = conn.in.d } else { d = xml.NewDecoder(r) } for { select { case <-ctx.Done(): return ctx.Err() default: } t, err := d.Token() if err != nil { return err } switch tok := t.(type) { case xml.StartElement: switch { case tok.Name.Local == "error" && tok.Name.Space == ns.Stream: se := streamerror.StreamError{} if err := d.DecodeElement(&se, &tok); err != nil { return err } return se case tok.Name.Local != "stream": return streamerror.BadFormat case tok.Name.Space != ns.Stream: return streamerror.InvalidNamespace } stream, err := streamFromStartElement(tok) switch { case err != nil: return err case stream.version != internal.DefaultVersion: return streamerror.UnsupportedVersion } if conn, ok := r.(*Conn); ok { if (conn.state&Received) != Received && stream.id == "" { // if we are the initiating entity and there is no stream ID… return streamerror.BadFormat } conn.in.stream = stream } return nil case xml.ProcInst: // TODO: If version or encoding are declared, validate XML 1.0 and UTF-8 if !foundHeader && tok.Target == "xml" { foundHeader = true continue } return streamerror.RestrictedXML case xml.EndElement: return streamerror.NotWellFormed default: return streamerror.RestrictedXML } } } func (c *Conn) negotiateStreams(ctx context.Context) (err error) { // Loop for as long as we're not done negotiating features or a stream restart // is still required. for done := false; !done || c.state&StreamRestartRequired == StreamRestartRequired; { if c.state&StreamRestartRequired == StreamRestartRequired { c.features = make(map[xml.Name]struct{}) c.in.d = xml.NewDecoder(c.rwc) c.out.e = xml.NewEncoder(c.rwc) c.state &= ^StreamRestartRequired if (c.state & Received) == Received { // If we're the receiving entity wait for a new stream, then send one in // response. if err = expectNewStream(ctx, c); err != nil { return err } if err = sendNewStream(c, c.config, internal.RandomID(streamIDLength)); err != nil { return err } } else { // If we're the initiating entity, send a new stream and then wait for one // in response. if err := sendNewStream(c, c.config, ""); err != nil { return err } if err := expectNewStream(ctx, c); err != nil { return err } } } if done, err = c.negotiateFeatures(ctx); err != nil { return err } } return nil }