~samwhited/xmpp

3ee9e06fbe4cc1e7e8e3cf8cd8b6fcba3b333ab0 — Sam Whited 3 years ago 0d549e5
xmpp: remove TLSConfig from Config

See #38
3 files changed, 13 insertions(+), 22 deletions(-)

M config.go
M starttls.go
M starttls_test.go
M config.go => config.go +0 -5
@@ 5,8 5,6 @@
package xmpp

import (
	"crypto/tls"

	"golang.org/x/text/language"
	"mellium.im/xmpp/jid"
)


@@ 26,9 24,6 @@ type Config struct {
	// The default language for any streams constructed using this config.
	Lang language.Tag

	// TLS config for STARTTLS.
	TLSConfig *tls.Config

	// The authorization identity, and password to authenticate with.
	// Identity is used when a user wants to act on behalf of another user. For
	// instance, an admin might want to log in as another user to help them

M starttls.go => starttls.go +6 -10
@@ 28,7 28,7 @@ var (
// StartTLS returns a new stream feature that can be used for negotiating TLS.
// For StartTLS to work, the underlying connection must support TLS (it must
// implement net.Conn).
func StartTLS(required bool) StreamFeature {
func StartTLS(required bool, cfg *tls.Config) StreamFeature {
	return StreamFeature{
		Name:       xml.Name{Local: "starttls", Space: ns.StartTLS},
		Prohibited: Secure,


@@ 63,23 63,19 @@ func StartTLS(required bool) StreamFeature {
				return mask, nil, ErrTLSUpgradeFailed
			}

			config := session.Config()
			state := session.State()
			d := xml.NewTokenDecoder(session)

			// Fetch or create a TLSConfig to use.
			var tlsconf *tls.Config
			if config.TLSConfig == nil {
				tlsconf = &tls.Config{
			// If no TLSConfig was specified, use a default config.
			if cfg == nil {
				cfg = &tls.Config{
					ServerName: session.LocalAddr().Domain().String(),
				}
			} else {
				tlsconf = config.TLSConfig
			}

			if (state & Received) == Received {
				fmt.Fprint(conn, `<proceed xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>`)
				rw = tls.Server(conn, tlsconf)
				rw = tls.Server(conn, cfg)
			} else {
				// Select starttls for negotiation.
				fmt.Fprint(conn, `<starttls xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>`)


@@ 99,7 95,7 @@ func StartTLS(required bool) StreamFeature {
						if err = d.Skip(); err != nil {
							return mask, nil, stream.InvalidXML
						}
						rw = tls.Client(conn, tlsconf)
						rw = tls.Client(conn, cfg)
					case tok.Name.Local == "failure":
						// Skip the </failure> token.
						if err = d.Skip(); err != nil {

M starttls_test.go => starttls_test.go +7 -7
@@ 22,7 22,7 @@ import (
// through the list process token for token.
func TestStartTLSList(t *testing.T) {
	for _, req := range []bool{true, false} {
		stls := StartTLS(req)
		stls := StartTLS(req, nil)
		var b bytes.Buffer
		e := xml.NewEncoder(&b)
		start := xml.StartElement{Name: xml.Name{Space: ns.StartTLS, Local: "starttls"}}


@@ 81,7 81,7 @@ func TestStartTLSList(t *testing.T) {
}

func TestStartTLSParse(t *testing.T) {
	stls := StartTLS(true)
	stls := StartTLS(true, nil)
	for _, test := range []struct {
		msg string
		req bool


@@ 147,7 147,7 @@ func (dummyConn) SetWriteDeadline(t time.Time) error {
// We can't create a tls.Client or tls.Server for a generic ReadWriter, so
// ensure that we fail (with a specific error) if this is the case.
func TestNegotiationFailsForNonNetSession(t *testing.T) {
	stls := StartTLS(true)
	stls := StartTLS(true, nil)
	var b bytes.Buffer
	_, _, err := stls.Negotiate(context.Background(), &Session{rw: nopRWC{&b, &b}}, nil)
	if err != ErrTLSUpgradeFailed {


@@ 156,9 156,9 @@ func TestNegotiationFailsForNonNetSession(t *testing.T) {
}

func TestNegotiateServer(t *testing.T) {
	stls := StartTLS(true)
	stls := StartTLS(true, &tls.Config{})
	var b bytes.Buffer
	c := &Session{state: Received, conn: dummyConn{nopRWC{&b, &b}}, config: &Config{TLSConfig: &tls.Config{}}}
	c := &Session{state: Received, conn: dummyConn{nopRWC{&b, &b}}}
	c.rw = c.conn
	_, rw, err := stls.Negotiate(context.Background(), c, nil)
	switch {


@@ 194,10 194,10 @@ func TestNegotiateClient(t *testing.T) {
		{[]string{`<notproceedorfailure xmlns='urn:ietf:params:xml:ns:xmpp-tls'/>`}, true, false, 0},
		{[]string{`chardata not start element`}, true, false, 0},
	} {
		stls := StartTLS(true)
		stls := StartTLS(true, &tls.Config{})
		r := strings.NewReader(strings.Join(test.responses, "\n"))
		var b bytes.Buffer
		c := &Session{conn: dummyConn{nopRWC{r, &b}}, config: &Config{TLSConfig: &tls.Config{}}}
		c := &Session{conn: dummyConn{nopRWC{r, &b}}}
		c.rw = c.conn
		c.in.d = xml.NewDecoder(c.rw)
		mask, rw, err := stls.Negotiate(context.Background(), c, nil)