~emersion/ident-proxy

26d93d5902d7b8ec2a6524157c39610bd074181a — Simon Ser 1 year, 8 months ago 720bedd master
Add support for the PROXY protocol
4 files changed, 79 insertions(+), 34 deletions(-)

M README.md
M go.mod
M main.go
M proxy.go
M README.md => README.md +3 -0
@@ 16,6 16,8 @@ Each backend is a URL in the form:
- `<host>[:port]` or `tcp://<host>[:port]`
- `unix://<path>`

Add `+proxy` to the URL scheme to enable the [PROXY protocol].

Example:

    ident-proxy unix:///run/identd 127.0.0.1:1113


@@ 25,3 27,4 @@ Example:
MIT

[ident]: https://tools.ietf.org/html/rfc1413
[PROXY protocol]: http://www.haproxy.org/download/2.3/doc/proxy-protocol.txt

M go.mod => go.mod +2 -0
@@ 1,3 1,5 @@
module git.sr.ht/~emersion/ident-proxy

go 1.15

require github.com/pires/go-proxyproto v0.1.3

M main.go => main.go +24 -18
@@ 17,37 17,43 @@ func main() {
	flag.StringVar(&listenAddr, "listen", ":113", "listening address")
	flag.Parse()

	var backends []dialFunc
	for _, backend := range flag.Args() {
		if !strings.Contains(backend, ":/") {
	var backends []*backend
	for _, backendURI := range flag.Args() {
		if !strings.Contains(backendURI, ":/") {
			// This is a raw domain name, make it an URL with an empty scheme
			backend = "//" + backend
			backendURI = "//" + backendURI
		}

		u, err := url.Parse(backend)
		u, err := url.Parse(backendURI)
		if err != nil {
			log.Fatalf("failed to parse backend URL %q: %v", backend, err)
			log.Fatalf("failed to parse backend URL %q: %v", backendURI, err)
		}

		var dial dialFunc
		proxyProto := strings.HasSuffix(u.Scheme, "+proxy")
		if proxyProto {
			u.Scheme = strings.TrimSuffix(u.Scheme, "+proxy")
		}

		var network, addr string
		switch u.Scheme {
		case "ident", "tcp", "":
			host := u.Host
			if _, _, err := net.SplitHostPort(host); err != nil {
				host = host + ":113"
			}
			dial = func() (net.Conn, error) {
				return net.Dial("tcp", host)
			network = "tcp"
			addr = u.Host
			if _, _, err := net.SplitHostPort(addr); err != nil {
				addr = addr + ":113"
			}
		case "unix":
			dial = func() (net.Conn, error) {
				return net.Dial("unix", u.Host)
			}
			network = "unix"
			addr = u.Host
		default:
			log.Fatalf("failed to setup backend %q: unsupported scheme", backend)
			log.Fatalf("failed to setup backend %q: unsupported scheme", backendURI)
		}

		backends = append(backends, dial)
		backends = append(backends, &backend{
			Network: network,
			Address: addr,
			Proxy:   proxyProto,
		})
	}

	if len(backends) == 0 {

M proxy.go => proxy.go +50 -16
@@ 10,15 10,24 @@ import (
	"strconv"
	"strings"
	"time"

	"github.com/pires/go-proxyproto"
)

var identdTimeout = 10 * time.Second

type dialFunc func() (net.Conn, error)
type backend struct {
	Network, Address string
	Proxy            bool
}

func (b *backend) Dial() (net.Conn, error) {
	return net.Dial(b.Network, b.Address)
}

type proxy struct {
	Backends []dialFunc
	Debug bool
	Backends []*backend
	Debug    bool
}

func (p *proxy) Serve(ln net.Listener) error {


@@ 37,12 46,6 @@ func (p *proxy) handle(c net.Conn) {

	scanner := bufio.NewScanner(c)

	remoteHost, _, err := net.SplitHostPort(c.RemoteAddr().String())
	if err != nil {
		p.debug("failed to parse host/port from remote address %q: %v", c.RemoteAddr(), err)
		return
	}

	// We only read to read lines with two port numbers
	var buf [512]byte
	scanner.Buffer(buf[:], len(buf))


@@ 62,7 65,7 @@ func (p *proxy) handle(c net.Conn) {

		ctx := context.Background()
		ctx, cancel := context.WithTimeout(ctx, identdTimeout)
		system, ident, err := p.queryAll(ctx, localPort, remoteHost, remotePort)
		system, ident, err := p.queryAll(ctx, localPort, remotePort, c.LocalAddr(), c.RemoteAddr())
		cancel()
		if err != nil {
			p.debug("query %q failed: %v", l, err)


@@ 82,13 85,13 @@ type identResult struct {
	system, ident string
}

func (p *proxy) queryAll(ctx context.Context, localPort int, remoteHost string, remotePort int) (system, ident string, err error) {
func (p *proxy) queryAll(ctx context.Context, localPort, remotePort int, localAddr, remoteAddr net.Addr) (system, ident string, err error) {
	// TODO: propagate ctx to backend queries
	// TODO: use a connection pool somehow
	ch := make(chan *identResult, len(p.Backends))
	for _, dial := range p.Backends {
	for _, backend := range p.Backends {
		go func() {
			system, ident, err := query(dial, localPort, remoteHost, remotePort)
			system, ident, err := query(backend, localPort, remotePort, localAddr, remoteAddr)
			if err != nil {
				p.debug("backend returned error: %v", err)
				ch <- nil


@@ 137,13 140,19 @@ func parseIdentQuery(l string) (localPort, remotePort int, err error) {
	return localPort, remotePort, nil
}

func query(dial dialFunc, localPort int, remoteHost string, remotePort int) (system, ident string, err error) {
	c, err := dial()
func query(backend *backend, localPort, remotePort int, localAddr, remoteAddr net.Addr) (system, ident string, err error) {
	c, err := backend.Dial()
	if err != nil {
		return "", "", err
	}
	defer c.Close()

	if backend.Proxy {
		if _, err := proxyHeader(localAddr, remoteAddr).WriteTo(c); err != nil {
			return "", "", err
		}
	}

	_, err = fmt.Fprintf(c, "%v, %v\r\n", localPort, remotePort)
	if err != nil {
		return "", "", err


@@ 154,7 163,7 @@ func query(dial dialFunc, localPort int, remoteHost string, remotePort int) (sys
		if err := scanner.Err(); err != nil {
			return "", "", err
		} else {
			return "", "", io.EOF
			return "", "", io.ErrUnexpectedEOF
		}
	}
	l := scanner.Text()


@@ 178,3 187,28 @@ func query(dial dialFunc, localPort int, remoteHost string, remotePort int) (sys

	return system, ident, nil
}

func proxyHeader(localAddr, remoteAddr net.Addr) *proxyproto.Header {
	h := proxyproto.Header{
		Version: 1,
		Command: proxyproto.PROXY,
	}

	if localAddr.Network() == remoteAddr.Network() {
		switch localAddr := localAddr.(type) {
		case *net.TCPAddr:
			remoteAddr := remoteAddr.(*net.TCPAddr)
			if localIP4 := localAddr.IP.To4(); len(localIP4) == net.IPv4len {
				h.TransportProtocol = proxyproto.TCPv4
			} else if len(localAddr.IP) == net.IPv6len {
				h.TransportProtocol = proxyproto.TCPv6
			}
			h.SourceAddress = remoteAddr.IP
			h.DestinationAddress = localAddr.IP
			h.SourcePort = uint16(remoteAddr.Port)
			h.DestinationPort = uint16(localAddr.Port)
		}
	}

	return &h
}