937534285786521c2dce57849569a59c1108e569 — Martin Angers 5 months ago 9fdaae1 + 275531d
Merge branch 'master' of git.sr.ht:~mna/kick
M TODO.md => TODO.md +1 -0
@@ 13,3 13,4 @@ 11. Structured logging
  12. Access path / router variables from handler
  13. Disable http/2 option (maybe to support websocket)
+ 14. Health checks, including DBs

M builder/builder.go => builder/builder.go +6 -2
@@ 156,7 156,11 @@ // routes, respectively).
  	configureGetAndOptions(router, s.Routes)
  
- 	h := combineMiddleware(router, mw)
+ 	routerHandler := http.Handler(router)
+ 	if s.Root != nil && s.Root.RouterMiddleware != nil {
+ 		routerHandler = s.Root.RouterMiddleware(router)
+ 	}
+ 	h := combineMiddleware(routerHandler, mw)
  	return h, nil
  }
  


@@ 172,7 176,7 @@ var mw []middleware
  
  	if fn := s.Root.LoggingFunc; fn != nil {
- 		mw = append(mw, logging(fn))
+ 		mw = append(mw, logging(s.Root.RequestIDHeader, fn))
  	}
  	if fn := s.Root.PanicRecoveryFunc; fn != nil {
  		mw = append(mw, panicRecovery(fn))

M builder/builder_test.go => builder/builder_test.go +213 -8
@@ 10,6 10,7 @@ "io"
  	"io/ioutil"
  	"net/http"
+ 	"net/http/cookiejar"
  	"net/http/httptest"
  	"os"
  	"strconv"


@@ 18,6 19,8 @@ "time"
  
  	"git.sr.ht/~mna/kick"
+ 	"github.com/gorilla/csrf"
+ 	"github.com/gorilla/securecookie"
  	"github.com/sirupsen/logrus"
  	"github.com/unrolled/secure"
  )


@@ 293,8 296,6 @@ "value": "hello"
  }`
  
- 	t.Logf("xml uncompressed body size: %d; json uncompressed body size: %d", len(xmlDoc), len(jsonDoc))
- 
  	noslashCORS := &kick.CORSConfig{
  		AllowedMethods: []string{"GET", "PUT"},
  		AllowedOrigins: []string{"http://notexample.com", "notexample.com"},


@@ 691,13 692,217 @@ }
  			}
  
- 			// TODO: HEAD body should be verified by an actual client call
- 			// to that endpoint - I think the ResponseWriter sees the writes,
- 			// but they wouldn't make it to the wire and the client wouldn't
- 			// see them.
- 			// TODO: same for CSRF, should be tested via a client with a
- 			// CookieJar.
  			t.Log(logBuf.String())
  		})
  	}
  }
+ 
+ func TestHandler_HEAD(t *testing.T) {
+ 	var logBuf bytes.Buffer
+ 	logrus.SetOutput(&logBuf)
+ 
+ 	const jsonDoc = `{
+   "value": "hello"
+ }`
+ 
+ 	s := &kick.Server{
+ 		Root: &kick.Root{
+ 			PanicRecoveryFunc: func(w http.ResponseWriter, r *http.Request, v interface{}) {
+ 				w.WriteHeader(500)
+ 			},
+ 			LoggingFunc: func(w http.ResponseWriter, r *http.Request, info map[string]interface{}) {
+ 				logrus.WithFields(info).Info()
+ 			},
+ 			TrustProxyHeaders:   true,
+ 			AllowMethodOverride: true,
+ 			RequestIDHeader:     "X-Request-Id",
+ 		},
+ 		Routes: []*kick.Route{
+ 			{
+ 				Method: "GET",
+ 				Path:   "/",
+ 				Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ 					w.Header().Add("Content-Length", strconv.Itoa(len(jsonDoc)))
+ 					fmt.Fprint(w, jsonDoc)
+ 				}),
+ 			},
+ 		},
+ 	}
+ 
+ 	h, err := Handler(s)
+ 	if err != nil {
+ 		t.Fatalf("failed to configure handlers: %s", err)
+ 	}
+ 
+ 	srv := httptest.NewServer(h)
+ 	defer srv.Close()
+ 
+ 	cli := &http.Client{Timeout: time.Second}
+ 	req, err := http.NewRequest("HEAD", srv.URL+"/", nil)
+ 	if err != nil {
+ 		t.Fatalf("failed to create new request: %s", err)
+ 	}
+ 
+ 	res, err := cli.Do(req)
+ 	if err != nil {
+ 		t.Fatalf("request failed: %s", err)
+ 	}
+ 	defer res.Body.Close()
+ 
+ 	if res.StatusCode != 200 {
+ 		t.Fatalf("want status 200, got %d", res.StatusCode)
+ 	}
+ 	if int(res.ContentLength) != len(jsonDoc) {
+ 		t.Fatalf("want Content-Length %d, got %d", len(jsonDoc), res.ContentLength)
+ 	}
+ 
+ 	b, err := ioutil.ReadAll(res.Body)
+ 	if err != nil {
+ 		t.Fatalf("failed to read response body: %s", err)
+ 	}
+ 	if len(b) > 0 {
+ 		t.Fatalf("want empty body, got %s", string(b))
+ 	}
+ }
+ 
+ func TestHandler_CSRF(t *testing.T) {
+ 	var logBuf bytes.Buffer
+ 	logrus.SetOutput(&logBuf)
+ 
+ 	csrfConf := &kick.CSRFConfig{
+ 		AuthKey:    securecookie.GenerateRandomKey(32),
+ 		CookieName: "csrf",
+ 		HTTPOnly:   true,
+ 		Secure:     false,
+ 	}
+ 	s := &kick.Server{
+ 		Root: &kick.Root{
+ 			PanicRecoveryFunc: func(w http.ResponseWriter, r *http.Request, v interface{}) {
+ 				w.WriteHeader(500)
+ 			},
+ 			LoggingFunc: func(w http.ResponseWriter, r *http.Request, info map[string]interface{}) {
+ 				logrus.WithFields(info).Info()
+ 			},
+ 			TrustProxyHeaders:   true,
+ 			AllowMethodOverride: true,
+ 			RequestIDHeader:     "X-Request-Id",
+ 		},
+ 		Routes: []*kick.Route{
+ 			{
+ 				Method: "GET",
+ 				Path:   "/",
+ 				Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ 					w.Header().Set("X-CSRF-Token", csrf.Token(r))
+ 				}),
+ 				Config: &kick.HandlerConfig{
+ 					CSRF: csrfConf,
+ 				},
+ 			},
+ 			{
+ 				Method:  "POST",
+ 				Path:    "/",
+ 				Handler: statusHandler(200),
+ 				Config: &kick.HandlerConfig{
+ 					CSRF: csrfConf,
+ 				},
+ 			},
+ 		},
+ 	}
+ 
+ 	h, err := Handler(s)
+ 	if err != nil {
+ 		t.Fatalf("failed to configure handler: %s", err)
+ 	}
+ 	srv := httptest.NewServer(h)
+ 	defer srv.Close()
+ 
+ 	jar, err := cookiejar.New(nil)
+ 	if err != nil {
+ 		t.Fatalf("failed to create cookie jar: %s", err)
+ 	}
+ 	cli := &http.Client{
+ 		Jar:     jar,
+ 		Timeout: time.Second,
+ 	}
+ 
+ 	// #1: initial GET receives CSRF cookie and token
+ 	var tok1 string
+ 	t.Run("#1 initial GET", func(t *testing.T) {
+ 		res := makeCall(t, cli, "GET", srv.URL+"/", "", true)
+ 		if res.StatusCode != 200 {
+ 			t.Fatalf("want GET #1 status 200, got %d", res.StatusCode)
+ 		}
+ 		tok1 = res.Header.Get("X-CSRF-Token")
+ 		if tok1 == "" {
+ 			t.Fatalf("want csrf token header to be set")
+ 		}
+ 	})
+ 
+ 	// #2: POST without CSRF token
+ 	t.Run("#2 POST without token", func(t *testing.T) {
+ 		res := makeCall(t, cli, "POST", srv.URL+"/", "", false)
+ 		if res.StatusCode != 403 {
+ 			t.Fatalf("want POST #2 status 403, got %d", res.StatusCode)
+ 		}
+ 	})
+ 
+ 	// #3: POST with GET #1 token succeeds
+ 	t.Run("#3 POST with token", func(t *testing.T) {
+ 		res := makeCall(t, cli, "POST", srv.URL+"/", tok1, false)
+ 		if res.StatusCode != 200 {
+ 			t.Fatalf("want POST #3 status 200, got %d", res.StatusCode)
+ 		}
+ 	})
+ 
+ 	// #4: GET receives new CSRF cookie and token
+ 	var tok4 string
+ 	t.Run("#4 GET new token", func(t *testing.T) {
+ 		res := makeCall(t, cli, "GET", srv.URL+"/", "", true)
+ 		if res.StatusCode != 200 {
+ 			t.Fatalf("want GET #4 status 200, got %d", res.StatusCode)
+ 		}
+ 		tok4 = res.Header.Get("X-CSRF-Token")
+ 	})
+ 
+ 	// #5: POST with valid token succeeds
+ 	t.Run("#5 POST with token", func(t *testing.T) {
+ 		res := makeCall(t, cli, "POST", srv.URL+"/", tok4, false)
+ 		if res.StatusCode != 200 {
+ 			t.Fatalf("want POST #5 status 200, got %d", res.StatusCode)
+ 		}
+ 	})
+ }
+ 
+ func makeCall(t *testing.T, cli *http.Client, method, path, tok string, checkCookie bool) *http.Response {
+ 	req, err := http.NewRequest(method, path, nil)
+ 	if err != nil {
+ 		t.Fatalf("failed to create request: %s", err)
+ 	}
+ 	if tok != "" {
+ 		req.Header.Set("X-CSRF-Token", tok)
+ 	}
+ 	res, err := cli.Do(req)
+ 	if err != nil {
+ 		t.Fatalf("failed call: %s", err)
+ 	}
+ 	defer res.Body.Close()
+ 
+ 	if checkCookie {
+ 		cookies := cli.Jar.Cookies(res.Request.URL)
+ 		ck := cookieByName(cookies, "csrf")
+ 		if ck == nil || ck.Value == "" {
+ 			t.Fatalf("want csrf cookie to be set")
+ 		}
+ 	}
+ 
+ 	return res
+ }
+ 
+ func cookieByName(cookies []*http.Cookie, nm string) *http.Cookie {
+ 	for _, c := range cookies {
+ 		if c.Name == nm {
+ 			return c
+ 		}
+ 	}
+ 	return nil
+ }

M builder/middleware.go => builder/middleware.go +4 -1
@@ 171,7 171,7 @@   // logging calls the wrapped handler and collects relevant information
  // about the request, and then calls logFn with that information.
- func logging(logFn func(http.ResponseWriter, *http.Request, map[string]interface{})) func(http.Handler) http.Handler {
+ func logging(reqIDHeader string, logFn func(http.ResponseWriter, *http.Request, map[string]interface{})) func(http.Handler) http.Handler {
  	return func(h http.Handler) http.Handler {
  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  			if _, ok := w.(interface {


@@ 201,6 201,9 @@ "user_agent":          r.UserAgent(),
  				"remote_addr":         r.RemoteAddr,
  			}
+ 			if reqIDHeader != "" {
+ 				m["request_id"] = w.Header().Get(reqIDHeader)
+ 			}
  			if w, ok := w.(interface {
  				Status() int
  				Size() int64

M builder/middleware_test.go => builder/middleware_test.go +1 -1
@@ 285,7 285,7 @@ logFn := func(w http.ResponseWriter, r *http.Request, info map[string]interface{}) {
  		logger.WithFields(logrus.Fields(info)).Info("logging")
  	}
- 	h := logging(logFn)(statusHandler(204))
+ 	h := logging("", logFn)(statusHandler(204))
  
  	w := httptest.NewRecorder()
  	r, _ := http.NewRequest("", "/", nil)

M go.mod => go.mod +1 -0
@@ 6,6 6,7 @@ github.com/NYTimes/gziphandler v1.1.1
  	github.com/gorilla/csrf v1.5.1
  	github.com/gorilla/handlers v1.4.1-0.20190227193432-ac6d24f88de4
+ 	github.com/gorilla/securecookie v1.1.1
  	github.com/juju/ratelimit v1.0.1
  	github.com/julienschmidt/httprouter v1.2.0
  	github.com/kr/pretty v0.1.0 // indirect

M kick.go => kick.go +21 -2
@@ 56,8 56,6 @@ // Routes is the list of routes served by the server.
  	Routes []*Route
  
- 	// TODO: validate that http2 is properly supported.
- 
  	// TLS configures the TLS settings of the server.
  	TLS *TLSConfig
  


@@ 119,6 117,10 @@ return err
  	}
  
+ 	if s.TLS != nil && s.TLS.DisableHTTP2 && srv.TLSNextProto == nil {
+ 		srv.TLSNextProto = make(map[string]func(*http.Server, *tls.Conn, http.Handler))
+ 	}
+ 
  	h, err := s.Builders.Handler(s)
  	if err != nil {
  		return err


@@ 342,6 344,16 @@ CertFile string
  	KeyFile  string
  
+ 	// DisableHTTP2 indicates that the TLS server should not support
+ 	// HTTP/2. By default, a Server with a TLSConfig would automatically
+ 	// support HTTP/2, but it may be useful in certain conditions to
+ 	// disable that automatic support (e.g. for Websockets, at least
+ 	// until those are supported on HTTP/2).
+ 	//
+ 	// This is an option on TLSConfig and not directly on the Server,
+ 	// because automatic HTTP/2 support is only enabled for TLS servers.
+ 	DisableHTTP2 bool
+ 
  	_ struct{} // prevent unkeyed struct creation
  }
  


@@ 383,6 395,13 @@ // for the requested method and path.
  	NotFoundHandler http.Handler
  
+ 	// RouterMiddleware is a middleware function that is applied to
+ 	// the router itself, meaning that it's handler will run for all
+ 	// routes, including the NotFound and MethodNotAllowed handlers,
+ 	// but ensuring it runs after the root middleware so that e.g.
+ 	// method overrides, request IDs, etc. are applied first.
+ 	RouterMiddleware func(http.Handler) http.Handler
+ 
  	TrustProxyHeaders   bool
  	AllowMethodOverride bool
  

M kick_test.go => kick_test.go +106 -0
@@ 39,6 39,111 @@ return nonRandomPort
  }
  
+ func TestServer_HTTP2(t *testing.T) {
+ 	const (
+ 		timeout      = time.Second
+ 		requestAfter = 500 * time.Millisecond
+ 	)
+ 
+ 	var port = nextPort()
+ 
+ 	s := kick.Server{
+ 		Builders: builder.Default,
+ 		Addr:     fmt.Sprintf(":%d", port),
+ 		Root: &kick.Root{
+ 			NotFoundHandler: http.NotFoundHandler(),
+ 		},
+ 		TLS: &kick.TLSConfig{
+ 			CertFile: localhostCert,
+ 			KeyFile:  localhostKey,
+ 		},
+ 	}
+ 
+ 	ctx, cancel := context.WithTimeout(context.Background(), timeout)
+ 	defer cancel()
+ 
+ 	var wg sync.WaitGroup
+ 	wg.Add(1)
+ 	go func() {
+ 		defer wg.Done()
+ 
+ 		err := s.ListenAndServe(ctx)
+ 		if err != nil {
+ 			t.Fatalf("want no server error, got %s", err)
+ 		}
+ 	}()
+ 
+ 	time.Sleep(requestAfter)
+ 	res, err := http.Get(fmt.Sprintf("https://localhost:%d/", port))
+ 	if err != nil {
+ 		t.Fatalf("want no client error, got %s", err)
+ 	}
+ 	defer res.Body.Close()
+ 
+ 	if res.StatusCode != 404 {
+ 		t.Fatalf("want status code 404, got %d", res.StatusCode)
+ 	}
+ 	if res.ProtoMajor < 2 {
+ 		t.Fatalf("want http/2, got %s", res.Proto)
+ 	}
+ 
+ 	cancel()
+ 	wg.Wait()
+ }
+ 
+ func TestServer_HTTP2Disabled(t *testing.T) {
+ 	const (
+ 		timeout      = time.Second
+ 		requestAfter = 500 * time.Millisecond
+ 	)
+ 
+ 	var port = nextPort()
+ 
+ 	s := kick.Server{
+ 		Builders: builder.Default,
+ 		Addr:     fmt.Sprintf(":%d", port),
+ 		Root: &kick.Root{
+ 			NotFoundHandler: http.NotFoundHandler(),
+ 		},
+ 		TLS: &kick.TLSConfig{
+ 			CertFile:     localhostCert,
+ 			KeyFile:      localhostKey,
+ 			DisableHTTP2: true,
+ 		},
+ 	}
+ 
+ 	ctx, cancel := context.WithTimeout(context.Background(), timeout)
+ 	defer cancel()
+ 
+ 	var wg sync.WaitGroup
+ 	wg.Add(1)
+ 	go func() {
+ 		defer wg.Done()
+ 
+ 		err := s.ListenAndServe(ctx)
+ 		if err != nil {
+ 			t.Fatalf("want no server error, got %s", err)
+ 		}
+ 	}()
+ 
+ 	time.Sleep(requestAfter)
+ 	res, err := http.Get(fmt.Sprintf("https://localhost:%d/", port))
+ 	if err != nil {
+ 		t.Fatalf("want no client error, got %s", err)
+ 	}
+ 	defer res.Body.Close()
+ 
+ 	if res.StatusCode != 404 {
+ 		t.Fatalf("want status code 404, got %d", res.StatusCode)
+ 	}
+ 	if res.ProtoMajor >= 2 {
+ 		t.Fatalf("want http/1.1, got %s", res.Proto)
+ 	}
+ 
+ 	cancel()
+ 	wg.Wait()
+ }
+ 
  func TestServer_TLS(t *testing.T) {
  	const (
  		timeout      = time.Second


@@ 99,6 204,7 @@ if string(b) != "ok" {
  				t.Fatalf(`want response "ok", got %s`, string(b))
  			}
+ 			cancel()
  			wg.Wait()
  		})
  	}