~egtann/srp

ref: 88118f3ca0b8930e61ddc56e7b9d20260772ef31 srp/proxy.go -rw-r--r-- 4.2 KiB
88118f3c — Evan Tann initial commit 3 years ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
package srp

import (
	"encoding/json"
	"fmt"
	"io/ioutil"
	"log"
	"math/rand"
	"net"
	"net/http"
	"net/http/httputil"
	"sync"
	"time"

	"github.com/pkg/errors"
)

type ReverseProxy struct {
	rp  httputil.ReverseProxy
	reg Registry
	mu  sync.RWMutex
	log Logger
}

type Registry map[string]*struct {
	HealthPath   string
	Backends     []string
	liveBackends []string
}

type Logger interface {
	Printf(format string, vals ...interface{})
}

// NewProxy from a given Registry.
func NewProxy(log Logger, reg Registry) *ReverseProxy {
	director := func(req *http.Request) {
		req.URL.Scheme = "http"
		req.URL.Host = req.Host
	}
	transport := newTransport(reg)
	rp := httputil.ReverseProxy{Director: director, Transport: transport}
	return &ReverseProxy{rp: rp, log: log, reg: reg}
}

// ServeHTTP implements the http.RoundTripper interface.
func (r *ReverseProxy) ServeHTTP(w http.ResponseWriter, req *http.Request) {
	r.mu.RLock()
	defer r.mu.RUnlock()
	r.rp.ServeHTTP(w, req)
}

// NewRegistry for a given configuration file. This reports an error if any
// frontend host has no backends.
func NewRegistry(filename string) (Registry, error) {
	byt, err := ioutil.ReadFile(filename)
	if err != nil {
		return nil, errors.Wrapf(err, "read file %q", filename)
	}
	reg := Registry{}
	err = json.Unmarshal(byt, &reg)
	if err != nil {
		return nil, errors.Wrap(err, "unmarshal config")
	}
	for host, v := range reg {
		if len(v.Backends) == 0 {
			return nil, fmt.Errorf("missing backends for %q", host)
		}
	}
	return reg, nil
}

// Hosts for the registry.
func (r Registry) Hosts() []string {
	hosts := []string{}
	for k := range r {
		hosts = append(hosts, k)
	}
	return hosts
}

// CheckHealth of backend servers in the registry concurrently, up to 10 at a
// time. If an unexpected error is returned during any of the checks,
// CheckHealth immediately exits, reporting that error.
func (r *ReverseProxy) CheckHealth(client *http.Client) error {
	regClone := Registry{}
	for k, v := range r.reg {
		regClone[k] = v
	}
	semaphore := make(chan int, 10)
	for host, frontend := range regClone {
		if len(frontend.HealthPath) == 0 {
			continue
		}
		liveBackends := []string{}
		ipchan := make(chan string)
		errchan := make(chan error, 1)
		for _, ip := range frontend.Backends {
			target := "http://" + ip + frontend.HealthPath
			semaphore <- 1
			go ping(client, ip, target, semaphore, ipchan, errchan)
		}
		f := regClone[host]
		for i := 0; i < len(frontend.Backends); i++ {
			select {
			case ip := <-ipchan:
				if ip == "" {
					continue
				}
				liveBackends = append(liveBackends, ip)
			case err := <-errchan:
				return errors.Wrap(err, "err on channel")
			}
		}
		f.liveBackends = liveBackends
	}

	// Update the registry
	r.mu.Lock()
	defer r.mu.Unlock()
	r.reg = regClone
	r.rp.Transport = newTransport(regClone)
	return nil
}

func newTransport(reg Registry) http.RoundTripper {
	return &http.Transport{
		Proxy: http.ProxyFromEnvironment,
		Dial: func(network, addr string) (net.Conn, error) {
			endpoints := reg[addr].Backends
			randInt := rand.Int()
			endpoint := endpoints[randInt%len(endpoints)]
			conn, err := net.Dial(network, endpoint)
			if len(endpoints) < 2 || err == nil {
				return conn, err
			}
			// Retry on other endpoints if there are multiple
			conn, err = net.Dial(network, endpoints[(randInt+1)%len(endpoints)])
			if len(endpoints) < 3 || err == nil {
				return conn, err
			}
			return net.Dial(network, endpoints[(randInt+2)%len(endpoints)])
		},
		MaxIdleConns:          100,
		IdleConnTimeout:       30 * time.Second,
		TLSHandshakeTimeout:   10 * time.Second,
		ResponseHeaderTimeout: 10 * time.Second,
	}
}

func ping(
	client *http.Client,
	ip, target string,
	semaphore chan int,
	ipchan chan string,
	errchan chan error,
) {
	defer func() {
		semaphore <- 1
	}()
	req, err := http.NewRequest("GET", target, nil)
	if err != nil {
		errchan <- errors.Wrap(err, "new request")
		return
	}
	resp, err := client.Do(req)
	if err != nil {
		log.Printf("%s: failed connection: %s", ip, err)
		ipchan <- ""
		return
	}
	if err = resp.Body.Close(); err != nil {
		errchan <- errors.Wrap(err, "close resp body")
		return
	}
	if resp.StatusCode != http.StatusOK {
		log.Printf("%s: expected status code 200, got %d",
			ip, resp.StatusCode)
		ipchan <- ""
		return
	}
	ipchan <- ip
}