~emersion/tlstunnel

2fdea9d4ed2d7e96ec435b559d9c6920023a3b3c — Simon Ser 8 days ago ec2a768
Move back directive processing to tlstunnel package
3 files changed, 121 insertions(+), 106 deletions(-)

M cmd/tlstunnel/main.go
A directives.go
M server.go
M cmd/tlstunnel/main.go => cmd/tlstunnel/main.go +2 -106
@@ 2,11 2,7 @@ package main

import (
	"flag"
	"fmt"
	"log"
	"net"
	"net/url"
	"strings"

	"git.sr.ht/~emersion/tlstunnel"
)


@@ 24,19 20,8 @@ func main() {

	srv := tlstunnel.NewServer()

	for _, d := range cfg.Children {
		var err error
		switch d.Name {
		case "frontend":
			err = parseFrontend(srv, d)
		case "tls":
			err = parseTLS(srv, d)
		default:
			log.Fatalf("unknown %q directive", d.Name)
		}
		if err != nil {
			log.Fatalf("directive %q: %v", d.Name, err)
		}
	if err := srv.Load(cfg); err != nil {
		log.Fatal(err)
	}

	if err := srv.Start(); err != nil {


@@ 45,92 30,3 @@ func main() {

	select {}
}

func parseFrontend(srv *tlstunnel.Server, d *tlstunnel.Directive) error {
	frontend := &tlstunnel.Frontend{Server: srv}
	srv.Frontends = append(srv.Frontends, frontend)

	// TODO: support multiple backends
	backendDirective := d.ChildByName("backend")
	if backendDirective == nil {
		return fmt.Errorf("missing backend directive in frontend block")
	}
	if err := parseBackend(&frontend.Backend, backendDirective); err != nil {
		return err
	}

	for _, listenAddr := range d.Params {
		host, port, err := net.SplitHostPort(listenAddr)
		if err != nil {
			return fmt.Errorf("failed to parse listen address %q: %v", listenAddr, err)
		}

		// TODO: come up with something more robust
		var name string
		if host != "" && host != "localhost" && net.ParseIP(host) == nil {
			name = host
			host = ""

			srv.ManagedNames = append(srv.ManagedNames, name)
		}

		addr := net.JoinHostPort(host, port)

		ln := srv.RegisterListener(addr)
		if err := ln.RegisterFrontend(name, frontend); err != nil {
			return err
		}
	}

	return nil
}

func parseBackend(backend *tlstunnel.Backend, d *tlstunnel.Directive) error {
	var backendURI string
	if err := d.ParseParams(&backendURI); err != nil {
		return err
	}
	if !strings.Contains(backendURI, ":/") {
		// This is a raw domain name, make it an URL with an empty scheme
		backendURI = "//" + backendURI
	}

	u, err := url.Parse(backendURI)
	if err != nil {
		return fmt.Errorf("failed to parse backend URI %q: %v", backendURI, err)
	}

	if strings.HasSuffix(u.Scheme, "+proxy") {
		u.Scheme = strings.TrimSuffix(u.Scheme, "+proxy")
		backend.Proxy = true
	}

	switch u.Scheme {
	case "", "tcp":
		backend.Network = "tcp"
		backend.Address = u.Host
	case "unix":
		backend.Network = "unix"
		backend.Address = u.Host
	default:
		return fmt.Errorf("failed to setup backend %q: unsupported URI scheme", backendURI)
	}

	return nil
}

func parseTLS(srv *tlstunnel.Server, d *tlstunnel.Directive) error {
	for _, child := range d.Children {
		switch child.Name {
		case "acme_ca":
			var caURL string
			if err := child.ParseParams(&caURL); err != nil {
				return err
			}
			srv.ACMEManager.CA = caURL
		default:
			return fmt.Errorf("unknown %q directive", child.Name)
		}
	}
	return nil
}

A directives.go => directives.go +115 -0
@@ 0,0 1,115 @@
package tlstunnel

import (
	"fmt"
	"net"
	"net/url"
	"strings"
)

func parseConfig(srv *Server, cfg *Directive) error {
	for _, d := range cfg.Children {
		var err error
		switch d.Name {
		case "frontend":
			err = parseFrontend(srv, d)
		case "tls":
			err = parseTLS(srv, d)
		default:
			return fmt.Errorf("unknown %q directive", d.Name)
		}
		if err != nil {
			return fmt.Errorf("directive %q: %v", d.Name, err)
		}
	}
	return nil
}

func parseFrontend(srv *Server, d *Directive) error {
	frontend := &Frontend{Server: srv}
	srv.Frontends = append(srv.Frontends, frontend)

	// TODO: support multiple backends
	backendDirective := d.ChildByName("backend")
	if backendDirective == nil {
		return fmt.Errorf("missing backend directive in frontend block")
	}
	if err := parseBackend(&frontend.Backend, backendDirective); err != nil {
		return err
	}

	for _, listenAddr := range d.Params {
		host, port, err := net.SplitHostPort(listenAddr)
		if err != nil {
			return fmt.Errorf("failed to parse listen address %q: %v", listenAddr, err)
		}

		// TODO: come up with something more robust
		var name string
		if host != "" && host != "localhost" && net.ParseIP(host) == nil {
			name = host
			host = ""

			srv.ManagedNames = append(srv.ManagedNames, name)
		}

		addr := net.JoinHostPort(host, port)

		ln := srv.RegisterListener(addr)
		if err := ln.RegisterFrontend(name, frontend); err != nil {
			return err
		}
	}

	return nil
}

func parseBackend(backend *Backend, d *Directive) error {
	var backendURI string
	if err := d.ParseParams(&backendURI); err != nil {
		return err
	}
	if !strings.Contains(backendURI, ":/") {
		// This is a raw domain name, make it an URL with an empty scheme
		backendURI = "//" + backendURI
	}

	u, err := url.Parse(backendURI)
	if err != nil {
		return fmt.Errorf("failed to parse backend URI %q: %v", backendURI, err)
	}

	if strings.HasSuffix(u.Scheme, "+proxy") {
		u.Scheme = strings.TrimSuffix(u.Scheme, "+proxy")
		backend.Proxy = true
	}

	switch u.Scheme {
	case "", "tcp":
		backend.Network = "tcp"
		backend.Address = u.Host
	case "unix":
		backend.Network = "unix"
		backend.Address = u.Host
	default:
		return fmt.Errorf("failed to setup backend %q: unsupported URI scheme", backendURI)
	}

	return nil
}

func parseTLS(srv *Server, d *Directive) error {
	for _, child := range d.Children {
		switch child.Name {
		case "acme_ca":
			var caURL string
			if err := child.ParseParams(&caURL); err != nil {
				return err
			}
			srv.ACMEManager.CA = caURL
		default:
			return fmt.Errorf("unknown %q directive", child.Name)
		}
	}
	return nil
}

M server.go => server.go +4 -0
@@ 38,6 38,10 @@ func NewServer() *Server {
	}
}

func (srv *Server) Load(cfg *Directive) error {
	return parseConfig(srv, cfg)
}

func (srv *Server) RegisterListener(addr string) *Listener {
	// TODO: normalize addr with net.LookupPort
	ln, ok := srv.Listeners[addr]