~stepbrobd/tailscale

b0626ff84c11f8ad5c680fdec214eb5981307f1c — VimT 2 months ago 634cc2b
net/socks5: fix UDP relay in userspace-networking mode

This commit addresses an issue with the SOCKS5 UDP relay functionality
when using the --tun=userspace-networking option. Previously, UDP packets
were not being correctly routed into the Tailscale network in this mode.

Key changes:
- Replace single UDP connection with a map of connections per target
- Use c.srv.dial for creating connections to ensure proper routing

Updates #7581

Change-Id: Iaaa66f9de6a3713218014cf3f498003a7cac9832
Signed-off-by: VimT <me@vimt.me>
1 files changed, 63 insertions(+), 38 deletions(-)

M net/socks5/socks5.go
M net/socks5/socks5.go => net/socks5/socks5.go +63 -38
@@ 22,6 22,7 @@ import (
	"log"
	"net"
	"strconv"
	"tailscale.com/syncs"
	"time"

	"tailscale.com/types/logger"


@@ 81,6 82,12 @@ const (
	addrTypeNotSupported replyCode = 8
)

// UDP conn default buffer size and read timeout.
const (
	bufferSize  = 8 * 1024
	readTimeout = 5 * time.Second
)

// Server is a SOCKS5 proxy server.
type Server struct {
	// Logf optionally specifies the logger to use.


@@ 143,7 150,8 @@ type Conn struct {
	clientConn net.Conn
	request    *request

	udpClientAddr net.Addr
	udpClientAddr  net.Addr
	udpTargetConns syncs.Map[string, net.Conn]
}

// Run starts the new connection.


@@ 276,15 284,6 @@ func (c *Conn) handleUDP() error {
	}
	defer clientUDPConn.Close()

	serverUDPConn, err := net.ListenPacket("udp", "[::]:0")
	if err != nil {
		res := errorResponse(generalFailure)
		buf, _ := res.marshal()
		c.clientConn.Write(buf)
		return err
	}
	defer serverUDPConn.Close()

	bindAddr, bindPort, err := splitHostPort(clientUDPConn.LocalAddr().String())
	if err != nil {
		return err


@@ 305,14 304,20 @@ func (c *Conn) handleUDP() error {
	}
	c.clientConn.Write(buf)

	return c.transferUDP(c.clientConn, clientUDPConn, serverUDPConn)
	return c.transferUDP(c.clientConn, clientUDPConn)
}

func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn, targetConn net.PacketConn) error {
func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn) error {
	ctx, cancel := context.WithCancel(context.Background())
	defer cancel()
	const bufferSize = 8 * 1024
	const readTimeout = 5 * time.Second

	// close all target udp connections when the client connection is closed
	defer func() {
		c.udpTargetConns.Range(func(_ string, conn net.Conn) bool {
			_ = conn.Close()
			return true
		})
	}()

	// client -> target
	go func() {


@@ 323,7 328,7 @@ func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn, ta
			case <-ctx.Done():
				return
			default:
				err := c.handleUDPRequest(clientConn, targetConn, buf, readTimeout)
				err := c.handleUDPRequest(ctx, clientConn, buf)
				if err != nil {
					if isTimeout(err) {
						continue


@@ 337,21 342,50 @@ func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn, ta
		}
	}()

	// A UDP association terminates when the TCP connection that the UDP
	// ASSOCIATE request arrived on terminates. RFC1928
	_, err := io.Copy(io.Discard, associatedTCP)
	if err != nil {
		err = fmt.Errorf("udp associated tcp conn: %w", err)
	}
	return err
}

func (c *Conn) getOrDialTargetConn(
	ctx context.Context,
	clientConn net.PacketConn,
	targetAddr string,
) (net.Conn, error) {
	host, port, err := splitHostPort(targetAddr)
	if err != nil {
		return nil, err
	}

	conn, loaded := c.udpTargetConns.Load(targetAddr)
	if loaded {
		return conn, nil
	}
	conn, err = c.srv.dial(ctx, "udp", targetAddr)
	if err != nil {
		return nil, err
	}
	c.udpTargetConns.Store(targetAddr, conn)

	// target -> client
	go func() {
		defer cancel()
		buf := make([]byte, bufferSize)
		addr := socksAddr{addrType: getAddrType(host), addr: host, port: port}
		for {
			select {
			case <-ctx.Done():
				return
			default:
				err := c.handleUDPResponse(targetConn, clientConn, buf, readTimeout)
				err := c.handleUDPResponse(clientConn, addr, conn, buf)
				if err != nil {
					if isTimeout(err) {
						continue
					}
					if errors.Is(err, net.ErrClosed) {
					if errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) {
						return
					}
					c.logf("udp transfer: handle udp response fail: %v", err)


@@ 360,20 394,13 @@ func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn, ta
		}
	}()

	// A UDP association terminates when the TCP connection that the UDP
	// ASSOCIATE request arrived on terminates. RFC1928
	_, err := io.Copy(io.Discard, associatedTCP)
	if err != nil {
		err = fmt.Errorf("udp associated tcp conn: %w", err)
	}
	return err
	return conn, nil
}

func (c *Conn) handleUDPRequest(
	ctx context.Context,
	clientConn net.PacketConn,
	targetConn net.PacketConn,
	buf []byte,
	readTimeout time.Duration,
) error {
	// add a deadline for the read to avoid blocking forever
	_ = clientConn.SetReadDeadline(time.Now().Add(readTimeout))


@@ 386,12 413,14 @@ func (c *Conn) handleUDPRequest(
	if err != nil {
		return fmt.Errorf("parse udp request: %w", err)
	}
	targetAddr, err := net.ResolveUDPAddr("udp", req.addr.hostPort())

	targetAddr := req.addr.hostPort()
	targetConn, err := c.getOrDialTargetConn(ctx, clientConn, targetAddr)
	if err != nil {
		c.logf("resolve target addr fail: %v", err)
		return fmt.Errorf("dial target %s fail: %w", targetAddr, err)
	}

	nn, err := targetConn.WriteTo(data, targetAddr)
	nn, err := targetConn.Write(data)
	if err != nil {
		return fmt.Errorf("write to target %s fail: %w", targetAddr, err)
	}


@@ 402,22 431,18 @@ func (c *Conn) handleUDPRequest(
}

func (c *Conn) handleUDPResponse(
	targetConn net.PacketConn,
	clientConn net.PacketConn,
	targetAddr socksAddr,
	targetConn net.Conn,
	buf []byte,
	readTimeout time.Duration,
) error {
	// add a deadline for the read to avoid blocking forever
	_ = targetConn.SetReadDeadline(time.Now().Add(readTimeout))
	n, addr, err := targetConn.ReadFrom(buf)
	n, err := targetConn.Read(buf)
	if err != nil {
		return fmt.Errorf("read from target: %w", err)
	}
	host, port, err := splitHostPort(addr.String())
	if err != nil {
		return fmt.Errorf("split host port: %w", err)
	}
	hdr := udpRequest{addr: socksAddr{addrType: getAddrType(host), addr: host, port: port}}
	hdr := udpRequest{addr: targetAddr}
	pkt, err := hdr.marshal()
	if err != nil {
		return fmt.Errorf("marshal udp request: %w", err)