@@ 22,6 22,7 @@ import (
"log"
"net"
"strconv"
+ "tailscale.com/syncs"
"time"
"tailscale.com/types/logger"
@@ 81,6 82,12 @@ const (
addrTypeNotSupported replyCode = 8
)
+// UDP conn default buffer size and read timeout.
+const (
+ bufferSize = 8 * 1024
+ readTimeout = 5 * time.Second
+)
+
// Server is a SOCKS5 proxy server.
type Server struct {
// Logf optionally specifies the logger to use.
@@ 143,7 150,8 @@ type Conn struct {
clientConn net.Conn
request *request
- udpClientAddr net.Addr
+ udpClientAddr net.Addr
+ udpTargetConns syncs.Map[string, net.Conn]
}
// Run starts the new connection.
@@ 276,15 284,6 @@ func (c *Conn) handleUDP() error {
}
defer clientUDPConn.Close()
- serverUDPConn, err := net.ListenPacket("udp", "[::]:0")
- if err != nil {
- res := errorResponse(generalFailure)
- buf, _ := res.marshal()
- c.clientConn.Write(buf)
- return err
- }
- defer serverUDPConn.Close()
-
bindAddr, bindPort, err := splitHostPort(clientUDPConn.LocalAddr().String())
if err != nil {
return err
@@ 305,14 304,20 @@ func (c *Conn) handleUDP() error {
}
c.clientConn.Write(buf)
- return c.transferUDP(c.clientConn, clientUDPConn, serverUDPConn)
+ return c.transferUDP(c.clientConn, clientUDPConn)
}
-func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn, targetConn net.PacketConn) error {
+func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn) error {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
- const bufferSize = 8 * 1024
- const readTimeout = 5 * time.Second
+
+ // close all target udp connections when the client connection is closed
+ defer func() {
+ c.udpTargetConns.Range(func(_ string, conn net.Conn) bool {
+ _ = conn.Close()
+ return true
+ })
+ }()
// client -> target
go func() {
@@ 323,7 328,7 @@ func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn, ta
case <-ctx.Done():
return
default:
- err := c.handleUDPRequest(clientConn, targetConn, buf, readTimeout)
+ err := c.handleUDPRequest(ctx, clientConn, buf)
if err != nil {
if isTimeout(err) {
continue
@@ 337,21 342,50 @@ func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn, ta
}
}()
+ // A UDP association terminates when the TCP connection that the UDP
+ // ASSOCIATE request arrived on terminates. RFC1928
+ _, err := io.Copy(io.Discard, associatedTCP)
+ if err != nil {
+ err = fmt.Errorf("udp associated tcp conn: %w", err)
+ }
+ return err
+}
+
+func (c *Conn) getOrDialTargetConn(
+ ctx context.Context,
+ clientConn net.PacketConn,
+ targetAddr string,
+) (net.Conn, error) {
+ host, port, err := splitHostPort(targetAddr)
+ if err != nil {
+ return nil, err
+ }
+
+ conn, loaded := c.udpTargetConns.Load(targetAddr)
+ if loaded {
+ return conn, nil
+ }
+ conn, err = c.srv.dial(ctx, "udp", targetAddr)
+ if err != nil {
+ return nil, err
+ }
+ c.udpTargetConns.Store(targetAddr, conn)
+
// target -> client
go func() {
- defer cancel()
buf := make([]byte, bufferSize)
+ addr := socksAddr{addrType: getAddrType(host), addr: host, port: port}
for {
select {
case <-ctx.Done():
return
default:
- err := c.handleUDPResponse(targetConn, clientConn, buf, readTimeout)
+ err := c.handleUDPResponse(clientConn, addr, conn, buf)
if err != nil {
if isTimeout(err) {
continue
}
- if errors.Is(err, net.ErrClosed) {
+ if errors.Is(err, net.ErrClosed) || errors.Is(err, io.EOF) {
return
}
c.logf("udp transfer: handle udp response fail: %v", err)
@@ 360,20 394,13 @@ func (c *Conn) transferUDP(associatedTCP net.Conn, clientConn net.PacketConn, ta
}
}()
- // A UDP association terminates when the TCP connection that the UDP
- // ASSOCIATE request arrived on terminates. RFC1928
- _, err := io.Copy(io.Discard, associatedTCP)
- if err != nil {
- err = fmt.Errorf("udp associated tcp conn: %w", err)
- }
- return err
+ return conn, nil
}
func (c *Conn) handleUDPRequest(
+ ctx context.Context,
clientConn net.PacketConn,
- targetConn net.PacketConn,
buf []byte,
- readTimeout time.Duration,
) error {
// add a deadline for the read to avoid blocking forever
_ = clientConn.SetReadDeadline(time.Now().Add(readTimeout))
@@ 386,12 413,14 @@ func (c *Conn) handleUDPRequest(
if err != nil {
return fmt.Errorf("parse udp request: %w", err)
}
- targetAddr, err := net.ResolveUDPAddr("udp", req.addr.hostPort())
+
+ targetAddr := req.addr.hostPort()
+ targetConn, err := c.getOrDialTargetConn(ctx, clientConn, targetAddr)
if err != nil {
- c.logf("resolve target addr fail: %v", err)
+ return fmt.Errorf("dial target %s fail: %w", targetAddr, err)
}
- nn, err := targetConn.WriteTo(data, targetAddr)
+ nn, err := targetConn.Write(data)
if err != nil {
return fmt.Errorf("write to target %s fail: %w", targetAddr, err)
}
@@ 402,22 431,18 @@ func (c *Conn) handleUDPRequest(
}
func (c *Conn) handleUDPResponse(
- targetConn net.PacketConn,
clientConn net.PacketConn,
+ targetAddr socksAddr,
+ targetConn net.Conn,
buf []byte,
- readTimeout time.Duration,
) error {
// add a deadline for the read to avoid blocking forever
_ = targetConn.SetReadDeadline(time.Now().Add(readTimeout))
- n, addr, err := targetConn.ReadFrom(buf)
+ n, err := targetConn.Read(buf)
if err != nil {
return fmt.Errorf("read from target: %w", err)
}
- host, port, err := splitHostPort(addr.String())
- if err != nil {
- return fmt.Errorf("split host port: %w", err)
- }
- hdr := udpRequest{addr: socksAddr{addrType: getAddrType(host), addr: host, port: port}}
+ hdr := udpRequest{addr: targetAddr}
pkt, err := hdr.marshal()
if err != nil {
return fmt.Errorf("marshal udp request: %w", err)