A api_test.go => api_test.go +29 -0
@@ 0,0 1,29 @@
+package srp
+
+import (
+ "net"
+ "testing"
+)
+
+func TestMaskIP(t *testing.T) {
+ ip := "10.128.0.0/24"
+ maskedIP, mask, err := maskIP(ip)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ ip = "10.128.0.10"
+ if maskedIP != net.ParseIP(ip).Mask(mask).String() {
+ t.Fatal("expected match")
+ }
+
+ ip = "1.1.1.1"
+ if maskedIP == net.ParseIP(ip).Mask(mask).String() {
+ t.Fatal("expected no match")
+ }
+
+ ip = ""
+ if maskedIP == net.ParseIP(ip).Mask(mask).String() {
+ t.Fatal("expected no match")
+ }
+}
M cmd/srp/main.go => cmd/srp/main.go +1 -1
@@ 62,7 62,7 @@ func main() {
srv := &http.Server{
// TODO(egtann) wrap proxy to allow API requests over the
// whitelisted subnet
- Handler: proxy,
+ Handler: proxy.Handler(),
ReadTimeout: timeout,
WriteTimeout: timeout,
MaxHeaderBytes: 1 << 20,
M proxy.go => proxy.go +0 -18
@@ 105,29 105,11 @@ func NewProxy(log Logger, reg *Registry) *ReverseProxy {
}
}
-// ServeHTTP implements the http.RoundTripper interface.
func (r *ReverseProxy) ServeHTTP(w http.ResponseWriter, req *http.Request) {
r.mu.RLock()
defer r.mu.RUnlock()
- // Only allow GET and HEAD requests to the API
- switch req.Method {
- case "GET", "HEAD":
- if strings.TrimPrefix(req.URL.Path, "/") == "services" {
- reg := cloneRegistryNoLock(r.reg)
- for _, srv := range reg.Services {
- // Only show live backends
- srv.Backends = srv.liveBackends
- }
- err := json.NewEncoder(w).Encode(reg.Services)
- if err != nil {
- r.log.Printf("failed to encode registry: %s", err)
- }
- return
- }
- }
r.rp.ServeHTTP(w, req)
- return
}
func newRegistry(r io.Reader) (*Registry, error) {