~egtann/srp

f26e48286997965d579ee5a9c8b5dd5b9b3e5b87 — Evan Tann 3 years ago 44f8e0f
fix https issues
2 files changed, 33 insertions(+), 12 deletions(-)

M cmd/srp/main.go
M proxy.go
M cmd/srp/main.go => cmd/srp/main.go +22 -8
@@ 67,27 67,41 @@ func main() {
	if len(*sslURL) > 0 {
		hosts := append(reg.Hosts(), selfURL.Host)
		m := &autocert.Manager{
			Cache:      autocert.DirCache("certs"),
			Prompt:     autocert.AcceptTOS,
			HostPolicy: autocert.HostWhitelist(hosts...),
		}
		srv.TLSConfig = &tls.Config{GetCertificate: m.GetCertificate}
		getCert := func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
			log.Printf("get cert for %s\n", hello.ServerName)
			cert, err := m.GetCertificate(hello)
			if err != nil {
				log.Println("failed to get cert:", err)
			}
			return cert, err
		}
		srv.TLSConfig = &tls.Config{GetCertificate: getCert}
		go func() {
			err = http.ListenAndServe(":http", m.HTTPHandler(nil))
			if err != nil {
				log.Fatal(errors.Wrap(err, "autocert"))
			}
		}()
		srv.Addr = ":https"
		port = "443"
		srv.Addr = ":https"
		go func() {
			if err = srv.ListenAndServeTLS("", ""); err != nil {
				log.Fatal(err)
			}
		}()
	} else {
		srv.Addr = ":" + port
		go func() {
			if err = srv.ListenAndServe(); err != nil {
				log.Fatal(err)
			}
		}()
	}
	go func() {
		log.Println("listening on", port)
		if err = srv.ListenAndServe(); err != nil {
			log.Fatal(err)
		}
	}()
	log.Println("listening on", port)
	if err = proxy.CheckHealth(); err != nil {
		log.Println("check health", err)
	}

M proxy.go => proxy.go +11 -4
@@ 208,7 208,12 @@ func newTransport(reg Registry) http.RoundTripper {
	return &http.Transport{
		Proxy: http.ProxyFromEnvironment,
		Dial: func(network, addr string) (net.Conn, error) {
			host, ok := reg[addr]
			// Trim training ":80"
			if len(addr) <= 3 {
				return nil, fmt.Errorf("invalid address %q", addr)
			}
			addrShort := addr[:len(addr)-3]
			host, ok := reg[addrShort]
			if !ok {
				return nil, fmt.Errorf("no host for %s", addr)
			}


@@ 218,16 223,18 @@ func newTransport(reg Registry) http.RoundTripper {
			}
			randInt := rand.Int()
			endpoint := endpoints[randInt%len(endpoints)]
			conn, err := net.Dial(network, endpoint)
			conn, err := net.Dial(network, endpoint+":80")
			if len(endpoints) < 2 || err == nil {
				return conn, err
			}
			// Retry on other endpoints if there are multiple
			conn, err = net.Dial(network, endpoints[(randInt+1)%len(endpoints)])
			endpoint = endpoints[(randInt+1)%len(endpoints)]
			conn, err = net.Dial(network, endpoint+":80")
			if len(endpoints) < 3 || err == nil {
				return conn, err
			}
			return net.Dial(network, endpoints[(randInt+2)%len(endpoints)])
			endpoint = endpoints[(randInt+2)%len(endpoints)]
			return net.Dial(network, endpoint+":80")
		},
		MaxIdleConns:          100,
		IdleConnTimeout:       30 * time.Second,