937534285786521c2dce57849569a59c1108e569 — Martin Angers 7 months ago 9fdaae1 + 275531d
Merge branch 'master' of git.sr.ht:~mna/kick
M TODO.md => TODO.md +1 -0
@@ 13,3 13,4 @@
 11. Structured logging
 12. Access path / router variables from handler
 13. Disable http/2 option (maybe to support websocket)
+14. Health checks, including DBs

M builder/builder.go => builder/builder.go +6 -2
@@ 156,7 156,11 @@ func Handler(s *kick.Server) (http.Handler, error) {
 	// routes, respectively).
 	configureGetAndOptions(router, s.Routes)
 
-	h := combineMiddleware(router, mw)
+	routerHandler := http.Handler(router)
+	if s.Root != nil && s.Root.RouterMiddleware != nil {
+		routerHandler = s.Root.RouterMiddleware(router)
+	}
+	h := combineMiddleware(routerHandler, mw)
 	return h, nil
 }
 


@@ 172,7 176,7 @@ func rootMiddleware(s *kick.Server, router *httprouter.Router) ([]middleware, er
 	var mw []middleware
 
 	if fn := s.Root.LoggingFunc; fn != nil {
-		mw = append(mw, logging(fn))
+		mw = append(mw, logging(s.Root.RequestIDHeader, fn))
 	}
 	if fn := s.Root.PanicRecoveryFunc; fn != nil {
 		mw = append(mw, panicRecovery(fn))

M builder/builder_test.go => builder/builder_test.go +213 -8
@@ 10,6 10,7 @@ import (
 	"io"
 	"io/ioutil"
 	"net/http"
+	"net/http/cookiejar"
 	"net/http/httptest"
 	"os"
 	"strconv"


@@ 18,6 19,8 @@ import (
 	"time"
 
 	"git.sr.ht/~mna/kick"
+	"github.com/gorilla/csrf"
+	"github.com/gorilla/securecookie"
 	"github.com/sirupsen/logrus"
 	"github.com/unrolled/secure"
 )


@@ 293,8 296,6 @@ func TestHandler_Valid(t *testing.T) {
   "value": "hello"
 }`
 
-	t.Logf("xml uncompressed body size: %d; json uncompressed body size: %d", len(xmlDoc), len(jsonDoc))
-
 	noslashCORS := &kick.CORSConfig{
 		AllowedMethods: []string{"GET", "PUT"},
 		AllowedOrigins: []string{"http://notexample.com", "notexample.com"},


@@ 691,13 692,217 @@ func TestHandler_Valid(t *testing.T) {
 				}
 			}
 
-			// TODO: HEAD body should be verified by an actual client call
-			// to that endpoint - I think the ResponseWriter sees the writes,
-			// but they wouldn't make it to the wire and the client wouldn't
-			// see them.
-			// TODO: same for CSRF, should be tested via a client with a
-			// CookieJar.
 			t.Log(logBuf.String())
 		})
 	}
 }
+
+func TestHandler_HEAD(t *testing.T) {
+	var logBuf bytes.Buffer
+	logrus.SetOutput(&logBuf)
+
+	const jsonDoc = `{
+  "value": "hello"
+}`
+
+	s := &kick.Server{
+		Root: &kick.Root{
+			PanicRecoveryFunc: func(w http.ResponseWriter, r *http.Request, v interface{}) {
+				w.WriteHeader(500)
+			},
+			LoggingFunc: func(w http.ResponseWriter, r *http.Request, info map[string]interface{}) {
+				logrus.WithFields(info).Info()
+			},
+			TrustProxyHeaders:   true,
+			AllowMethodOverride: true,
+			RequestIDHeader:     "X-Request-Id",
+		},
+		Routes: []*kick.Route{
+			{
+				Method: "GET",
+				Path:   "/",
+				Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+					w.Header().Add("Content-Length", strconv.Itoa(len(jsonDoc)))
+					fmt.Fprint(w, jsonDoc)
+				}),
+			},
+		},
+	}
+
+	h, err := Handler(s)
+	if err != nil {
+		t.Fatalf("failed to configure handlers: %s", err)
+	}
+
+	srv := httptest.NewServer(h)
+	defer srv.Close()
+
+	cli := &http.Client{Timeout: time.Second}
+	req, err := http.NewRequest("HEAD", srv.URL+"/", nil)
+	if err != nil {
+		t.Fatalf("failed to create new request: %s", err)
+	}
+
+	res, err := cli.Do(req)
+	if err != nil {
+		t.Fatalf("request failed: %s", err)
+	}
+	defer res.Body.Close()
+
+	if res.StatusCode != 200 {
+		t.Fatalf("want status 200, got %d", res.StatusCode)
+	}
+	if int(res.ContentLength) != len(jsonDoc) {
+		t.Fatalf("want Content-Length %d, got %d", len(jsonDoc), res.ContentLength)
+	}
+
+	b, err := ioutil.ReadAll(res.Body)
+	if err != nil {
+		t.Fatalf("failed to read response body: %s", err)
+	}
+	if len(b) > 0 {
+		t.Fatalf("want empty body, got %s", string(b))
+	}
+}
+
+func TestHandler_CSRF(t *testing.T) {
+	var logBuf bytes.Buffer
+	logrus.SetOutput(&logBuf)
+
+	csrfConf := &kick.CSRFConfig{
+		AuthKey:    securecookie.GenerateRandomKey(32),
+		CookieName: "csrf",
+		HTTPOnly:   true,
+		Secure:     false,
+	}
+	s := &kick.Server{
+		Root: &kick.Root{
+			PanicRecoveryFunc: func(w http.ResponseWriter, r *http.Request, v interface{}) {
+				w.WriteHeader(500)
+			},
+			LoggingFunc: func(w http.ResponseWriter, r *http.Request, info map[string]interface{}) {
+				logrus.WithFields(info).Info()
+			},
+			TrustProxyHeaders:   true,
+			AllowMethodOverride: true,
+			RequestIDHeader:     "X-Request-Id",
+		},
+		Routes: []*kick.Route{
+			{
+				Method: "GET",
+				Path:   "/",
+				Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+					w.Header().Set("X-CSRF-Token", csrf.Token(r))
+				}),
+				Config: &kick.HandlerConfig{
+					CSRF: csrfConf,
+				},
+			},
+			{
+				Method:  "POST",
+				Path:    "/",
+				Handler: statusHandler(200),
+				Config: &kick.HandlerConfig{
+					CSRF: csrfConf,
+				},
+			},
+		},
+	}
+
+	h, err := Handler(s)
+	if err != nil {
+		t.Fatalf("failed to configure handler: %s", err)
+	}
+	srv := httptest.NewServer(h)
+	defer srv.Close()
+
+	jar, err := cookiejar.New(nil)
+	if err != nil {
+		t.Fatalf("failed to create cookie jar: %s", err)
+	}
+	cli := &http.Client{
+		Jar:     jar,
+		Timeout: time.Second,
+	}
+
+	// #1: initial GET receives CSRF cookie and token
+	var tok1 string
+	t.Run("#1 initial GET", func(t *testing.T) {
+		res := makeCall(t, cli, "GET", srv.URL+"/", "", true)
+		if res.StatusCode != 200 {
+			t.Fatalf("want GET #1 status 200, got %d", res.StatusCode)
+		}
+		tok1 = res.Header.Get("X-CSRF-Token")
+		if tok1 == "" {
+			t.Fatalf("want csrf token header to be set")
+		}
+	})
+
+	// #2: POST without CSRF token
+	t.Run("#2 POST without token", func(t *testing.T) {
+		res := makeCall(t, cli, "POST", srv.URL+"/", "", false)
+		if res.StatusCode != 403 {
+			t.Fatalf("want POST #2 status 403, got %d", res.StatusCode)
+		}
+	})
+
+	// #3: POST with GET #1 token succeeds
+	t.Run("#3 POST with token", func(t *testing.T) {
+		res := makeCall(t, cli, "POST", srv.URL+"/", tok1, false)
+		if res.StatusCode != 200 {
+			t.Fatalf("want POST #3 status 200, got %d", res.StatusCode)
+		}
+	})
+
+	// #4: GET receives new CSRF cookie and token
+	var tok4 string
+	t.Run("#4 GET new token", func(t *testing.T) {
+		res := makeCall(t, cli, "GET", srv.URL+"/", "", true)
+		if res.StatusCode != 200 {
+			t.Fatalf("want GET #4 status 200, got %d", res.StatusCode)
+		}
+		tok4 = res.Header.Get("X-CSRF-Token")
+	})
+
+	// #5: POST with valid token succeeds
+	t.Run("#5 POST with token", func(t *testing.T) {
+		res := makeCall(t, cli, "POST", srv.URL+"/", tok4, false)
+		if res.StatusCode != 200 {
+			t.Fatalf("want POST #5 status 200, got %d", res.StatusCode)
+		}
+	})
+}
+
+func makeCall(t *testing.T, cli *http.Client, method, path, tok string, checkCookie bool) *http.Response {
+	req, err := http.NewRequest(method, path, nil)
+	if err != nil {
+		t.Fatalf("failed to create request: %s", err)
+	}
+	if tok != "" {
+		req.Header.Set("X-CSRF-Token", tok)
+	}
+	res, err := cli.Do(req)
+	if err != nil {
+		t.Fatalf("failed call: %s", err)
+	}
+	defer res.Body.Close()
+
+	if checkCookie {
+		cookies := cli.Jar.Cookies(res.Request.URL)
+		ck := cookieByName(cookies, "csrf")
+		if ck == nil || ck.Value == "" {
+			t.Fatalf("want csrf cookie to be set")
+		}
+	}
+
+	return res
+}
+
+func cookieByName(cookies []*http.Cookie, nm string) *http.Cookie {
+	for _, c := range cookies {
+		if c.Name == nm {
+			return c
+		}
+	}
+	return nil
+}

M builder/middleware.go => builder/middleware.go +4 -1
@@ 171,7 171,7 @@ func panicRecovery(recoverFn func(http.ResponseWriter, *http.Request, interface{
 
 // logging calls the wrapped handler and collects relevant information
 // about the request, and then calls logFn with that information.
-func logging(logFn func(http.ResponseWriter, *http.Request, map[string]interface{})) func(http.Handler) http.Handler {
+func logging(reqIDHeader string, logFn func(http.ResponseWriter, *http.Request, map[string]interface{})) func(http.Handler) http.Handler {
 	return func(h http.Handler) http.Handler {
 		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 			if _, ok := w.(interface {


@@ 201,6 201,9 @@ func logging(logFn func(http.ResponseWriter, *http.Request, map[string]interface
 				"user_agent":          r.UserAgent(),
 				"remote_addr":         r.RemoteAddr,
 			}
+			if reqIDHeader != "" {
+				m["request_id"] = w.Header().Get(reqIDHeader)
+			}
 			if w, ok := w.(interface {
 				Status() int
 				Size() int64

M builder/middleware_test.go => builder/middleware_test.go +1 -1
@@ 285,7 285,7 @@ func TestLogging(t *testing.T) {
 	logFn := func(w http.ResponseWriter, r *http.Request, info map[string]interface{}) {
 		logger.WithFields(logrus.Fields(info)).Info("logging")
 	}
-	h := logging(logFn)(statusHandler(204))
+	h := logging("", logFn)(statusHandler(204))
 
 	w := httptest.NewRecorder()
 	r, _ := http.NewRequest("", "/", nil)

M go.mod => go.mod +1 -0
@@ 6,6 6,7 @@ require (
 	github.com/NYTimes/gziphandler v1.1.1
 	github.com/gorilla/csrf v1.5.1
 	github.com/gorilla/handlers v1.4.1-0.20190227193432-ac6d24f88de4
+	github.com/gorilla/securecookie v1.1.1
 	github.com/juju/ratelimit v1.0.1
 	github.com/julienschmidt/httprouter v1.2.0
 	github.com/kr/pretty v0.1.0 // indirect

M kick.go => kick.go +21 -2
@@ 56,8 56,6 @@ type Server struct {
 	// Routes is the list of routes served by the server.
 	Routes []*Route
 
-	// TODO: validate that http2 is properly supported.
-
 	// TLS configures the TLS settings of the server.
 	TLS *TLSConfig
 


@@ 119,6 117,10 @@ func (s *Server) build() error {
 		return err
 	}
 
+	if s.TLS != nil && s.TLS.DisableHTTP2 && srv.TLSNextProto == nil {
+		srv.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler))
+	}
+
 	h, err := s.Builders.Handler(s)
 	if err != nil {
 		return err


@@ 342,6 344,16 @@ type TLSConfig struct {
 	CertFile string
 	KeyFile  string
 
+	// DisableHTTP2 indicates that the TLS server should not support
+	// HTTP/2. By default, a Server with a TLSConfig would automatically
+	// support HTTP/2, but it may be useful in certain conditions to
+	// disable that automatic support (e.g. for Websockets, at least
+	// until those are supported on HTTP/2).
+	//
+	// This is an option on TLSConfig and not directly on the Server,
+	// because automatic HTTP/2 support is only enabled for TLS servers.
+	DisableHTTP2 bool
+
 	_ struct{} // prevent unkeyed struct creation
 }
 


@@ 383,6 395,13 @@ type Root struct {
 	// for the requested method and path.
 	NotFoundHandler http.Handler
 
+	// RouterMiddleware is a middleware function that is applied to
+	// the router itself, meaning that it's handler will run for all
+	// routes, including the NotFound and MethodNotAllowed handlers,
+	// but ensuring it runs after the root middleware so that e.g.
+	// method overrides, request IDs, etc. are applied first.
+	RouterMiddleware func(http.Handler) http.Handler
+
 	TrustProxyHeaders   bool
 	AllowMethodOverride bool
 

M kick_test.go => kick_test.go +106 -0
@@ 39,6 39,111 @@ func nextPort() int {
 	return nonRandomPort
 }
 
+func TestServer_HTTP2(t *testing.T) {
+	const (
+		timeout      = time.Second
+		requestAfter = 500 * time.Millisecond
+	)
+
+	var port = nextPort()
+
+	s := kick.Server{
+		Builders: builder.Default,
+		Addr:     fmt.Sprintf(":%d", port),
+		Root: &kick.Root{
+			NotFoundHandler: http.NotFoundHandler(),
+		},
+		TLS: &kick.TLSConfig{
+			CertFile: localhostCert,
+			KeyFile:  localhostKey,
+		},
+	}
+
+	ctx, cancel := context.WithTimeout(context.Background(), timeout)
+	defer cancel()
+
+	var wg sync.WaitGroup
+	wg.Add(1)
+	go func() {
+		defer wg.Done()
+
+		err := s.ListenAndServe(ctx)
+		if err != nil {
+			t.Fatalf("want no server error, got %s", err)
+		}
+	}()
+
+	time.Sleep(requestAfter)
+	res, err := http.Get(fmt.Sprintf("https://localhost:%d/", port))
+	if err != nil {
+		t.Fatalf("want no client error, got %s", err)
+	}
+	defer res.Body.Close()
+
+	if res.StatusCode != 404 {
+		t.Fatalf("want status code 404, got %d", res.StatusCode)
+	}
+	if res.ProtoMajor < 2 {
+		t.Fatalf("want http/2, got %s", res.Proto)
+	}
+
+	cancel()
+	wg.Wait()
+}
+
+func TestServer_HTTP2Disabled(t *testing.T) {
+	const (
+		timeout      = time.Second
+		requestAfter = 500 * time.Millisecond
+	)
+
+	var port = nextPort()
+
+	s := kick.Server{
+		Builders: builder.Default,
+		Addr:     fmt.Sprintf(":%d", port),
+		Root: &kick.Root{
+			NotFoundHandler: http.NotFoundHandler(),
+		},
+		TLS: &kick.TLSConfig{
+			CertFile:     localhostCert,
+			KeyFile:      localhostKey,
+			DisableHTTP2: true,
+		},
+	}
+
+	ctx, cancel := context.WithTimeout(context.Background(), timeout)
+	defer cancel()
+
+	var wg sync.WaitGroup
+	wg.Add(1)
+	go func() {
+		defer wg.Done()
+
+		err := s.ListenAndServe(ctx)
+		if err != nil {
+			t.Fatalf("want no server error, got %s", err)
+		}
+	}()
+
+	time.Sleep(requestAfter)
+	res, err := http.Get(fmt.Sprintf("https://localhost:%d/", port))
+	if err != nil {
+		t.Fatalf("want no client error, got %s", err)
+	}
+	defer res.Body.Close()
+
+	if res.StatusCode != 404 {
+		t.Fatalf("want status code 404, got %d", res.StatusCode)
+	}
+	if res.ProtoMajor >= 2 {
+		t.Fatalf("want http/1.1, got %s", res.Proto)
+	}
+
+	cancel()
+	wg.Wait()
+}
+
 func TestServer_TLS(t *testing.T) {
 	const (
 		timeout      = time.Second


@@ 99,6 204,7 @@ func TestServer_TLS(t *testing.T) {
 			if string(b) != "ok" {
 				t.Fatalf(`want response "ok", got %s`, string(b))
 			}
+			cancel()
 			wg.Wait()
 		})
 	}