~samwhited/xmpp

1c406eb9289cc80a386ba7a59c70be0dd30e82ac — Sam Whited 1 year, 7 months ago ee04abc
xmpp: minor refactor of input stream handling
1 files changed, 128 insertions(+), 132 deletions(-)

M session.go
M session.go => session.go +128 -132
@@ 244,8 244,30 @@ func NewServerSession(ctx context.Context, location, origin jid.JID, rw io.ReadW
// so the handler should not close over the session or use any of its send
// methods or a deadlock will occur.
// After Serve finishes running the handler, it flushes the output stream.
func (s *Session) Serve(h Handler) error {
	return handleInputStream(s, h)
func (s *Session) Serve(h Handler) (err error) {
	if h == nil {
		h = nopHandler{}
	}

	defer func() {
		s.closeInputStream()
		e := s.Close()
		if err == nil {
			err = e
		}
	}()

	for {
		select {
		case <-s.in.ctx.Done():
			return s.in.ctx.Err()
		default:
		}
		err := handleInputStream(s, h)
		if err != nil {
			return s.sendError(err)
		}
	}
}

// sendError transmits an error on the session. If the error is not a standard


@@ 303,158 325,132 @@ func (r iqResponder) Close() error {
}

func handleInputStream(s *Session, handler Handler) (err error) {
	if handler == nil {
		handler = nopHandler{}
	}

	defer func() {
		s.closeInputStream()
		e := s.Close()
		if err == nil {
			err = e
		}
	}()

	discard := xmlstream.Discard()
	r := s.TokenReader()
	defer r.Close()

	for {
		select {
		case <-s.in.ctx.Done():
			return s.in.ctx.Err()
		default:
		}
		tok, err := s.in.d.Token()
		if err != nil {
			// If this was a read timeout, don't try to send it. Just try to read
			// again.
			if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
				continue
			}
			return s.sendError(err)
	tok, err := r.Token()
	if err != nil {
		// If this was a read timeout, don't try to send it. Just try to read
		// again.
		if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
			return nil
		}
		return err
	}

		var start xml.StartElement
		switch t := tok.(type) {
		case xml.StartElement:
			start = t
		case xml.EndElement:
			if t.Name.Space == ns.Stream && t.Name.Local == "stream" {
				return nil
			}
			// If this is a stream level end element but not </stream:stream>,
			// something is really weird…
			return s.sendError(stream.BadFormat)
		default:
			// If this isn't a start element, the stream is in a bad state.
			return s.sendError(stream.BadFormat)
	var start xml.StartElement
	switch t := tok.(type) {
	case xml.StartElement:
		start = t
	case xml.EndElement:
		if t.Name.Space == ns.Stream && t.Name.Local == "stream" {
			return nil
		}

		// Handle stream errors and unknown stream namespaced tokens first, before
		// delegating to the normal handler.
		if start.Name.Space == ns.Stream {
			switch start.Name.Local {
			case "error":
				// TODO: Unmarshal the error and return it.
				return nil
			default:
				return s.sendError(stream.UnsupportedStanzaType)
			}
		// If this is a stream level end element but not </stream:stream>,
		// something is really weird…
		return stream.BadFormat
	default:
		// If this isn't a start element, the stream is in a bad state.
		return stream.BadFormat
	}

	// Handle stream errors and unknown stream namespaced tokens first, before
	// delegating to the normal handler.
	if start.Name.Space == ns.Stream {
		switch start.Name.Local {
		case "error":
			// TODO: Unmarshal the error and return it.
			return nil
		default:
			return stream.UnsupportedStanzaType
		}
	}

		// If this is a stanza, normalize the "from" attribute.
		if isStanza(start.Name) {
			for i, attr := range start.Attr {
				if attr.Name.Local == "from" /*&& attr.Name.Space == start.Name.Space*/ {
					local := s.LocalAddr().Bare().String()
					// Try a direct comparison first to avoid expensive JID parsing.
					// TODO: really we should be parsing the JID here in case the server
					// is using a different version of PRECIS, stringprep, etc. and the
					// canonical representation isn't the same.
					if attr.Value == local {
						start.Attr[i].Value = ""
					}
					break
	// If this is a stanza, normalize the "from" attribute.
	if isStanza(start.Name) {
		for i, attr := range start.Attr {
			if attr.Name.Local == "from" /*&& attr.Name.Space == start.Name.Space*/ {
				local := s.LocalAddr().Bare().String()
				// Try a direct comparison first to avoid expensive JID parsing.
				// TODO: really we should be parsing the JID here in case the server
				// is using a different version of PRECIS, stringprep, etc. and the
				// canonical representation isn't the same.
				if attr.Value == local {
					start.Attr[i].Value = ""
				}
				break
			}
		}
	}

		var id string
		var needsResp bool
		if isIQ(start.Name) {
			_, id = getID(start)

			// If this is a response IQ (ie. an "error" or "result") check if we're
			// handling it as part of a SendIQ call.
			// If not, record this so that we can check if the user sends a response
			// later.
			if !iqNeedsResp(start.Attr) {
				s.sentIQMutex.Lock()
				c := s.sentIQs[id]
				s.sentIQMutex.Unlock()
				if c == nil {
					goto noreply
				}

				c <- iqResponder{
					r: xmlstream.MultiReader(xmlstream.Token(start), xmlstream.Inner(s.in.d), xmlstream.Token(start.End())),
					c: c,
				}
				<-c
				// Consume the rest of the stream before continuing the loop.
				_, err = xmlstream.Copy(discard, s.in.d)
				if err != nil {
					return s.sendError(err)
				}
				continue
			} else {
				needsResp = true
	var id string
	var needsResp bool
	if isIQ(start.Name) {
		_, id = getID(start)

		// If this is a response IQ (ie. an "error" or "result") check if we're
		// handling it as part of a SendIQ call.
		// If not, record this so that we can check if the user sends a response
		// later.
		if !iqNeedsResp(start.Attr) {
			s.sentIQMutex.Lock()
			c := s.sentIQs[id]
			s.sentIQMutex.Unlock()
			if c == nil {
				goto noreply
			}
		}

	noreply:

		err = func() error {
			r := s.TokenReader()
			w := s.TokenWriter()
			defer r.Close()
			defer w.Close()

			rw := &responseChecker{
				TokenReader: xmlstream.Inner(r),
				TokenWriter: w,
				id:          id,
			c <- iqResponder{
				r: xmlstream.MultiReader(xmlstream.Token(start), xmlstream.Inner(r), xmlstream.Token(start.End())),
				c: c,
			}
			if err := handler.HandleXMPP(rw, &start); err != nil {
			<-c
			// Consume the rest of the stream before continuing the loop.
			_, err = xmlstream.Copy(discard, r)
			if err != nil {
				return err
			}
			return nil
		}
		needsResp = true
	}

			// If the user did not write a response to an IQ, send a default one.
			if needsResp && !rw.wroteResp {
				_, err := xmlstream.Copy(w, stanza.WrapIQ(stanza.IQ{
					ID:   id,
					Type: stanza.ErrorIQ,
				}, stanza.Error{
					Type:      stanza.Cancel,
					Condition: stanza.ServiceUnavailable,
				}.TokenReader()))
				if err != nil {
					return err
				}
			}
noreply:

			if err := w.Flush(); err != nil {
				return err
			}
	w := s.TokenWriter()
	defer w.Close()
	rw := &responseChecker{
		TokenReader: xmlstream.Inner(r),
		TokenWriter: w,
		id:          id,
	}
	if err := handler.HandleXMPP(rw, &start); err != nil {
		return err
	}

			// Advance to the end of the current element before attempting to read the
			// next.
			_, err = xmlstream.Copy(discard, rw)
			return err
		}()
	// If the user did not write a response to an IQ, send a default one.
	if needsResp && !rw.wroteResp {
		_, err := xmlstream.Copy(w, stanza.WrapIQ(stanza.IQ{
			ID:   id,
			Type: stanza.ErrorIQ,
		}, stanza.Error{
			Type:      stanza.Cancel,
			Condition: stanza.ServiceUnavailable,
		}.TokenReader()))
		if err != nil {
			return s.sendError(err)
			return err
		}
	}

	if err := w.Flush(); err != nil {
		return err
	}

	// Advance to the end of the current element before attempting to read the
	// next.
	_, err = xmlstream.Copy(discard, rw)
	return err
}

type responseChecker struct {