@@ 244,8 244,30 @@ func NewServerSession(ctx context.Context, location, origin jid.JID, rw io.ReadW
// so the handler should not close over the session or use any of its send
// methods or a deadlock will occur.
// After Serve finishes running the handler, it flushes the output stream.
-func (s *Session) Serve(h Handler) error {
- return handleInputStream(s, h)
+func (s *Session) Serve(h Handler) (err error) {
+ if h == nil {
+ h = nopHandler{}
+ }
+
+ defer func() {
+ s.closeInputStream()
+ e := s.Close()
+ if err == nil {
+ err = e
+ }
+ }()
+
+ for {
+ select {
+ case <-s.in.ctx.Done():
+ return s.in.ctx.Err()
+ default:
+ }
+ err := handleInputStream(s, h)
+ if err != nil {
+ return s.sendError(err)
+ }
+ }
}
// sendError transmits an error on the session. If the error is not a standard
@@ 303,158 325,132 @@ func (r iqResponder) Close() error {
}
func handleInputStream(s *Session, handler Handler) (err error) {
- if handler == nil {
- handler = nopHandler{}
- }
-
- defer func() {
- s.closeInputStream()
- e := s.Close()
- if err == nil {
- err = e
- }
- }()
-
discard := xmlstream.Discard()
+ r := s.TokenReader()
+ defer r.Close()
- for {
- select {
- case <-s.in.ctx.Done():
- return s.in.ctx.Err()
- default:
- }
- tok, err := s.in.d.Token()
- if err != nil {
- // If this was a read timeout, don't try to send it. Just try to read
- // again.
- if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
- continue
- }
- return s.sendError(err)
+ tok, err := r.Token()
+ if err != nil {
+ // If this was a read timeout, don't try to send it. Just try to read
+ // again.
+ if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
+ return nil
}
+ return err
+ }
- var start xml.StartElement
- switch t := tok.(type) {
- case xml.StartElement:
- start = t
- case xml.EndElement:
- if t.Name.Space == ns.Stream && t.Name.Local == "stream" {
- return nil
- }
- // If this is a stream level end element but not </stream:stream>,
- // something is really weird…
- return s.sendError(stream.BadFormat)
- default:
- // If this isn't a start element, the stream is in a bad state.
- return s.sendError(stream.BadFormat)
+ var start xml.StartElement
+ switch t := tok.(type) {
+ case xml.StartElement:
+ start = t
+ case xml.EndElement:
+ if t.Name.Space == ns.Stream && t.Name.Local == "stream" {
+ return nil
}
-
- // Handle stream errors and unknown stream namespaced tokens first, before
- // delegating to the normal handler.
- if start.Name.Space == ns.Stream {
- switch start.Name.Local {
- case "error":
- // TODO: Unmarshal the error and return it.
- return nil
- default:
- return s.sendError(stream.UnsupportedStanzaType)
- }
+ // If this is a stream level end element but not </stream:stream>,
+ // something is really weird…
+ return stream.BadFormat
+ default:
+ // If this isn't a start element, the stream is in a bad state.
+ return stream.BadFormat
+ }
+
+ // Handle stream errors and unknown stream namespaced tokens first, before
+ // delegating to the normal handler.
+ if start.Name.Space == ns.Stream {
+ switch start.Name.Local {
+ case "error":
+ // TODO: Unmarshal the error and return it.
+ return nil
+ default:
+ return stream.UnsupportedStanzaType
}
+ }
- // If this is a stanza, normalize the "from" attribute.
- if isStanza(start.Name) {
- for i, attr := range start.Attr {
- if attr.Name.Local == "from" /*&& attr.Name.Space == start.Name.Space*/ {
- local := s.LocalAddr().Bare().String()
- // Try a direct comparison first to avoid expensive JID parsing.
- // TODO: really we should be parsing the JID here in case the server
- // is using a different version of PRECIS, stringprep, etc. and the
- // canonical representation isn't the same.
- if attr.Value == local {
- start.Attr[i].Value = ""
- }
- break
+ // If this is a stanza, normalize the "from" attribute.
+ if isStanza(start.Name) {
+ for i, attr := range start.Attr {
+ if attr.Name.Local == "from" /*&& attr.Name.Space == start.Name.Space*/ {
+ local := s.LocalAddr().Bare().String()
+ // Try a direct comparison first to avoid expensive JID parsing.
+ // TODO: really we should be parsing the JID here in case the server
+ // is using a different version of PRECIS, stringprep, etc. and the
+ // canonical representation isn't the same.
+ if attr.Value == local {
+ start.Attr[i].Value = ""
}
+ break
}
}
+ }
- var id string
- var needsResp bool
- if isIQ(start.Name) {
- _, id = getID(start)
-
- // If this is a response IQ (ie. an "error" or "result") check if we're
- // handling it as part of a SendIQ call.
- // If not, record this so that we can check if the user sends a response
- // later.
- if !iqNeedsResp(start.Attr) {
- s.sentIQMutex.Lock()
- c := s.sentIQs[id]
- s.sentIQMutex.Unlock()
- if c == nil {
- goto noreply
- }
-
- c <- iqResponder{
- r: xmlstream.MultiReader(xmlstream.Token(start), xmlstream.Inner(s.in.d), xmlstream.Token(start.End())),
- c: c,
- }
- <-c
- // Consume the rest of the stream before continuing the loop.
- _, err = xmlstream.Copy(discard, s.in.d)
- if err != nil {
- return s.sendError(err)
- }
- continue
- } else {
- needsResp = true
+ var id string
+ var needsResp bool
+ if isIQ(start.Name) {
+ _, id = getID(start)
+
+ // If this is a response IQ (ie. an "error" or "result") check if we're
+ // handling it as part of a SendIQ call.
+ // If not, record this so that we can check if the user sends a response
+ // later.
+ if !iqNeedsResp(start.Attr) {
+ s.sentIQMutex.Lock()
+ c := s.sentIQs[id]
+ s.sentIQMutex.Unlock()
+ if c == nil {
+ goto noreply
}
- }
- noreply:
-
- err = func() error {
- r := s.TokenReader()
- w := s.TokenWriter()
- defer r.Close()
- defer w.Close()
-
- rw := &responseChecker{
- TokenReader: xmlstream.Inner(r),
- TokenWriter: w,
- id: id,
+ c <- iqResponder{
+ r: xmlstream.MultiReader(xmlstream.Token(start), xmlstream.Inner(r), xmlstream.Token(start.End())),
+ c: c,
}
- if err := handler.HandleXMPP(rw, &start); err != nil {
+ <-c
+ // Consume the rest of the stream before continuing the loop.
+ _, err = xmlstream.Copy(discard, r)
+ if err != nil {
return err
}
+ return nil
+ }
+ needsResp = true
+ }
- // If the user did not write a response to an IQ, send a default one.
- if needsResp && !rw.wroteResp {
- _, err := xmlstream.Copy(w, stanza.WrapIQ(stanza.IQ{
- ID: id,
- Type: stanza.ErrorIQ,
- }, stanza.Error{
- Type: stanza.Cancel,
- Condition: stanza.ServiceUnavailable,
- }.TokenReader()))
- if err != nil {
- return err
- }
- }
+noreply:
- if err := w.Flush(); err != nil {
- return err
- }
+ w := s.TokenWriter()
+ defer w.Close()
+ rw := &responseChecker{
+ TokenReader: xmlstream.Inner(r),
+ TokenWriter: w,
+ id: id,
+ }
+ if err := handler.HandleXMPP(rw, &start); err != nil {
+ return err
+ }
- // Advance to the end of the current element before attempting to read the
- // next.
- _, err = xmlstream.Copy(discard, rw)
- return err
- }()
+ // If the user did not write a response to an IQ, send a default one.
+ if needsResp && !rw.wroteResp {
+ _, err := xmlstream.Copy(w, stanza.WrapIQ(stanza.IQ{
+ ID: id,
+ Type: stanza.ErrorIQ,
+ }, stanza.Error{
+ Type: stanza.Cancel,
+ Condition: stanza.ServiceUnavailable,
+ }.TokenReader()))
if err != nil {
- return s.sendError(err)
+ return err
}
}
+
+ if err := w.Flush(); err != nil {
+ return err
+ }
+
+ // Advance to the end of the current element before attempting to read the
+ // next.
+ _, err = xmlstream.Copy(discard, rw)
+ return err
}
type responseChecker struct {