~samwhited/xmpp

571fcbc05252cc760f5fad23ffe78911703ead0b — Sam Whited 5 years ago 7424ea1
Refactor writing stream features
1 files changed, 89 insertions(+), 83 deletions(-)

M features.go
M features.go => features.go +89 -83
@@ 65,105 65,76 @@ type StreamFeature struct {
	Negotiate func(ctx context.Context, conn *Conn, data interface{}) (mask SessionState, rwc io.ReadWriteCloser, err error)
}

// Returns the number of stream features written (zero means we've reached the
// end of negotiation), and the number of required features written (zero means
// we've potentially reached the end of negotiation, but the client may
// negotiate more optional features).
func writeStreamFeatures(ctx context.Context, conn *Conn) (n int, req int, err error) {
	if _, err = fmt.Fprint(conn, `<stream:features>`); err != nil {
		return
	}
	for _, feature := range conn.config.Features {
		// Check if all the necessary bits are set and none of the prohibited bits
		// are set.
		if (conn.state&feature.Necessary) == feature.Necessary && (conn.state&feature.Prohibited) == 0 {
			var r bool
			r, err = feature.List(ctx, conn.out.e, xml.StartElement{
				Name: feature.Name,
			})
			if err != nil {
				return
			}
			if r {
				req++
			}
			n++
		}
	}
	_, err = fmt.Fprint(conn, `</stream:features>`)
	return
}

func (c *Conn) negotiateFeatures(ctx context.Context) (done bool, rwc io.ReadWriteCloser, err error) {
	if (c.state & Received) == Received {
		_, _, err = writeStreamFeatures(ctx, c)
		_, err = writeStreamFeatures(ctx, c)
		if err != nil {
			return false, nil, err
		}
		panic("Sending stream:features not yet implemented")
	} else {
		t, err := c.in.d.Token()
		if err != nil {
			return done, nil, err
		}
		start, ok := t.(xml.StartElement)
		if !ok {
			return done, nil, streamerror.BadFormat
		}
		list, err := readStreamFeatures(ctx, c, start)

		switch {
		case err != nil:
			return done, nil, err
		case list.total == 0 || len(list.cache) == 0:
			// If we received an empty list (or one with no supported features), we're
			// done.
			return true, nil, nil
		}
	}

		// If the list has any optional items that we support, negotiate them first
		// before moving on to the required items.
		for {
			var data sfData
			for _, v := range list.cache {
				if _, ok := c.negotiated[v.feature.Name]; ok {
					// If this feature has already been negotiated, skip it (servers
					// shouldn't list them in this case, but you never know).
					continue
				}
	t, err := c.in.d.Token()
	if err != nil {
		return done, nil, err
	}
	start, ok := t.(xml.StartElement)
	if !ok {
		return done, nil, streamerror.BadFormat
	}
	list, err := readStreamFeatures(ctx, c, start)

				// If the feature is optional, select it.
				if !v.req {
					data = v
					break
				}
	switch {
	case err != nil:
		return done, nil, err
	case list.total == 0 || len(list.cache) == 0:
		// If we received an empty list (or one with no supported features), we're
		// done.
		return true, nil, nil
	}

				// If the feature is required, tentatively select it (but finish looking
				// for optional features).
				if v.req {
					data = v
				}
			}
			// No features that haven't already been negotiated were sent… we're done.
			if data.feature.Name.Local == "" {
				return true, nil, nil
			}
			var mask SessionState
			mask, rwc, err = data.feature.Negotiate(ctx, c, data.data)
			if err == nil {
				c.state |= mask
	// If the list has any optional items that we support, negotiate them first
	// before moving on to the required items.
	for {
		var data sfData
		for _, v := range list.cache {
			if _, ok := c.negotiated[v.feature.Name]; ok {
				// If this feature has already been negotiated, skip it (servers
				// shouldn't list them in this case, but you never know).
				continue
			}
			c.negotiated[data.feature.Name] = struct{}{}

			// If we negotiated a required feature or a stream restart is required
			// we're done with this feature set.
			if rwc != nil || data.req {
			// If the feature is optional, select it.
			if !v.req {
				data = v
				break
			}

			// If the feature is required, tentatively select it (but finish looking
			// for optional features).
			if v.req {
				data = v
			}
		}
		// No features that haven't already been negotiated were sent… we're done.
		if data.feature.Name.Local == "" {
			return true, nil, nil
		}
		var mask SessionState
		mask, rwc, err = data.feature.Negotiate(ctx, c, data.data)
		if err == nil {
			c.state |= mask
		}
		c.negotiated[data.feature.Name] = struct{}{}

		return !list.req || (c.state&Ready == Ready), rwc, err
		// If we negotiated a required feature or a stream restart is required
		// we're done with this feature set.
		if rwc != nil || data.req {
			break
		}
	}

	return !list.req || (c.state&Ready == Ready), rwc, err
}

type sfData struct {


@@ 178,6 149,41 @@ type streamFeaturesList struct {
	cache map[xml.Name]sfData
}

func writeStreamFeatures(ctx context.Context, conn *Conn) (list *streamFeaturesList, err error) {
	if _, err = fmt.Fprint(conn, `<stream:features>`); err != nil {
		return
	}
	// Lock the connection features list.
	list = &streamFeaturesList{
		cache: make(map[xml.Name]sfData),
	}

	for _, feature := range conn.config.Features {
		// Check if all the necessary bits are set and none of the prohibited bits
		// are set.
		if (conn.state&feature.Necessary) == feature.Necessary && (conn.state&feature.Prohibited) == 0 {
			var r bool
			r, err = feature.List(ctx, conn.out.e, xml.StartElement{
				Name: feature.Name,
			})
			if err != nil {
				return
			}
			list.cache[feature.Name] = sfData{
				req:     r,
				data:    nil,
				feature: feature,
			}
			if r {
				list.req = true
			}
			list.total++
		}
	}
	_, err = fmt.Fprint(conn, `</stream:features>`)
	return
}

func readStreamFeatures(ctx context.Context, conn *Conn, start xml.StartElement) (*streamFeaturesList, error) {
	switch {
	case start.Name.Local != "features":