~samwhited/xmpp

30563f99e5f371559dfe948fc440268574e0b67c — Sam Whited 4 years ago 37d8a6f
Save the features list on Conn
3 files changed, 38 insertions(+), 35 deletions(-)

M conn.go
M features.go
M stream.go
M conn.go => conn.go +4 -0
@@ 27,6 27,9 @@ type Conn struct {
	// server did not assign us the resourcepart we requested, this is canonical).
	origin *jid.JID

	// The stream features advertised for the current streams.
	features map[xml.Name]struct{}

	in struct {
		stream
		d *xml.Decoder


@@ 45,6 48,7 @@ func NewConn(ctx context.Context, config *Config, rwc io.ReadWriteCloser) (*Conn
	c := &Conn{
		config: config,
		rwc:    rwc,
		state:  StreamRestartRequired,
	}

	return c, c.negotiateStreams(ctx)

M features.go => features.go +1 -0
@@ 167,6 167,7 @@ parsefeatures:
			// If the token is a new feature, see if it's one we handle. If so, parse
			// it. Increment the total features count regardless.
			sf.total += 1
			conn.features[tok.Name] = struct{}{}
			if feature, ok := conn.config.Features[tok.Name]; ok && (conn.state&feature.Necessary) == feature.Necessary && (conn.state&feature.Prohibited) == 0 {
				req, data, err := feature.Parse(ctx, conn.in.d, &tok)
				if err != nil {

M stream.go => stream.go +33 -35
@@ 142,20 142,17 @@ func sendNewStream(w io.Writer, cfg *Config, id string) error {
	}

	if conn, ok := w.(*Conn); ok {
		conn.state &= ^StreamRestartRequired
		conn.out.stream = stream
		conn.out.e = xml.NewEncoder(w)
	}
	return nil
}

func expectNewStream(ctx context.Context, r io.Reader) error {
	var foundHeader bool

	// If the reader is a Conn, use its decoder, otherwise make a new one.
	var d *xml.Decoder
	if conn, ok := r.(*Conn); ok {
		if conn.in.d == nil {
			conn.in.d = xml.NewDecoder(r)
		}
		d = conn.in.d
	} else {
		d = xml.NewDecoder(r)


@@ 198,11 195,7 @@ func expectNewStream(ctx context.Context, r io.Reader) error {
					// if we are the initiating entity and there is no stream ID…
					return streamerror.BadFormat
				}
				if (conn.state & StreamRestartRequired) == StreamRestartRequired {
					conn.state &= ^StreamRestartRequired
					conn.in.stream = stream
					conn.in.d = xml.NewDecoder(r)
				}
				conn.in.stream = stream
			}
			return nil
		case xml.ProcInst:


@@ 221,35 214,40 @@ func expectNewStream(ctx context.Context, r io.Reader) error {
}

func (c *Conn) negotiateStreams(ctx context.Context) (err error) {
restartstream:
	for {
		if (c.state & Received) == Received {
			if err = expectNewStream(ctx, c); err != nil {
				return err
			}
			if err = sendNewStream(c, c.config, internal.RandomID(streamIDLength)); err != nil {
				return err
			}
		} else {
			if err := sendNewStream(c, c.config, ""); err != nil {
				return err
			}
			if err := expectNewStream(ctx, c); err != nil {
				return err

	// Loop for as long as we're not done negotiating features or a stream restart
	// is still required.
	for done := false; !done || c.state&StreamRestartRequired == StreamRestartRequired; {
		if c.state&StreamRestartRequired == StreamRestartRequired {
			c.features = make(map[xml.Name]struct{})
			c.in.d = xml.NewDecoder(c.rwc)
			c.out.e = xml.NewEncoder(c.rwc)
			c.state &= ^StreamRestartRequired

			if (c.state & Received) == Received {
				// If we're the receiving entity wait for a new stream, then send one in
				// response.
				if err = expectNewStream(ctx, c); err != nil {
					return err
				}
				if err = sendNewStream(c, c.config, internal.RandomID(streamIDLength)); err != nil {
					return err
				}
			} else {
				// If we're the initiating entity, send a new stream and then wait for one
				// in response.
				if err := sendNewStream(c, c.config, ""); err != nil {
					return err
				}
				if err := expectNewStream(ctx, c); err != nil {
					return err
				}
			}
		}

		for done := false; !done; {
			done, err = c.negotiateFeatures(ctx)
			switch {
			case err != nil:
				return err
			case c.state&StreamRestartRequired == StreamRestartRequired:
				// If we require a stream restart, do so…
				continue restartstream
			}
		if done, err = c.negotiateFeatures(ctx); err != nil {
			return err
		}
		break
	}
	return nil
}