~samwhited/xmpp

466c97d10966768f341a036a03849a31e335d7fb — Sam Whited 4 years ago e648226
Lookup stream features by namespace, not xmlname
3 files changed, 21 insertions(+), 19 deletions(-)

M conn.go
M features.go
M stream.go
M conn.go => conn.go +7 -7
@@ 33,12 33,12 @@ type Conn struct {
	// server did not assign us the resourcepart we requested, this is canonical).
	origin *jid.JID

	// The stream features advertised for the current streams.
	features map[xml.Name]struct{}
	// The stream feature namespaces advertised for the current streams.
	features map[string]struct{}
	flock    sync.Mutex

	// The negotiated features for the current session.
	negotiated map[xml.Name]struct{}
	// The negotiated features (by namespace) for the current session.
	negotiated map[string]struct{}

	in struct {
		sync.Mutex


@@ 52,9 52,9 @@ type Conn struct {
	}
}

// Features returns a set of the currently available stream features (including
// those that have already been negotiated).
func (c *Conn) Features() map[xml.Name]struct{} {
// Features returns a set of the currently available stream features namespaces
// (including namespaces for features that have already been negotiated).
func (c *Conn) Features() map[string]struct{} {
	c.flock.Lock()
	defer c.flock.Unlock()


M features.go => features.go +12 -10
@@ 125,8 125,8 @@ func (c *Conn) negotiateFeatures(ctx context.Context) (done bool, rwc io.ReadWri

			// If the feature was not sent or was already negotiated, error.

			_, negotiated := c.negotiated[start.Name]
			data, sent = list.cache[start.Name]
			_, negotiated := c.negotiated[start.Name.Space]
			data, sent = list.cache[start.Name.Space]
			if !sent || negotiated {
				// TODO: What should we return here?
				return done, rwc, streamerror.PolicyViolation


@@ 135,7 135,7 @@ func (c *Conn) negotiateFeatures(ctx context.Context) (done bool, rwc io.ReadWri
			// If we're the client, iterate through the cached features and select one
			// to negotiate.
			for _, v := range list.cache {
				if _, ok := c.negotiated[v.feature.Name]; ok {
				if _, ok := c.negotiated[v.feature.Name.Space]; ok {
					// If this feature has already been negotiated, skip it.
					continue
				}


@@ 163,7 163,7 @@ func (c *Conn) negotiateFeatures(ctx context.Context) (done bool, rwc io.ReadWri
		if err == nil {
			c.state |= mask
		}
		c.negotiated[data.feature.Name] = struct{}{}
		c.negotiated[data.feature.Name.Space] = struct{}{}

		// If we negotiated a required feature or a stream restart is required
		// we're done with this feature set.


@@ 184,7 184,9 @@ type sfData struct {
type streamFeaturesList struct {
	total int
	req   bool
	cache map[xml.Name]sfData

	// Namespace to sfData
	cache map[string]sfData
}

func writeStreamFeatures(ctx context.Context, conn *Conn) (list *streamFeaturesList, err error) {


@@ 197,7 199,7 @@ func writeStreamFeatures(ctx context.Context, conn *Conn) (list *streamFeaturesL

	// Lock the connection features list.
	list = &streamFeaturesList{
		cache: make(map[xml.Name]sfData),
		cache: make(map[string]sfData),
	}

	for _, feature := range conn.config.Features {


@@ 211,7 213,7 @@ func writeStreamFeatures(ctx context.Context, conn *Conn) (list *streamFeaturesL
			if err != nil {
				return
			}
			list.cache[feature.Name] = sfData{
			list.cache[feature.Name.Space] = sfData{
				req:     r,
				data:    nil,
				feature: feature,


@@ 244,7 246,7 @@ func readStreamFeatures(ctx context.Context, conn *Conn, start xml.StartElement)
	defer conn.flock.Unlock()

	sf := &streamFeaturesList{
		cache: make(map[xml.Name]sfData),
		cache: make(map[string]sfData),
	}

parsefeatures:


@@ 258,13 260,13 @@ parsefeatures:
			// If the token is a new feature, see if it's one we handle. If so, parse
			// it. Increment the total features count regardless.
			sf.total++
			conn.features[tok.Name] = struct{}{}
			conn.features[tok.Name.Space] = struct{}{}
			if feature, ok := conn.config.Features[tok.Name]; ok && (conn.state&feature.Necessary) == feature.Necessary && (conn.state&feature.Prohibited) == 0 {
				req, data, err := feature.Parse(ctx, conn.in.d, &tok)
				if err != nil {
					return nil, err
				}
				sf.cache[tok.Name] = sfData{
				sf.cache[tok.Name.Space] = sfData{
					req:     req,
					data:    data,
					feature: feature,

M stream.go => stream.go +2 -2
@@ 215,8 215,8 @@ func (c *Conn) negotiateStreams(ctx context.Context, rwc io.ReadWriteCloser) (er
	// is still required.
	for done := false; !done || rwc != nil; {
		if rwc != nil {
			c.features = make(map[xml.Name]struct{})
			c.negotiated = make(map[xml.Name]struct{})
			c.features = make(map[string]struct{})
			c.negotiated = make(map[string]struct{})
			c.rwc = rwc
			c.in.d = xml.NewDecoder(c.rwc)
			c.out.e = xml.NewEncoder(c.rwc)