164e80fe576000d090daf286ba66c4d096f7a898 — Martin Angers 7 months ago e42652b
test more middleware configuration cases
3 files changed, 315 insertions(+), 8 deletions(-)

M TODO.md
M builder/builder.go
M builder/builder_test.go
M TODO.md => TODO.md +3 -2
@@ 9,6 9,7 @@
 7. 3rd party services/api calls with circuit breaker (also microservices)
 8. Batch, async & one-off tools
 9. Redirect http to https
-10. Reverse proxy in front
+10. Reverse proxy in front (e.g. distinct web + api)
 11. Structured logging
-
+12. Access path / router variables from handler
+13. Disable http/2 option (maybe to support websocket)

M builder/builder.go => builder/builder.go +3 -3
@@ 163,6 163,9 @@
 
 	var mw []middleware
 
+	if fn := s.Root.LoggingFunc; fn != nil {
+		mw = append(mw, logging(fn))
+	}
 	if fn := s.Root.PanicRecoveryFunc; fn != nil {
 		mw = append(mw, panicRecovery(fn))
 	}


@@ 179,9 182,6 @@
 		sec := secure.New(*s.Root.SecurityHeaders)
 		mw = append(mw, sec.Handler)
 	}
-	if fn := s.Root.LoggingFunc; fn != nil {
-		mw = append(mw, logging(fn))
-	}
 
 	if s.Root.CanonicalHost != "" {
 		redir := s.Root.CanonicalHostRedirectStatusCode

M builder/builder_test.go => builder/builder_test.go +309 -3
@@ 1,16 1,23 @@
 package builder
 
 import (
+	"bytes"
+	"crypto/rand"
 	"crypto/tls"
+	"encoding/json"
+	"encoding/xml"
 	"fmt"
 	"io"
+	"io/ioutil"
 	"net/http"
 	"net/http/httptest"
 	"os"
 	"strings"
 	"testing"
+	"time"
 
 	"git.sr.ht/~mna/kick"
+	"github.com/sirupsen/logrus"
 	"github.com/unrolled/secure"
 )
 


@@ 269,6 276,24 @@
 }
 
 func TestHandler_Valid(t *testing.T) {
+	var logBuf bytes.Buffer
+	logrus.SetOutput(&logBuf)
+
+	const (
+		maxResponseBody = 10
+		maxRequestBody  = 20
+	)
+
+	const xmlDoc = xml.Header + `<root>
+  <value>hello</value>
+</root>`
+
+	const jsonDoc = `{
+  "value": "hello"
+}`
+
+	t.Logf("xml uncompressed body size: %d; json uncompressed body size: %d", len(xmlDoc), len(jsonDoc))
+
 	s := &kick.Server{
 		Root: &kick.Root{
 			MethodNotAllowedHandler: statusHandler(405),


@@ 276,12 301,17 @@
 			PanicRecoveryFunc: func(w http.ResponseWriter, r *http.Request, v interface{}) {
 				w.WriteHeader(500)
 			},
+			LoggingFunc: func(w http.ResponseWriter, r *http.Request, info map[string]interface{}) {
+				logrus.WithFields(info).Info()
+			},
 			TrustProxyHeaders:   true,
 			AllowMethodOverride: true,
 			RequestIDHeader:     "X-Request-Id",
 			SecurityHeaders: &secure.Options{
 				FrameDeny: true,
 			},
+			CanonicalHostRedirectStatusCode: 301,
+			CanonicalHost:                   "http://example.com",
 			Gzip: &kick.GzipConfig{
 				ContentTypes: []string{"application/xml"},
 			},


@@ 293,12 323,106 @@
 				Handler: statusHandler(200),
 			},
 			{
+				Method:  "GET",
+				Path:    "/noslash",
+				Handler: statusHandler(200),
+			},
+			{
+				Method:  "PUT",
+				Path:    "/noslash",
+				Handler: statusHandler(204),
+			},
+			{
+				Method:  "GET",
+				Path:    "/slash/",
+				Handler: statusHandler(200),
+			},
+			{
+				Method: "GET",
+				Path:   "/sleep",
+				Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+					dur, err := time.ParseDuration(r.URL.Query().Get("duration"))
+					if err != nil {
+						panic(err)
+					}
+					time.Sleep(dur)
+					w.WriteHeader(200)
+				}),
+				Config: &kick.HandlerConfig{
+					HandlerTimeout: 100 * time.Millisecond,
+				},
+			},
+			{
+				Method: "GET",
+				Path:   "/body",
+				Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+					ct := w.Header().Get("Content-Type")
+					switch ct {
+					case "application/xml":
+						fmt.Fprint(w, xmlDoc)
+					case "application/json":
+						fmt.Fprint(w, jsonDoc)
+					default:
+						w.WriteHeader(400)
+					}
+				}),
+				Config: &kick.HandlerConfig{
+					ResponseContentTypes: []string{"application/json", "application/xml"},
+				},
+			},
+			{
+				Method: "GET",
+				Path:   "/validate",
+				Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+					q := r.URL.Query()
+					if scheme := q.Get("scheme"); scheme != "" {
+						if scheme != r.URL.Scheme {
+							w.WriteHeader(400)
+							return
+						}
+					}
+					if ra := q.Get("remote"); ra != "" {
+						if ra != r.RemoteAddr {
+							w.WriteHeader(400)
+							return
+						}
+					}
+					w.WriteHeader(200)
+				}),
+			},
+			{
 				Method: "POST",
 				Path:   "/panic",
 				Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 					panic(io.EOF)
 				}),
 			},
+			{
+				Method: "POST",
+				Path:   "/random",
+				Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+					b, err := ioutil.ReadAll(r.Body)
+					if err != nil {
+						if strings.Contains(err.Error(), "too large") {
+							w.WriteHeader(413)
+						} else {
+							w.WriteHeader(500)
+						}
+						return
+					}
+					var m map[string]int64
+					if err := json.Unmarshal(b, &m); err != nil {
+						w.WriteHeader(400)
+						return
+					}
+					io.CopyN(w, rand.Reader, m["size"])
+				}),
+				Config: &kick.HandlerConfig{
+					RequestContentTypes:  []string{"application/json"},
+					MaxResponseBodyBytes: maxResponseBody,
+					MaxRequestBodyBytes:  maxRequestBody,
+				},
+			},
 		},
 	}
 


@@ 307,21 431,203 @@
 		t.Fatalf("want no error, got %s", err)
 	}
 
+	newReq := func(s string) *http.Request {
+		const partCount = 3
+
+		parts := strings.SplitN(s, " ", partCount)
+		for len(parts) < partCount {
+			parts = append(parts, "")
+		}
+
+		req := httptest.NewRequest(parts[0], parts[1], strings.NewReader(parts[2]))
+		return req
+	}
+
+	headers := func(s string) http.Header {
+		if s == "" {
+			return http.Header{}
+		}
+
+		kvs := strings.Split(s, " ")
+
+		h := make(http.Header, len(kvs))
+		for _, kv := range kvs {
+			ix := strings.Index(kv, ":")
+			if ix < 0 {
+				t.Fatalf("header key-value must contain ':': %s", kv)
+			}
+			h.Add(kv[:ix], kv[ix+1:])
+		}
+		return h
+	}
+
 	cases := []struct {
 		req  string // formatted as "METHOD /path/... body"
 		reqh string // formatted as "Name:Value Name:Value"
 		code int
-		resh string // formatted as "Name:Value", * means any value
+		resh string // formatted as "Name:Value", * means any value, - means not set
 	}{
 		{
 			req:  "GET /",
 			code: 200,
-			resh: "X-Request-Id:*",
+			resh: "X-Request-Id:* X-Frame-Options:DENY",
+		},
+		{
+			req:  "GET /",
+			reqh: "X-Request-Id:abcd",
+			code: 200,
+			resh: "X-Request-Id:abcd",
+		},
+		{
+			req:  "GET /noslash/",
+			code: 301,
+			resh: "Location:/noslash",
+		},
+		{
+			req:  "GET /slash",
+			code: 301,
+			resh: "Location:/slash/",
+		},
+		{
+			req:  "GET /noslash/.//",
+			code: 301,
+			resh: "Location:/noslash",
+		},
+		{
+			req:  "GET /panic",
+			code: 405,
+			resh: "Allow:POST,OPTIONS",
+		},
+		{
+			req:  "GET /notexist",
+			code: 404,
+		},
+		{
+			req:  "POST /panic",
+			code: 500,
+		},
+		{
+			req:  "GET http://example.com/validate?scheme=https",
+			code: 400,
+		},
+		{
+			req:  "GET http://example.com/validate?scheme=https",
+			reqh: "X-Forwarded-Proto:https",
+			code: 200,
+		},
+		{
+			req:  "GET /validate?remote=1.2.3.4",
+			code: 400,
+		},
+		{
+			req:  "GET /validate?remote=1.2.3.4",
+			reqh: "X-Forwarded-For:1.2.3.4",
+			code: 200,
+		},
+		{
+			req:  "POST /noslash",
+			code: 405,
+		},
+		{
+			req:  "POST /noslash",
+			reqh: "X-HTTP-Method-Override:PUT",
+			code: 204,
+		},
+		{
+			req:  "GET http://www.example.com/",
+			code: 301,
+			resh: "Location:http://example.com/",
+		},
+		{
+			req:  "GET /body",
+			reqh: "Accept:text/plain",
+			code: 406,
+		},
+		{
+			req:  "GET /body",
+			reqh: "Accept:application/json Accept-Encoding:gzip",
+			code: 200,
+			resh: "Content-Encoding:-",
+		},
+		{
+			req:  "GET /body",
+			reqh: "Accept:application/xml",
+			code: 200,
+			resh: "Content-Encoding:-",
+		},
+		{
+			req:  "GET /body",
+			reqh: "Accept:application/xml Accept-Encoding:gzip",
+			code: 200,
+			resh: "Content-Encoding:gzip",
+		},
+		{
+			req:  "GET /sleep?duration=10ms",
+			code: 200,
+		},
+		{
+			req:  "GET /sleep?duration=110ms",
+			code: 503,
+		},
+		{
+			req:  "POST /random some-plain-text-body",
+			reqh: "Content-Type:text/plain",
+			code: 415,
+		},
+		{
+			req:  "POST /random {\"size\":4}",
+			reqh: "Content-Type:application/json",
+			code: 200,
+		},
+		{
+			req:  fmt.Sprintf("POST /random {\"size\":%d}", maxResponseBody+1),
+			reqh: "Content-Type:application/json",
+			code: 500,
+		},
+		{
+			req:  "POST /random {\"bloat\":\"abcdefghijklmnopqrstuvwxyz\",\"size\":1}",
+			reqh: "Content-Type:application/json",
+			code: 413,
 		},
 	}
 	for _, c := range cases {
 		t.Run(fmt.Sprintf("%s %s", c.req, c.reqh), func(t *testing.T) {
-			_ = h
+			logBuf.Reset()
+
+			req := newReq(c.req)
+			req.Header = headers(c.reqh)
+			res := httptest.NewRecorder()
+			h.ServeHTTP(res, req)
+
+			if res.Code != c.code {
+				t.Fatalf("want status code %d, got %d", c.code, res.Code)
+			}
+			if !strings.Contains(logBuf.String(), fmt.Sprintf("status=%d", c.code)) {
+				t.Fatalf("unexpected log message: %s", logBuf.String())
+			}
+
+			wanth := headers(c.resh)
+			for k := range wanth {
+				want := wanth.Get(k)
+				got := strings.Replace(res.Header().Get(k), " ", "", -1)
+
+				if want == "*" {
+					if got == "" {
+						t.Fatalf("want header %q to have a value", k)
+					}
+					continue
+				}
+				if want == "-" {
+					if got != "" {
+						t.Fatalf("want header %q to not be set, got %q", k, got)
+					}
+					continue
+				}
+				if want != got {
+					t.Fatalf("want header %q=%q, got %q", k, want, got)
+				}
+			}
+			t.Log(logBuf.String())
 		})
 	}
 }