~samwhited/xmpp

e8241b83e29bc48f068c167b9910dc2108c7d464 — Sam Whited 5 years ago fac775d
Add NoLookup (no SRV or TXT lookup) to Dialer

Also use DialContext on Go 1.7 and avoid annoying hacks
5 files changed, 188 insertions(+), 79 deletions(-)

M config.go
M dial.go
A dial1_6.go
A dial1_7.go
M lookup.go
M config.go => config.go +7 -0
@@ 61,3 61,10 @@ func NewServerConfig(location, origin *jid.JID) *Config {
		},
	}
}

func (config *Config) connType() string {
	if config.S2S {
		return "xmpp-server"
	}
	return "xmpp-client"
}

M dial.go => dial.go +5 -67
@@ 6,7 6,6 @@ package xmpp

import (
	"net"
	"strconv"
	"time"

	"golang.org/x/net/context"


@@ 56,6 55,11 @@ func Dial(ctx context.Context, network string, config *Config) (*Conn, error) {
// the DialClient function.
type Dialer struct {
	net.Dialer

	// NoLookup stops the dialer from looking up SRV or TXT records for the given
	// domain. It also prevents fetching of the host metadata file.
	// Instead, it will try to connect to the domain directly.
	NoLookup bool
}

// Copied from the net package in the standard library. Copyright The Go


@@ 88,13 92,6 @@ func (d *Dialer) deadline(ctx context.Context, now time.Time) (earliest time.Tim
	return minNonzeroTime(earliest, d.Deadline)
}

func connType(config *Config) string {
	if config.S2S {
		return "xmpp-server"
	}
	return "xmpp-client"
}

// DialClient discovers and connects to the address on the named network that
// services the given local address with a client-to-server (c2s) connection.
//


@@ 126,62 123,3 @@ func (d *Dialer) Dial(

	return c, err
}

func (d *Dialer) dial(
	ctx context.Context, network string, config *Config) (*Conn, error) {
	if ctx == nil {
		panic("xmpp.Dial: nil context")
	}

	deadline := d.deadline(ctx, time.Now())
	if !deadline.IsZero() {
		if d, ok := ctx.Deadline(); !ok || deadline.Before(d) {
			subCtx, cancel := context.WithDeadline(ctx, deadline)
			defer cancel()
			ctx = subCtx
		}
	}
	if oldCancel := d.Cancel; oldCancel != nil {
		subCtx, cancel := context.WithCancel(ctx)
		defer cancel()
		go func() {
			select {
			case <-oldCancel:
				cancel()
			case <-subCtx.Done():
			}
		}()
		ctx = subCtx
	}

	c := &Conn{
		config: config,
	}

	addrs, err := lookupService(connType(config), c.RemoteAddr())
	if err != nil {
		return nil, err
	}

	// Try dialing all of the SRV records we know about, breaking as soon as the
	// connection is established.
	for _, addr := range addrs {
		if conn, e := d.Dialer.Dial(
			network, net.JoinHostPort(
				addr.Target, strconv.FormatUint(uint64(addr.Port), 10),
			),
		); e != nil {
			err = e
			continue
		} else {
			err = nil
			c.rwc = conn
			break
		}
	}
	if err != nil {
		return nil, err
	}

	return c, nil
}

A dial1_6.go => dial1_6.go +90 -0
@@ 0,0 1,90 @@
// Copyright 2016 Sam Whited.
// Use of this source code is governed by the BSD 2-clause license that can be
// found in the LICENSE file.

// +build !go1.7

package xmpp

import (
	"net"
	"strconv"

	"golang.org/x/net/context"
)

func (d *Dialer) dial(
	ctx context.Context, network string, config *Config) (*Conn, error) {
	if ctx == nil {
		panic("xmpp.Dial: nil context")
	}

	// Backwards compatibility with old net.Dialer cancelation methods.
	deadline := d.deadline(ctx, time.Now())
	if !deadline.IsZero() {
		if d, ok := ctx.Deadline(); !ok || deadline.Before(d) {
			subCtx, cancel := context.WithDeadline(ctx, deadline)
			defer cancel()
			ctx = subCtx
		}
	}
	if oldCancel := d.Cancel; oldCancel != nil {
		subCtx, cancel := context.WithCancel(ctx)
		defer cancel()
		go func() {
			select {
			case <-oldCancel:
				cancel()
			case <-subCtx.Done():
			}
		}()
		ctx = subCtx
	}

	c := &Conn{
		config: config,
	}

	if d.NoLookup {
		p, err := lookupPort(network, config.connType())
		if err != nil {
			return nil, err
		}
		conn, err := d.Dialer.Dial(network, net.JoinHostPort(
			config.Location.Domainpart(),
			strconv.FormatUint(uint64(p), 10),
		))
		if err != nil {
			return nil, err
		}
		c.rwc = conn
		return c, nil
	}

	addrs, err := lookupService(config.connType(), network, c.RemoteAddr())
	if err != nil {
		return nil, err
	}

	// Try dialing all of the SRV records we know about, breaking as soon as the
	// connection is established.
	for _, addr := range addrs {
		if conn, e := d.Dialer.Dial(
			network, net.JoinHostPort(
				addr.Target, strconv.FormatUint(uint64(addr.Port), 10),
			),
		); e != nil {
			err = e
			continue
		} else {
			err = nil
			c.rwc = conn
			break
		}
	}
	if err != nil {
		return nil, err
	}

	return c, nil
}

A dial1_7.go => dial1_7.go +67 -0
@@ 0,0 1,67 @@
// Copyright 2016 Sam Whited.
// Use of this source code is governed by the BSD 2-clause license that can be
// found in the LICENSE file.

// +build go1.7

package xmpp

import (
	"context"
	"net"
	"strconv"
)

func (d *Dialer) dial(
	ctx context.Context, network string, config *Config) (*Conn, error) {
	if ctx == nil {
		panic("xmpp.Dial: nil context")
	}

	c := &Conn{
		config: config,
	}

	if d.NoLookup {
		p, err := lookupPort(network, config.connType())
		if err != nil {
			return nil, err
		}
		conn, err := d.Dialer.DialContext(ctx, network, net.JoinHostPort(
			config.Location.Domainpart(),
			strconv.FormatUint(uint64(p), 10),
		))
		if err != nil {
			return nil, err
		}
		c.rwc = conn
		return c, nil
	}

	addrs, err := lookupService(config.connType(), network, c.RemoteAddr())
	if err != nil {
		return nil, err
	}

	// Try dialing all of the SRV records we know about, breaking as soon as the
	// connection is established.
	for _, addr := range addrs {
		if conn, e := d.Dialer.DialContext(
			ctx, network, net.JoinHostPort(
				addr.Target, strconv.FormatUint(uint64(addr.Port), 10),
			),
		); e != nil {
			err = e
			continue
		} else {
			err = nil
			c.rwc = conn
			break
		}
	}
	if err != nil {
		return nil, err
	}

	return c, nil
}

M lookup.go => lookup.go +19 -12
@@ 39,11 39,27 @@ var (
	}
)

func lookupPort(network, service string) (int, error) {
	p, err := net.LookupPort(network, service)
	if err == nil {
		return p, err
	}
	switch service {
	case "xmpp-client":
		return 5222, nil
	case "xmpp-server":
		return 5269, nil
	case "xmpp-bosh":
		return 5280, nil
	}
	return 0, err
}

// lookupService looks for an XMPP service hosted by the given address. It
// returns addresses from SRV records or the default domain (as a fake SRV
// record) if no real records exist. Service should be one of "xmpp-client" or
// "xmpp-server".
func lookupService(service string, addr net.Addr) (addrs []*net.SRV, err error) {
func lookupService(service, network string, addr net.Addr) (addrs []*net.SRV, err error) {
	switch j := addr.(type) {
	case *jid.JID:
		addr = j.Domain()


@@ 64,18 80,9 @@ func lookupService(service string, addr net.Addr) (addrs []*net.SRV, err error) 
	}

	// Use domain and default port.
	p, err := net.LookupPort("tcp", service)
	p, err := lookupPort(network, service)
	if err != nil {
		switch service {
		case "xmpp-client":
			p = 5222
		case "xmpp-server":
			p = 5269
		case "xmpp-bosh":
			p = 5280
		default:
			return nil, err
		}
		return nil, err
	}
	addrs = []*net.SRV{{
		Target: addr.String(),