~samwhited/xmpp

9bc84e972a3cf6b1af4a97af8b041cfeb4f3b0ac — Sam Whited 5 years ago 0504424
Don't special case the stream:stream wrapper

And don't recurse when negotiating inner streams, which could have
possibly lead to a malicious server being able to overflow the stack
2 files changed, 35 insertions(+), 31 deletions(-)

M features.go
M stream.go
M features.go => features.go +2 -2
@@ 145,7 145,7 @@ func readStreamFeatures(ctx context.Context, conn *Conn, start xml.StartElement)
	switch {
	case start.Name.Local != "features":
		return nil, InvalidXML
	case start.Name.Space != "stream":
	case start.Name.Space != NSStream:
		return nil, BadNamespacePrefix
	}



@@ 184,7 184,7 @@ parsefeatures:
				return nil, err
			}
		case xml.EndElement:
			if tok.Name.Local == "features" && tok.Name.Space == "stream" {
			if tok.Name.Local == "features" && tok.Name.Space == NSStream {
				// We've reached the end of the features list!
				return sf, nil
			}

M stream.go => stream.go +33 -29
@@ 164,7 164,7 @@ func expectNewStream(ctx context.Context, r io.Reader) error {
			return ctx.Err()
		default:
		}
		t, err := d.RawToken()
		t, err := d.Token()
		if err != nil {
			return err
		}


@@ 173,7 173,7 @@ func expectNewStream(ctx context.Context, r io.Reader) error {
			switch {
			case tok.Name.Local != "stream":
				return BadFormat
			case tok.Name.Space != "stream":
			case tok.Name.Space != NSStream:
				return InvalidNamespace
			}



@@ 190,9 190,11 @@ func expectNewStream(ctx context.Context, r io.Reader) error {
					// if we are the initiating entity and there is no stream ID…
					return BadFormat
				}
				conn.state &= ^StreamRestartRequired
				conn.in.stream = stream
				conn.in.d = xml.NewDecoder(r)
				if (conn.state & StreamRestartRequired) == StreamRestartRequired {
					conn.state &= ^StreamRestartRequired
					conn.in.stream = stream
					conn.in.d = xml.NewDecoder(r)
				}
			}
			return nil
		case xml.ProcInst:


@@ 211,33 213,35 @@ func expectNewStream(ctx context.Context, r io.Reader) error {
}

func (c *Conn) negotiateStreams(ctx context.Context) (err error) {
	if (c.state & Received) == Received {
		if err = expectNewStream(ctx, c); err != nil {
			return err
		}
		if err = sendNewStream(c, c.config, internal.RandomID(streamIDLength)); err != nil {
			return err
		}
	} else {
		if err := sendNewStream(c, c.config, ""); err != nil {
			return err
		}
		if err := expectNewStream(ctx, c); err != nil {
			return err
restartstream:
	for {
		if (c.state & Received) == Received {
			if err = expectNewStream(ctx, c); err != nil {
				return err
			}
			if err = sendNewStream(c, c.config, internal.RandomID(streamIDLength)); err != nil {
				return err
			}
		} else {
			if err := sendNewStream(c, c.config, ""); err != nil {
				return err
			}
			if err := expectNewStream(ctx, c); err != nil {
				return err
			}
		}
	}

	for done := false; !done; done, err = c.negotiateFeatures(ctx) {
		switch {
		case err != nil:
			return err
		case c.state&StreamRestartRequired == StreamRestartRequired:
			// If we require a stream restart, do so…

			// BUG(ssw): Negotiating streams can lead to a stack overflow when
			//           connecting to a malicious endpoint.
			return c.negotiateStreams(ctx)
		for done := false; !done; {
			done, err = c.negotiateFeatures(ctx)
			switch {
			case err != nil:
				return err
			case c.state&StreamRestartRequired == StreamRestartRequired:
				// If we require a stream restart, do so…
				continue restartstream
			}
		}
		break
	}
	panic("xmpp: Not yet implemented.")
}