~samwhited/xmpp

1fde60c27602e5a40dd80739b37f3de18bfc229f — Sam Whited 2 years ago 19734c6
all: move session Flush into TokenWriter
9 files changed, 33 insertions(+), 25 deletions(-)

M bind.go
M features.go
M go.mod
M go.sum
M ibr2/ibr2.go
M mux/mux.go
M mux/mux_test.go
M session.go
M session_test.go
M bind.go => bind.go +2 -2
@@ 161,7 161,7 @@ func bind(server func(jid.JID, string) (jid.JID, error)) StreamFeature {
				if err != nil {
					return mask, nil, err
				}
				return mask, nil, session.Flush()
				return mask, nil, w.Flush()
			}

			// Client encodes an IQ requesting resource binding.


@@ 179,7 179,7 @@ func bind(server func(jid.JID, string) (jid.JID, error)) StreamFeature {
			if err != nil {
				return mask, nil, err
			}
			if err = session.Flush(); err != nil {
			if err = w.Flush(); err != nil {
				return mask, nil, err
			}


M features.go => features.go +1 -1
@@ 288,7 288,7 @@ func writeStreamFeatures(ctx context.Context, s *Session, features []StreamFeatu
	if err = w.EncodeToken(start.End()); err != nil {
		return list, err
	}
	if err = s.Flush(); err != nil {
	if err = w.Flush(); err != nil {
		return list, err
	}
	return list, err

M go.mod => go.mod +1 -1
@@ 8,5 8,5 @@ require (
	golang.org/x/net v0.0.0-20190611141213-3f473d35a33a
	golang.org/x/text v0.3.2
	mellium.im/sasl v0.2.1
	mellium.im/xmlstream v0.13.3
	mellium.im/xmlstream v0.13.4
)

M go.sum => go.sum +2 -2
@@ 17,5 17,5 @@ mellium.im/reader v0.1.0 h1:UUEMev16gdvaxxZC7fC08j7IzuDKh310nB6BlwnxTww=
mellium.im/reader v0.1.0/go.mod h1:F+X5HXpkIfJ9EE1zHQG9lM/hO946iYAmU7xjg5dsQHI=
mellium.im/sasl v0.2.1 h1:nspKSRg7/SyO0cRGY71OkfHab8tf9kCts6a6oTDut0w=
mellium.im/sasl v0.2.1/go.mod h1:ROaEDLQNuf9vjKqE1SrAfnsobm2YKXT1gnN1uDp1PjQ=
mellium.im/xmlstream v0.13.3 h1:tc2HN6gNIBbpB/uTgxZgw7QV+GNkSI8dSfpruJOBf9I=
mellium.im/xmlstream v0.13.3/go.mod h1:n9o+Vjw+o977AcxnRnb+5dKfzniGZV4DSQKgQLNYWaU=
mellium.im/xmlstream v0.13.4 h1:AN9dkVD9K/CHzp21IVKnKXYUCsBjCEsrPqW8k1/tgGw=
mellium.im/xmlstream v0.13.4/go.mod h1:O7wqreSmFi1LOh4RiK7r2j4H4pYDgzo1qv5ZkYJZ7Ns=

M ibr2/ibr2.go => ibr2/ibr2.go +1 -1
@@ 163,7 163,7 @@ func negotiateFunc(challenges ...Challenge) func(context.Context, *xmpp.Session,
				if err != nil {
					return
				}
				err = session.Flush()
				err = w.Flush()
				if err != nil {
					return
				}

M mux/mux.go => mux/mux.go +4 -1
@@ 87,7 87,10 @@ func fallback(s *xmpp.Session, start *xml.StartElement) error {
	w := s.TokenWriter()
	defer w.Close()
	_, err := xmlstream.Copy(w, xmlstream.Wrap(e.TokenReader(), *start))
	return err
	if err != nil {
		return err
	}
	return w.Flush()
}

// New allocates and returns a new ServeMux.

M mux/mux_test.go => mux/mux_test.go +0 -3
@@ 108,9 108,6 @@ func TestFallback(t *testing.T) {
	if err != nil {
		t.Errorf("Unexpected error: `%v'", err)
	}
	if err = s.Flush(); err != nil {
		t.Errorf("Unexpected error: `%v'", err)
	}

	const expected = `<iq to="juliet@example.com" from="romeo@example.com" id="123" type="error"><error type="cancel"><feature-not-implemented xmlns="urn:ietf:params:xml:ns:xmpp-stanzas"></feature-not-implemented></error></iq>`
	if buf.String() != expected {

M session.go => session.go +21 -13
@@ 440,7 440,7 @@ func handleInputStream(s *Session, handler Handler) (err error) {
			}
		}

		if err := s.Flush(); err != nil {
		if err := s.out.e.Flush(); err != nil {
			return err
		}



@@ 518,6 518,16 @@ func (lwc *lockWriteCloser) EncodeToken(t xml.Token) error {
	return lwc.w.out.e.EncodeToken(t)
}

func (lwc *lockWriteCloser) Flush() error {
	if lwc.err != nil {
		return nil
	}
	if lwc.w.state&OutputStreamClosed == OutputStreamClosed {
		return ErrOutputStreamClosed
	}
	return lwc.w.out.e.Flush()
}

func (lwc *lockWriteCloser) Close() error {
	if lwc.err != nil {
		return nil


@@ 560,7 570,7 @@ func (lrc *lockReadCloser) Close() error {
// method is called.
// After the TokenWriteCloser has been closed, any future writes will return
// io.EOF.
func (s *Session) TokenWriter() xmlstream.TokenWriteCloser {
func (s *Session) TokenWriter() xmlstream.TokenWriteFlushCloser {
	s.out.Lock()

	return &lockWriteCloser{


@@ 584,14 594,6 @@ func (s *Session) TokenReader() xmlstream.TokenReadCloser {
	}
}

// Flush satisfies the xmlstream.TokenWriter interface.
func (s *Session) Flush() error {
	if s.state&OutputStreamClosed == OutputStreamClosed {
		return ErrOutputStreamClosed
	}
	return s.out.e.Flush()
}

// Close ends the output stream (by sending a closing </stream:stream> token).
// It does not close the underlying connection.
// Calling Close() multiple times will only result in one closing


@@ 670,7 672,10 @@ func (s *Session) Encode(v interface{}) error {
		return err
	}
	_, err = xmlstream.Copy(s.out.e, xml.NewDecoder(&b))
	return err
	if err != nil {
		return err
	}
	return s.out.e.Flush()
}

// EncodeElement writes the XML encoding of v to the stream, using start as the


@@ 691,7 696,10 @@ func (s *Session) EncodeElement(v interface{}, start xml.StartElement) error {
		return err
	}
	_, err = xmlstream.Copy(s.out.e, xml.NewDecoder(&b))
	return err
	if err != nil {
		return err
	}
	return s.out.e.Flush()
}

// Send transmits the first element read from the provided token reader.


@@ 734,7 742,7 @@ func (s *Session) SendElement(ctx context.Context, r xml.TokenReader, start xml.
	if err != nil {
		return err
	}
	return s.Flush()
	return s.out.e.Flush()
}

func iqNeedsResp(attrs []xml.Attr) bool {

M session_test.go => session_test.go +1 -1
@@ 59,7 59,7 @@ func TestClosedOutputStream(t *testing.T) {
			case mask&xmpp.OutputStreamClosed == 0 && err != nil:
				t.Errorf("Unexpected error: `%v'", err)
			}
			switch err := s.Flush(); {
			switch err := w.Flush(); {
			case mask&xmpp.OutputStreamClosed == xmpp.OutputStreamClosed && err != xmpp.ErrOutputStreamClosed:
				t.Errorf("Unexpected error flushing: want=`%v', got=`%v'", xmpp.ErrOutputStreamClosed, err)
			case mask&xmpp.OutputStreamClosed == 0 && err != nil: