~samwhited/xmpp

72b4e29c50c53180f222596fbdcdfd580af19ba1 — Sam Whited 3 months ago 17665ea
xmpp: add RW mutex around session state

This fixes a race condition that could result in writes to a closed
stream, for example.

Signed-off-by: Sam Whited <sam@samwhited.com>
2 files changed, 31 insertions(+), 1 deletions(-)

M CHANGELOG.md
M session.go
M CHANGELOG.md => CHANGELOG.md +2 -0
@@ 24,6 24,8 @@ All notable changes to this project will be documented in this file.
### Fixed

- stanza: converting stanzas with empty to/from attributes no longer fails
- xmpp: fixed data race that could result in invalid session state and lead to
  writes on a closed session and other state related issues


## v0.17.0 — 2020-11-11

M session.go => session.go +29 -1
@@ 75,7 75,8 @@ type Session struct {
	conn      net.Conn
	connState func() tls.ConnectionState

	state SessionState
	state      SessionState
	stateMutex sync.RWMutex

	origin   jid.JID
	location jid.JID


@@ 134,6 135,8 @@ func NegotiateSession(ctx context.Context, location, origin jid.JID, rw io.ReadW
		s.connState = tc.ConnectionState
	}
	if received {
		// We don't need to lock the state mutex yet since we haven't returned
		// session and nothing else can access it.
		s.state |= Received
	}
	s.out.Locker = &sync.Mutex{}


@@ 289,9 292,12 @@ func (s *Session) sendError(err error) (e error) {
	s.out.Lock()
	defer s.out.Unlock()

	s.stateMutex.RLock()
	if s.state&OutputStreamClosed == OutputStreamClosed {
		s.stateMutex.RUnlock()
		return err
	}
	s.stateMutex.RUnlock()

	switch typErr := err.(type) {
	case stream.Error:


@@ 509,9 515,12 @@ func (lwc *lockWriteCloser) EncodeToken(t xml.Token) error {
		return lwc.err
	}

	lwc.w.stateMutex.RLock()
	if lwc.w.state&OutputStreamClosed == OutputStreamClosed {
		lwc.w.stateMutex.RUnlock()
		return ErrOutputStreamClosed
	}
	lwc.w.stateMutex.RUnlock()

	return lwc.w.out.e.EncodeToken(t)
}


@@ 520,9 529,12 @@ func (lwc *lockWriteCloser) Flush() error {
	if lwc.err != nil {
		return nil
	}
	lwc.w.stateMutex.RLock()
	if lwc.w.state&OutputStreamClosed == OutputStreamClosed {
		lwc.w.stateMutex.RUnlock()
		return ErrOutputStreamClosed
	}
	lwc.w.stateMutex.RUnlock()
	return lwc.w.out.e.Flush()
}



@@ 546,9 558,12 @@ func (lrc *lockReadCloser) Token() (xml.Token, error) {
		return nil, lrc.err
	}

	lrc.s.stateMutex.RLock()
	if lrc.s.state&InputStreamClosed == InputStreamClosed {
		lrc.s.stateMutex.RUnlock()
		return nil, ErrInputStreamClosed
	}
	lrc.s.stateMutex.RUnlock()

	return lrc.s.in.d.Token()
}


@@ 604,10 619,15 @@ func (s *Session) Close() error {
}

func (s *Session) closeSession() error {
	s.stateMutex.RLock()
	if s.state&OutputStreamClosed == OutputStreamClosed {
		s.stateMutex.RUnlock()
		return nil
	}
	s.stateMutex.RUnlock()

	s.stateMutex.Lock()
	defer s.stateMutex.Unlock()
	s.state |= OutputStreamClosed
	// We wrote the opening stream instead of encoding it, so do the same with the
	// closing to ensure that the encoder doesn't think the tokens are mismatched.


@@ 618,12 638,16 @@ func (s *Session) closeSession() error {
// State returns the current state of the session. For more information, see the
// SessionState type.
func (s *Session) State() SessionState {
	s.stateMutex.RLock()
	defer s.stateMutex.RUnlock()
	return s.state
}

// LocalAddr returns the Origin address for initiated connections, or the
// Location for received connections.
func (s *Session) LocalAddr() jid.JID {
	s.stateMutex.RLock()
	defer s.stateMutex.RUnlock()
	if (s.state & Received) == Received {
		return s.location
	}


@@ 633,6 657,8 @@ func (s *Session) LocalAddr() jid.JID {
// RemoteAddr returns the Location address for initiated connections, or the
// Origin address for received connections.
func (s *Session) RemoteAddr() jid.JID {
	s.stateMutex.RLock()
	defer s.stateMutex.RUnlock()
	if (s.state & Received) == Received {
		return s.origin
	}


@@ 863,6 889,8 @@ func (s *Session) sendResp(ctx context.Context, id string, payload xml.TokenRead
func (s *Session) closeInputStream() {
	s.in.Lock()
	defer s.in.Unlock()
	s.stateMutex.Lock()
	defer s.stateMutex.Unlock()
	s.state |= InputStreamClosed
	s.in.cancel()
}