616fcfa757ac32d7a6a7d9834299f20a602ac0a3 — Martin Angers 7 months ago 41e6ab7
test csrf
2 files changed, 214 insertions(+), 8 deletions(-)

M builder/builder_test.go
M go.mod
M builder/builder_test.go => builder/builder_test.go +213 -8
@@ 10,6 10,7 @@ import (
 	"io"
 	"io/ioutil"
 	"net/http"
+	"net/http/cookiejar"
 	"net/http/httptest"
 	"os"
 	"strconv"


@@ 18,6 19,8 @@ import (
 	"time"
 
 	"git.sr.ht/~mna/kick"
+	"github.com/gorilla/csrf"
+	"github.com/gorilla/securecookie"
 	"github.com/sirupsen/logrus"
 	"github.com/unrolled/secure"
 )


@@ 293,8 296,6 @@ func TestHandler_Valid(t *testing.T) {
   "value": "hello"
 }`
 
-	t.Logf("xml uncompressed body size: %d; json uncompressed body size: %d", len(xmlDoc), len(jsonDoc))
-
 	noslashCORS := &kick.CORSConfig{
 		AllowedMethods: []string{"GET", "PUT"},
 		AllowedOrigins: []string{"http://notexample.com", "notexample.com"},


@@ 691,13 692,217 @@ func TestHandler_Valid(t *testing.T) {
 				}
 			}
 
-			// TODO: HEAD body should be verified by an actual client call
-			// to that endpoint - I think the ResponseWriter sees the writes,
-			// but they wouldn't make it to the wire and the client wouldn't
-			// see them.
-			// TODO: same for CSRF, should be tested via a client with a
-			// CookieJar.
 			t.Log(logBuf.String())
 		})
 	}
 }
+
+func TestHandler_HEAD(t *testing.T) {
+	var logBuf bytes.Buffer
+	logrus.SetOutput(&logBuf)
+
+	const jsonDoc = `{
+  "value": "hello"
+}`
+
+	s := &kick.Server{
+		Root: &kick.Root{
+			PanicRecoveryFunc: func(w http.ResponseWriter, r *http.Request, v interface{}) {
+				w.WriteHeader(500)
+			},
+			LoggingFunc: func(w http.ResponseWriter, r *http.Request, info map[string]interface{}) {
+				logrus.WithFields(info).Info()
+			},
+			TrustProxyHeaders:   true,
+			AllowMethodOverride: true,
+			RequestIDHeader:     "X-Request-Id",
+		},
+		Routes: []*kick.Route{
+			{
+				Method: "GET",
+				Path:   "/",
+				Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+					w.Header().Add("Content-Length", strconv.Itoa(len(jsonDoc)))
+					fmt.Fprint(w, jsonDoc)
+				}),
+			},
+		},
+	}
+
+	h, err := Handler(s)
+	if err != nil {
+		t.Fatalf("failed to configure handlers: %s", err)
+	}
+
+	srv := httptest.NewServer(h)
+	defer srv.Close()
+
+	cli := &http.Client{Timeout: time.Second}
+	req, err := http.NewRequest("HEAD", srv.URL+"/", nil)
+	if err != nil {
+		t.Fatalf("failed to create new request: %s", err)
+	}
+
+	res, err := cli.Do(req)
+	if err != nil {
+		t.Fatalf("request failed: %s", err)
+	}
+	defer res.Body.Close()
+
+	if res.StatusCode != 200 {
+		t.Fatalf("want status 200, got %d", res.StatusCode)
+	}
+	if int(res.ContentLength) != len(jsonDoc) {
+		t.Fatalf("want Content-Length %d, got %d", len(jsonDoc), res.ContentLength)
+	}
+
+	b, err := ioutil.ReadAll(res.Body)
+	if err != nil {
+		t.Fatalf("failed to read response body: %s", err)
+	}
+	if len(b) > 0 {
+		t.Fatalf("want empty body, got %s", string(b))
+	}
+}
+
+func TestHandler_CSRF(t *testing.T) {
+	var logBuf bytes.Buffer
+	logrus.SetOutput(&logBuf)
+
+	csrfConf := &kick.CSRFConfig{
+		AuthKey:    securecookie.GenerateRandomKey(32),
+		CookieName: "csrf",
+		HTTPOnly:   true,
+		Secure:     false,
+	}
+	s := &kick.Server{
+		Root: &kick.Root{
+			PanicRecoveryFunc: func(w http.ResponseWriter, r *http.Request, v interface{}) {
+				w.WriteHeader(500)
+			},
+			LoggingFunc: func(w http.ResponseWriter, r *http.Request, info map[string]interface{}) {
+				logrus.WithFields(info).Info()
+			},
+			TrustProxyHeaders:   true,
+			AllowMethodOverride: true,
+			RequestIDHeader:     "X-Request-Id",
+		},
+		Routes: []*kick.Route{
+			{
+				Method: "GET",
+				Path:   "/",
+				Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+					w.Header().Set("X-CSRF-Token", csrf.Token(r))
+				}),
+				Config: &kick.HandlerConfig{
+					CSRF: csrfConf,
+				},
+			},
+			{
+				Method:  "POST",
+				Path:    "/",
+				Handler: statusHandler(200),
+				Config: &kick.HandlerConfig{
+					CSRF: csrfConf,
+				},
+			},
+		},
+	}
+
+	h, err := Handler(s)
+	if err != nil {
+		t.Fatalf("failed to configure handler: %s", err)
+	}
+	srv := httptest.NewServer(h)
+	defer srv.Close()
+
+	jar, err := cookiejar.New(nil)
+	if err != nil {
+		t.Fatalf("failed to create cookie jar: %s", err)
+	}
+	cli := &http.Client{
+		Jar:     jar,
+		Timeout: time.Second,
+	}
+
+	// #1: initial GET receives CSRF cookie and token
+	var tok1 string
+	t.Run("#1 initial GET", func(t *testing.T) {
+		res := makeCall(t, cli, "GET", srv.URL+"/", "", true)
+		if res.StatusCode != 200 {
+			t.Fatalf("want GET #1 status 200, got %d", res.StatusCode)
+		}
+		tok1 = res.Header.Get("X-CSRF-Token")
+		if tok1 == "" {
+			t.Fatalf("want csrf token header to be set")
+		}
+	})
+
+	// #2: POST without CSRF token
+	t.Run("#2 POST without token", func(t *testing.T) {
+		res := makeCall(t, cli, "POST", srv.URL+"/", "", false)
+		if res.StatusCode != 403 {
+			t.Fatalf("want POST #2 status 403, got %d", res.StatusCode)
+		}
+	})
+
+	// #3: POST with GET #1 token succeeds
+	t.Run("#3 POST with token", func(t *testing.T) {
+		res := makeCall(t, cli, "POST", srv.URL+"/", tok1, false)
+		if res.StatusCode != 200 {
+			t.Fatalf("want POST #3 status 200, got %d", res.StatusCode)
+		}
+	})
+
+	// #4: GET receives new CSRF cookie and token
+	var tok4 string
+	t.Run("#4 GET new token", func(t *testing.T) {
+		res := makeCall(t, cli, "GET", srv.URL+"/", "", true)
+		if res.StatusCode != 200 {
+			t.Fatalf("want GET #4 status 200, got %d", res.StatusCode)
+		}
+		tok4 = res.Header.Get("X-CSRF-Token")
+	})
+
+	// #5: POST with valid token succeeds
+	t.Run("#5 POST with token", func(t *testing.T) {
+		res := makeCall(t, cli, "POST", srv.URL+"/", tok4, false)
+		if res.StatusCode != 200 {
+			t.Fatalf("want POST #5 status 200, got %d", res.StatusCode)
+		}
+	})
+}
+
+func makeCall(t *testing.T, cli *http.Client, method, path, tok string, checkCookie bool) *http.Response {
+	req, err := http.NewRequest(method, path, nil)
+	if err != nil {
+		t.Fatalf("failed to create request: %s", err)
+	}
+	if tok != "" {
+		req.Header.Set("X-CSRF-Token", tok)
+	}
+	res, err := cli.Do(req)
+	if err != nil {
+		t.Fatalf("failed call: %s", err)
+	}
+	defer res.Body.Close()
+
+	if checkCookie {
+		cookies := cli.Jar.Cookies(res.Request.URL)
+		ck := cookieByName(cookies, "csrf")
+		if ck == nil || ck.Value == "" {
+			t.Fatalf("want csrf cookie to be set")
+		}
+	}
+
+	return res
+}
+
+func cookieByName(cookies []*http.Cookie, nm string) *http.Cookie {
+	for _, c := range cookies {
+		if c.Name == nm {
+			return c
+		}
+	}
+	return nil
+}

M go.mod => go.mod +1 -0
@@ 6,6 6,7 @@ require (
 	github.com/NYTimes/gziphandler v1.1.1
 	github.com/gorilla/csrf v1.5.1
 	github.com/gorilla/handlers v1.4.1-0.20190227193432-ac6d24f88de4
+	github.com/gorilla/securecookie v1.1.1
 	github.com/juju/ratelimit v1.0.1
 	github.com/julienschmidt/httprouter v1.2.0
 	github.com/kr/pretty v0.1.0 // indirect