@@ 22,7 22,6 @@ import (
"log"
"net"
"strconv"
- "tailscale.com/syncs"
"time"
"tailscale.com/types/logger"
@@ 151,7 150,7 @@ type Conn struct {
request *request
udpClientAddr net.Addr
- udpTargetConns syncs.Map[string, net.Conn]
+ udpTargetConns map[socksAddr]net.Conn
}
// Run starts the new connection.
@@ 311,17 310,18 @@ func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn) er
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
- // close all target udp connections when the client connection is closed
- defer func() {
- c.udpTargetConns.Range(func(_ string, conn net.Conn) bool {
- _ = conn.Close()
- return true
- })
- }()
-
// client -> target
go func() {
defer cancel()
+
+ c.udpTargetConns = make(map[socksAddr]net.Conn)
+ // close all target udp connections when the client connection is closed
+ defer func() {
+ for _, conn := range c.udpTargetConns {
+ _ = conn.Close()
+ }
+ }()
+
buf := make([]byte, bufferSize)
for {
select {
@@ 354,33 354,27 @@ func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn) er
func (c *Conn) getOrDialTargetConn(
ctx context.Context,
clientConn net.PacketConn,
- targetAddr string,
+ targetAddr socksAddr,
) (net.Conn, error) {
- host, port, err := splitHostPort(targetAddr)
- if err != nil {
- return nil, err
- }
-
- conn, loaded := c.udpTargetConns.Load(targetAddr)
- if loaded {
+ conn, exist := c.udpTargetConns[targetAddr]
+ if exist {
return conn, nil
}
- conn, err = c.srv.dial(ctx, "udp", targetAddr)
+ conn, err := c.srv.dial(ctx, "udp", targetAddr.hostPort())
if err != nil {
return nil, err
}
- c.udpTargetConns.Store(targetAddr, conn)
+ c.udpTargetConns[targetAddr] = conn
// target -> client
go func() {
buf := make([]byte, bufferSize)
- addr := socksAddr{addrType: getAddrType(host), addr: host, port: port}
for {
select {
case <-ctx.Done():
return
default:
- err := c.handleUDPResponse(clientConn, addr, conn, buf)
+ err := c.handleUDPResponse(clientConn, targetAddr, conn, buf)
if err != nil {
if isTimeout(err) {
continue
@@ 414,18 408,17 @@ func (c *Conn) handleUDPRequest(
return fmt.Errorf("parse udp request: %w", err)
}
- targetAddr := req.addr.hostPort()
- targetConn, err := c.getOrDialTargetConn(ctx, clientConn, targetAddr)
+ targetConn, err := c.getOrDialTargetConn(ctx, clientConn, req.addr)
if err != nil {
- return fmt.Errorf("dial target %s fail: %w", targetAddr, err)
+ return fmt.Errorf("dial target %s fail: %w", req.addr, err)
}
nn, err := targetConn.Write(data)
if err != nil {
- return fmt.Errorf("write to target %s fail: %w", targetAddr, err)
+ return fmt.Errorf("write to target %s fail: %w", req.addr, err)
}
if nn != len(data) {
- return fmt.Errorf("write to target %s fail: %w", targetAddr, io.ErrShortWrite)
+ return fmt.Errorf("write to target %s fail: %w", req.addr, io.ErrShortWrite)
}
return nil
}
@@ 652,10 645,15 @@ func (s socksAddr) marshal() ([]byte, error) {
pkt = binary.BigEndian.AppendUint16(pkt, s.port)
return pkt, nil
}
+
func (s socksAddr) hostPort() string {
return net.JoinHostPort(s.addr, strconv.Itoa(int(s.port)))
}
+func (s socksAddr) String() string {
+ return s.hostPort()
+}
+
// response contains the contents of
// a response packet sent from the proxy
// to the client.
@@ 169,12 169,25 @@ func TestReadPassword(t *testing.T) {
func TestUDP(t *testing.T) {
// backend UDP server which we'll use SOCKS5 to connect to
- listener, err := net.ListenPacket("udp", ":0")
- if err != nil {
- t.Fatal(err)
+ newUDPEchoServer := func() net.PacketConn {
+ listener, err := net.ListenPacket("udp", ":0")
+ if err != nil {
+ t.Fatal(err)
+ }
+ go udpEchoServer(listener)
+ return listener
}
- backendServerPort := listener.LocalAddr().(*net.UDPAddr).Port
- go udpEchoServer(listener)
+
+ const echoServerNumber = 3
+ echoServerListener := make([]net.PacketConn, echoServerNumber)
+ for i := 0; i < echoServerNumber; i++ {
+ echoServerListener[i] = newUDPEchoServer()
+ }
+ defer func() {
+ for i := 0; i < echoServerNumber; i++ {
+ _ = echoServerListener[i].Close()
+ }
+ }()
// SOCKS5 server
socks5, err := net.Listen("tcp", ":0")
@@ 184,84 197,93 @@ func TestUDP(t *testing.T) {
socks5Port := socks5.Addr().(*net.TCPAddr).Port
go socks5Server(socks5)
- // net/proxy don't support UDP, so we need to manually send the SOCKS5 UDP request
- conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", socks5Port))
- if err != nil {
- t.Fatal(err)
- }
- _, err = conn.Write([]byte{0x05, 0x01, 0x00}) // client hello with no auth
- if err != nil {
- t.Fatal(err)
- }
- buf := make([]byte, 1024)
- n, err := conn.Read(buf) // server hello
- if err != nil {
- t.Fatal(err)
- }
- if n != 2 || buf[0] != 0x05 || buf[1] != 0x00 {
- t.Fatalf("got: %q want: 0x05 0x00", buf[:n])
- }
+ // make a socks5 udpAssociate conn
+ newUdpAssociateConn := func() (socks5Conn net.Conn, socks5UDPAddr socksAddr) {
+ // net/proxy don't support UDP, so we need to manually send the SOCKS5 UDP request
+ conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", socks5Port))
+ if err != nil {
+ t.Fatal(err)
+ }
+ _, err = conn.Write([]byte{socks5Version, 0x01, noAuthRequired}) // client hello with no auth
+ if err != nil {
+ t.Fatal(err)
+ }
+ buf := make([]byte, 1024)
+ n, err := conn.Read(buf) // server hello
+ if err != nil {
+ t.Fatal(err)
+ }
+ if n != 2 || buf[0] != socks5Version || buf[1] != noAuthRequired {
+ t.Fatalf("got: %q want: 0x05 0x00", buf[:n])
+ }
- targetAddr := socksAddr{
- addrType: domainName,
- addr: "localhost",
- port: uint16(backendServerPort),
- }
- targetAddrPkt, err := targetAddr.marshal()
- if err != nil {
- t.Fatal(err)
- }
- _, err = conn.Write(append([]byte{0x05, 0x03, 0x00}, targetAddrPkt...)) // client reqeust
- if err != nil {
- t.Fatal(err)
- }
+ targetAddr := socksAddr{addrType: ipv4, addr: "0.0.0.0", port: 0}
+ targetAddrPkt, err := targetAddr.marshal()
+ if err != nil {
+ t.Fatal(err)
+ }
+ _, err = conn.Write(append([]byte{socks5Version, byte(udpAssociate), 0x00}, targetAddrPkt...)) // client reqeust
+ if err != nil {
+ t.Fatal(err)
+ }
- n, err = conn.Read(buf) // server response
- if err != nil {
- t.Fatal(err)
- }
- if n < 3 || !bytes.Equal(buf[:3], []byte{0x05, 0x00, 0x00}) {
- t.Fatalf("got: %q want: 0x05 0x00 0x00", buf[:n])
+ n, err = conn.Read(buf) // server response
+ if err != nil {
+ t.Fatal(err)
+ }
+ if n < 3 || !bytes.Equal(buf[:3], []byte{socks5Version, 0x00, 0x00}) {
+ t.Fatalf("got: %q want: 0x05 0x00 0x00", buf[:n])
+ }
+ udpProxySocksAddr, err := parseSocksAddr(bytes.NewReader(buf[3:n]))
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ return conn, udpProxySocksAddr
}
- udpProxySocksAddr, err := parseSocksAddr(bytes.NewReader(buf[3:n]))
- if err != nil {
- t.Fatal(err)
+
+ conn, udpProxySocksAddr := newUdpAssociateConn()
+ defer conn.Close()
+
+ sendUDPAndWaitResponse := func(socks5UDPConn net.Conn, addr socksAddr, body []byte) (responseBody []byte) {
+ udpPayload, err := (&udpRequest{addr: addr}).marshal()
+ if err != nil {
+ t.Fatal(err)
+ }
+ udpPayload = append(udpPayload, body...)
+ _, err = socks5UDPConn.Write(udpPayload)
+ if err != nil {
+ t.Fatal(err)
+ }
+ buf := make([]byte, 1024)
+ n, err := socks5UDPConn.Read(buf)
+ if err != nil {
+ t.Fatal(err)
+ }
+ _, responseBody, err = parseUDPRequest(buf[:n])
+ if err != nil {
+ t.Fatal(err)
+ }
+ return responseBody
}
udpProxyAddr, err := net.ResolveUDPAddr("udp", udpProxySocksAddr.hostPort())
if err != nil {
t.Fatal(err)
}
- udpConn, err := net.DialUDP("udp", nil, udpProxyAddr)
- if err != nil {
- t.Fatal(err)
- }
- udpPayload, err := (&udpRequest{addr: targetAddr}).marshal()
- if err != nil {
- t.Fatal(err)
- }
- udpPayload = append(udpPayload, []byte("Test")...)
- _, err = udpConn.Write(udpPayload) // send udp package
- if err != nil {
- t.Fatal(err)
- }
- n, _, err = udpConn.ReadFrom(buf)
- if err != nil {
- t.Fatal(err)
- }
- _, responseBody, err := parseUDPRequest(buf[:n]) // read udp response
- if err != nil {
- t.Fatal(err)
- }
- if string(responseBody) != "Test" {
- t.Fatalf("got: %q want: Test", responseBody)
- }
- err = udpConn.Close()
+ socks5UDPConn, err := net.DialUDP("udp", nil, udpProxyAddr)
if err != nil {
t.Fatal(err)
}
- err = conn.Close()
- if err != nil {
- t.Fatal(err)
+ defer socks5UDPConn.Close()
+
+ for i := 0; i < echoServerNumber; i++ {
+ port := echoServerListener[i].LocalAddr().(*net.UDPAddr).Port
+ addr := socksAddr{addrType: ipv4, addr: "127.0.0.1", port: uint16(port)}
+ requestBody := []byte(fmt.Sprintf("Test %d", i))
+ responseBody := sendUDPAndWaitResponse(socks5UDPConn, addr, requestBody)
+ if !bytes.Equal(requestBody, responseBody) {
+ t.Fatalf("got: %q want: %q", responseBody, requestBody)
+ }
}
}