~stepbrobd/tailscale

43138c7a5c8815ea104499866440e34bb1220e93 — VimT 2 months ago b0626ff
net/socks5: optimize UDP relay

Key changes:
- No mutex for every udp package: replace syncs.Map with regular map for udpTargetConns
- Use socksAddr as map key for better type safety
- Add test for multi udp target

Updates #7581

Change-Id: Ic3d384a9eab62dcbf267d7d6d268bf242cc8ed3c
Signed-off-by: VimT <me@vimt.me>
2 files changed, 119 insertions(+), 99 deletions(-)

M net/socks5/socks5.go
M net/socks5/socks5_test.go
M net/socks5/socks5.go => net/socks5/socks5.go +25 -27
@@ 22,7 22,6 @@ import (
	"log"
	"net"
	"strconv"
	"tailscale.com/syncs"
	"time"

	"tailscale.com/types/logger"


@@ 151,7 150,7 @@ type Conn struct {
	request    *request

	udpClientAddr  net.Addr
	udpTargetConns syncs.Map[string, net.Conn]
	udpTargetConns map[socksAddr]net.Conn
}

// Run starts the new connection.


@@ 311,17 310,18 @@ func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn) er
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()

	// close all target udp connections when the client connection is closed
	defer func() {
		c.udpTargetConns.Range(func(_ string, conn net.Conn) bool {
			_ = conn.Close()
			return true
		})
	}()

	// client -> target
	go func() {
		defer cancel()

		c.udpTargetConns = make(map[socksAddr]net.Conn)
		// close all target udp connections when the client connection is closed
		defer func() {
			for _, conn := range c.udpTargetConns {
				_ = conn.Close()
			}
		}()

		buf := make([]byte, bufferSize)
		for {
			select {


@@ 354,33 354,27 @@ func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn) er
func (c *Conn) getOrDialTargetConn(
	ctx context.Context,
	clientConn net.PacketConn,
	targetAddr string,
	targetAddr socksAddr,
) (net.Conn, error) {
	host, port, err := splitHostPort(targetAddr)
	if err != nil {
		return nil, err
	}

	conn, loaded := c.udpTargetConns.Load(targetAddr)
	if loaded {
	conn, exist := c.udpTargetConns[targetAddr]
	if exist {
		return conn, nil
	}
	conn, err = c.srv.dial(ctx, "udp", targetAddr)
	conn, err := c.srv.dial(ctx, "udp", targetAddr.hostPort())
	if err != nil {
		return nil, err
	}
	c.udpTargetConns.Store(targetAddr, conn)
	c.udpTargetConns[targetAddr] = conn

	// target -> client
	go func() {
		buf := make([]byte, bufferSize)
		addr := socksAddr{addrType: getAddrType(host), addr: host, port: port}
		for {
			select {
			case <-ctx.Done():
				return
			default:
				err := c.handleUDPResponse(clientConn, addr, conn, buf)
				err := c.handleUDPResponse(clientConn, targetAddr, conn, buf)
				if err != nil {
					if isTimeout(err) {
						continue


@@ 414,18 408,17 @@ func (c *Conn) handleUDPRequest(
		return fmt.Errorf("parse udp request: %w", err)
	}

	targetAddr := req.addr.hostPort()
	targetConn, err := c.getOrDialTargetConn(ctx, clientConn, targetAddr)
	targetConn, err := c.getOrDialTargetConn(ctx, clientConn, req.addr)
	if err != nil {
		return fmt.Errorf("dial target %s fail: %w", targetAddr, err)
		return fmt.Errorf("dial target %s fail: %w", req.addr, err)
	}

	nn, err := targetConn.Write(data)
	if err != nil {
		return fmt.Errorf("write to target %s fail: %w", targetAddr, err)
		return fmt.Errorf("write to target %s fail: %w", req.addr, err)
	}
	if nn != len(data) {
		return fmt.Errorf("write to target %s fail: %w", targetAddr, io.ErrShortWrite)
		return fmt.Errorf("write to target %s fail: %w", req.addr, io.ErrShortWrite)
	}
	return nil
}


@@ 652,10 645,15 @@ func (s socksAddr) marshal() ([]byte, error) {
	pkt = binary.BigEndian.AppendUint16(pkt, s.port)
	return pkt, nil
}

func (s socksAddr) hostPort() string {
	return net.JoinHostPort(s.addr, strconv.Itoa(int(s.port)))
}

func (s socksAddr) String() string {
	return s.hostPort()
}

// response contains the contents of
// a response packet sent from the proxy
// to the client.

M net/socks5/socks5_test.go => net/socks5/socks5_test.go +94 -72
@@ 169,12 169,25 @@ func TestReadPassword(t *testing.T) {

func TestUDP(t *testing.T) {
	// backend UDP server which we'll use SOCKS5 to connect to
	listener, err := net.ListenPacket("udp", ":0")
	if err != nil {
		t.Fatal(err)
	newUDPEchoServer := func() net.PacketConn {
		listener, err := net.ListenPacket("udp", ":0")
		if err != nil {
			t.Fatal(err)
		}
		go udpEchoServer(listener)
		return listener
	}
	backendServerPort := listener.LocalAddr().(*net.UDPAddr).Port
	go udpEchoServer(listener)

	const echoServerNumber = 3
	echoServerListener := make([]net.PacketConn, echoServerNumber)
	for i := 0; i < echoServerNumber; i++ {
		echoServerListener[i] = newUDPEchoServer()
	}
	defer func() {
		for i := 0; i < echoServerNumber; i++ {
			_ = echoServerListener[i].Close()
		}
	}()

	// SOCKS5 server
	socks5, err := net.Listen("tcp", ":0")


@@ 184,84 197,93 @@ func TestUDP(t *testing.T) {
	socks5Port := socks5.Addr().(*net.TCPAddr).Port
	go socks5Server(socks5)

	// net/proxy don't support UDP, so we need to manually send the SOCKS5 UDP request
	conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", socks5Port))
	if err != nil {
		t.Fatal(err)
	}
	_, err = conn.Write([]byte{0x05, 0x01, 0x00}) // client hello with no auth
	if err != nil {
		t.Fatal(err)
	}
	buf := make([]byte, 1024)
	n, err := conn.Read(buf) // server hello
	if err != nil {
		t.Fatal(err)
	}
	if n != 2 || buf[0] != 0x05 || buf[1] != 0x00 {
		t.Fatalf("got: %q want: 0x05 0x00", buf[:n])
	}
	// make a socks5 udpAssociate conn
	newUdpAssociateConn := func() (socks5Conn net.Conn, socks5UDPAddr socksAddr) {
		// net/proxy don't support UDP, so we need to manually send the SOCKS5 UDP request
		conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", socks5Port))
		if err != nil {
			t.Fatal(err)
		}
		_, err = conn.Write([]byte{socks5Version, 0x01, noAuthRequired}) // client hello with no auth
		if err != nil {
			t.Fatal(err)
		}
		buf := make([]byte, 1024)
		n, err := conn.Read(buf) // server hello
		if err != nil {
			t.Fatal(err)
		}
		if n != 2 || buf[0] != socks5Version || buf[1] != noAuthRequired {
			t.Fatalf("got: %q want: 0x05 0x00", buf[:n])
		}

	targetAddr := socksAddr{
		addrType: domainName,
		addr:     "localhost",
		port:     uint16(backendServerPort),
	}
	targetAddrPkt, err := targetAddr.marshal()
	if err != nil {
		t.Fatal(err)
	}
	_, err = conn.Write(append([]byte{0x05, 0x03, 0x00}, targetAddrPkt...)) // client reqeust
	if err != nil {
		t.Fatal(err)
	}
		targetAddr := socksAddr{addrType: ipv4, addr: "0.0.0.0", port: 0}
		targetAddrPkt, err := targetAddr.marshal()
		if err != nil {
			t.Fatal(err)
		}
		_, err = conn.Write(append([]byte{socks5Version, byte(udpAssociate), 0x00}, targetAddrPkt...)) // client reqeust
		if err != nil {
			t.Fatal(err)
		}

	n, err = conn.Read(buf) // server response
	if err != nil {
		t.Fatal(err)
	}
	if n < 3 || !bytes.Equal(buf[:3], []byte{0x05, 0x00, 0x00}) {
		t.Fatalf("got: %q want: 0x05 0x00 0x00", buf[:n])
		n, err = conn.Read(buf) // server response
		if err != nil {
			t.Fatal(err)
		}
		if n < 3 || !bytes.Equal(buf[:3], []byte{socks5Version, 0x00, 0x00}) {
			t.Fatalf("got: %q want: 0x05 0x00 0x00", buf[:n])
		}
		udpProxySocksAddr, err := parseSocksAddr(bytes.NewReader(buf[3:n]))
		if err != nil {
			t.Fatal(err)
		}

		return conn, udpProxySocksAddr
	}
	udpProxySocksAddr, err := parseSocksAddr(bytes.NewReader(buf[3:n]))
	if err != nil {
		t.Fatal(err)

	conn, udpProxySocksAddr := newUdpAssociateConn()
	defer conn.Close()

	sendUDPAndWaitResponse := func(socks5UDPConn net.Conn, addr socksAddr, body []byte) (responseBody []byte) {
		udpPayload, err := (&udpRequest{addr: addr}).marshal()
		if err != nil {
			t.Fatal(err)
		}
		udpPayload = append(udpPayload, body...)
		_, err = socks5UDPConn.Write(udpPayload)
		if err != nil {
			t.Fatal(err)
		}
		buf := make([]byte, 1024)
		n, err := socks5UDPConn.Read(buf)
		if err != nil {
			t.Fatal(err)
		}
		_, responseBody, err = parseUDPRequest(buf[:n])
		if err != nil {
			t.Fatal(err)
		}
		return responseBody
	}

	udpProxyAddr, err := net.ResolveUDPAddr("udp", udpProxySocksAddr.hostPort())
	if err != nil {
		t.Fatal(err)
	}
	udpConn, err := net.DialUDP("udp", nil, udpProxyAddr)
	if err != nil {
		t.Fatal(err)
	}
	udpPayload, err := (&udpRequest{addr: targetAddr}).marshal()
	if err != nil {
		t.Fatal(err)
	}
	udpPayload = append(udpPayload, []byte("Test")...)
	_, err = udpConn.Write(udpPayload) // send udp package
	if err != nil {
		t.Fatal(err)
	}
	n, _, err = udpConn.ReadFrom(buf)
	if err != nil {
		t.Fatal(err)
	}
	_, responseBody, err := parseUDPRequest(buf[:n]) // read udp response
	if err != nil {
		t.Fatal(err)
	}
	if string(responseBody) != "Test" {
		t.Fatalf("got: %q want: Test", responseBody)
	}
	err = udpConn.Close()
	socks5UDPConn, err := net.DialUDP("udp", nil, udpProxyAddr)
	if err != nil {
		t.Fatal(err)
	}
	err = conn.Close()
	if err != nil {
		t.Fatal(err)
	defer socks5UDPConn.Close()

	for i := 0; i < echoServerNumber; i++ {
		port := echoServerListener[i].LocalAddr().(*net.UDPAddr).Port
		addr := socksAddr{addrType: ipv4, addr: "127.0.0.1", port: uint16(port)}
		requestBody := []byte(fmt.Sprintf("Test %d", i))
		responseBody := sendUDPAndWaitResponse(socks5UDPConn, addr, requestBody)
		if !bytes.Equal(requestBody, responseBody) {
			t.Fatalf("got: %q want: %q", responseBody, requestBody)
		}
	}
}