@@ 17,37 17,43 @@ func main() {
flag.StringVar(&listenAddr, "listen", ":113", "listening address")
flag.Parse()
- var backends []dialFunc
- for _, backend := range flag.Args() {
- if !strings.Contains(backend, ":/") {
+ var backends []*backend
+ for _, backendURI := range flag.Args() {
+ if !strings.Contains(backendURI, ":/") {
// This is a raw domain name, make it an URL with an empty scheme
- backend = "//" + backend
+ backendURI = "//" + backendURI
}
- u, err := url.Parse(backend)
+ u, err := url.Parse(backendURI)
if err != nil {
- log.Fatalf("failed to parse backend URL %q: %v", backend, err)
+ log.Fatalf("failed to parse backend URL %q: %v", backendURI, err)
}
- var dial dialFunc
+ proxyProto := strings.HasSuffix(u.Scheme, "+proxy")
+ if proxyProto {
+ u.Scheme = strings.TrimSuffix(u.Scheme, "+proxy")
+ }
+
+ var network, addr string
switch u.Scheme {
case "ident", "tcp", "":
- host := u.Host
- if _, _, err := net.SplitHostPort(host); err != nil {
- host = host + ":113"
- }
- dial = func() (net.Conn, error) {
- return net.Dial("tcp", host)
+ network = "tcp"
+ addr = u.Host
+ if _, _, err := net.SplitHostPort(addr); err != nil {
+ addr = addr + ":113"
}
case "unix":
- dial = func() (net.Conn, error) {
- return net.Dial("unix", u.Host)
- }
+ network = "unix"
+ addr = u.Host
default:
- log.Fatalf("failed to setup backend %q: unsupported scheme", backend)
+ log.Fatalf("failed to setup backend %q: unsupported scheme", backendURI)
}
- backends = append(backends, dial)
+ backends = append(backends, &backend{
+ Network: network,
+ Address: addr,
+ Proxy: proxyProto,
+ })
}
if len(backends) == 0 {
@@ 10,15 10,24 @@ import (
"strconv"
"strings"
"time"
+
+ "github.com/pires/go-proxyproto"
)
var identdTimeout = 10 * time.Second
-type dialFunc func() (net.Conn, error)
+type backend struct {
+ Network, Address string
+ Proxy bool
+}
+
+func (b *backend) Dial() (net.Conn, error) {
+ return net.Dial(b.Network, b.Address)
+}
type proxy struct {
- Backends []dialFunc
- Debug bool
+ Backends []*backend
+ Debug bool
}
func (p *proxy) Serve(ln net.Listener) error {
@@ 37,12 46,6 @@ func (p *proxy) handle(c net.Conn) {
scanner := bufio.NewScanner(c)
- remoteHost, _, err := net.SplitHostPort(c.RemoteAddr().String())
- if err != nil {
- p.debug("failed to parse host/port from remote address %q: %v", c.RemoteAddr(), err)
- return
- }
-
// We only read to read lines with two port numbers
var buf [512]byte
scanner.Buffer(buf[:], len(buf))
@@ 62,7 65,7 @@ func (p *proxy) handle(c net.Conn) {
ctx := context.Background()
ctx, cancel := context.WithTimeout(ctx, identdTimeout)
- system, ident, err := p.queryAll(ctx, localPort, remoteHost, remotePort)
+ system, ident, err := p.queryAll(ctx, localPort, remotePort, c.LocalAddr(), c.RemoteAddr())
cancel()
if err != nil {
p.debug("query %q failed: %v", l, err)
@@ 82,13 85,13 @@ type identResult struct {
system, ident string
}
-func (p *proxy) queryAll(ctx context.Context, localPort int, remoteHost string, remotePort int) (system, ident string, err error) {
+func (p *proxy) queryAll(ctx context.Context, localPort, remotePort int, localAddr, remoteAddr net.Addr) (system, ident string, err error) {
// TODO: propagate ctx to backend queries
// TODO: use a connection pool somehow
ch := make(chan *identResult, len(p.Backends))
- for _, dial := range p.Backends {
+ for _, backend := range p.Backends {
go func() {
- system, ident, err := query(dial, localPort, remoteHost, remotePort)
+ system, ident, err := query(backend, localPort, remotePort, localAddr, remoteAddr)
if err != nil {
p.debug("backend returned error: %v", err)
ch <- nil
@@ 137,13 140,19 @@ func parseIdentQuery(l string) (localPort, remotePort int, err error) {
return localPort, remotePort, nil
}
-func query(dial dialFunc, localPort int, remoteHost string, remotePort int) (system, ident string, err error) {
- c, err := dial()
+func query(backend *backend, localPort, remotePort int, localAddr, remoteAddr net.Addr) (system, ident string, err error) {
+ c, err := backend.Dial()
if err != nil {
return "", "", err
}
defer c.Close()
+ if backend.Proxy {
+ if _, err := proxyHeader(localAddr, remoteAddr).WriteTo(c); err != nil {
+ return "", "", err
+ }
+ }
+
_, err = fmt.Fprintf(c, "%v, %v\r\n", localPort, remotePort)
if err != nil {
return "", "", err
@@ 154,7 163,7 @@ func query(dial dialFunc, localPort int, remoteHost string, remotePort int) (sys
if err := scanner.Err(); err != nil {
return "", "", err
} else {
- return "", "", io.EOF
+ return "", "", io.ErrUnexpectedEOF
}
}
l := scanner.Text()
@@ 178,3 187,28 @@ func query(dial dialFunc, localPort int, remoteHost string, remotePort int) (sys
return system, ident, nil
}
+
+func proxyHeader(localAddr, remoteAddr net.Addr) *proxyproto.Header {
+ h := proxyproto.Header{
+ Version: 1,
+ Command: proxyproto.PROXY,
+ }
+
+ if localAddr.Network() == remoteAddr.Network() {
+ switch localAddr := localAddr.(type) {
+ case *net.TCPAddr:
+ remoteAddr := remoteAddr.(*net.TCPAddr)
+ if localIP4 := localAddr.IP.To4(); len(localIP4) == net.IPv4len {
+ h.TransportProtocol = proxyproto.TCPv4
+ } else if len(localAddr.IP) == net.IPv6len {
+ h.TransportProtocol = proxyproto.TCPv6
+ }
+ h.SourceAddress = remoteAddr.IP
+ h.DestinationAddress = localAddr.IP
+ h.SourcePort = uint16(remoteAddr.Port)
+ h.DestinationPort = uint16(localAddr.Port)
+ }
+ }
+
+ return &h
+}