~mna/kick

164e80fe576000d090daf286ba66c4d096f7a898 — Martin Angers 1 year, 7 months ago e42652b
test more middleware configuration cases
3 files changed, 315 insertions(+), 8 deletions(-)

M TODO.md
M builder/builder.go
M builder/builder_test.go
M TODO.md => TODO.md +3 -2
@@ 9,6 9,7 @@
7. 3rd party services/api calls with circuit breaker (also microservices)
8. Batch, async & one-off tools
9. Redirect http to https
10. Reverse proxy in front
10. Reverse proxy in front (e.g. distinct web + api)
11. Structured logging

12. Access path / router variables from handler
13. Disable http/2 option (maybe to support websocket)

M builder/builder.go => builder/builder.go +3 -3
@@ 163,6 163,9 @@ func rootMiddleware(s *kick.Server, router *httprouter.Router) ([]middleware, er

	var mw []middleware

	if fn := s.Root.LoggingFunc; fn != nil {
		mw = append(mw, logging(fn))
	}
	if fn := s.Root.PanicRecoveryFunc; fn != nil {
		mw = append(mw, panicRecovery(fn))
	}


@@ 179,9 182,6 @@ func rootMiddleware(s *kick.Server, router *httprouter.Router) ([]middleware, er
		sec := secure.New(*s.Root.SecurityHeaders)
		mw = append(mw, sec.Handler)
	}
	if fn := s.Root.LoggingFunc; fn != nil {
		mw = append(mw, logging(fn))
	}

	if s.Root.CanonicalHost != "" {
		redir := s.Root.CanonicalHostRedirectStatusCode

M builder/builder_test.go => builder/builder_test.go +309 -3
@@ 1,16 1,23 @@
package builder

import (
	"bytes"
	"crypto/rand"
	"crypto/tls"
	"encoding/json"
	"encoding/xml"
	"fmt"
	"io"
	"io/ioutil"
	"net/http"
	"net/http/httptest"
	"os"
	"strings"
	"testing"
	"time"

	"git.sr.ht/~mna/kick"
	"github.com/sirupsen/logrus"
	"github.com/unrolled/secure"
)



@@ 269,6 276,24 @@ func TestHandler_InvalidRoute(t *testing.T) {
}

func TestHandler_Valid(t *testing.T) {
	var logBuf bytes.Buffer
	logrus.SetOutput(&logBuf)

	const (
		maxResponseBody = 10
		maxRequestBody  = 20
	)

	const xmlDoc = xml.Header + `<root>
  <value>hello</value>
</root>`

	const jsonDoc = `{
  "value": "hello"
}`

	t.Logf("xml uncompressed body size: %d; json uncompressed body size: %d", len(xmlDoc), len(jsonDoc))

	s := &kick.Server{
		Root: &kick.Root{
			MethodNotAllowedHandler: statusHandler(405),


@@ 276,12 301,17 @@ func TestHandler_Valid(t *testing.T) {
			PanicRecoveryFunc: func(w http.ResponseWriter, r *http.Request, v interface{}) {
				w.WriteHeader(500)
			},
			LoggingFunc: func(w http.ResponseWriter, r *http.Request, info map[string]interface{}) {
				logrus.WithFields(info).Info()
			},
			TrustProxyHeaders:   true,
			AllowMethodOverride: true,
			RequestIDHeader:     "X-Request-Id",
			SecurityHeaders: &secure.Options{
				FrameDeny: true,
			},
			CanonicalHostRedirectStatusCode: 301,
			CanonicalHost:                   "http://example.com",
			Gzip: &kick.GzipConfig{
				ContentTypes: []string{"application/xml"},
			},


@@ 293,12 323,106 @@ func TestHandler_Valid(t *testing.T) {
				Handler: statusHandler(200),
			},
			{
				Method:  "GET",
				Path:    "/noslash",
				Handler: statusHandler(200),
			},
			{
				Method:  "PUT",
				Path:    "/noslash",
				Handler: statusHandler(204),
			},
			{
				Method:  "GET",
				Path:    "/slash/",
				Handler: statusHandler(200),
			},
			{
				Method: "GET",
				Path:   "/sleep",
				Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
					dur, err := time.ParseDuration(r.URL.Query().Get("duration"))
					if err != nil {
						panic(err)
					}
					time.Sleep(dur)
					w.WriteHeader(200)
				}),
				Config: &kick.HandlerConfig{
					HandlerTimeout: 100 * time.Millisecond,
				},
			},
			{
				Method: "GET",
				Path:   "/body",
				Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
					ct := w.Header().Get("Content-Type")
					switch ct {
					case "application/xml":
						fmt.Fprint(w, xmlDoc)
					case "application/json":
						fmt.Fprint(w, jsonDoc)
					default:
						w.WriteHeader(400)
					}
				}),
				Config: &kick.HandlerConfig{
					ResponseContentTypes: []string{"application/json", "application/xml"},
				},
			},
			{
				Method: "GET",
				Path:   "/validate",
				Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
					q := r.URL.Query()
					if scheme := q.Get("scheme"); scheme != "" {
						if scheme != r.URL.Scheme {
							w.WriteHeader(400)
							return
						}
					}
					if ra := q.Get("remote"); ra != "" {
						if ra != r.RemoteAddr {
							w.WriteHeader(400)
							return
						}
					}
					w.WriteHeader(200)
				}),
			},
			{
				Method: "POST",
				Path:   "/panic",
				Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
					panic(io.EOF)
				}),
			},
			{
				Method: "POST",
				Path:   "/random",
				Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
					b, err := ioutil.ReadAll(r.Body)
					if err != nil {
						if strings.Contains(err.Error(), "too large") {
							w.WriteHeader(413)
						} else {
							w.WriteHeader(500)
						}
						return
					}
					var m map[string]int64
					if err := json.Unmarshal(b, &m); err != nil {
						w.WriteHeader(400)
						return
					}
					io.CopyN(w, rand.Reader, m["size"])
				}),
				Config: &kick.HandlerConfig{
					RequestContentTypes:  []string{"application/json"},
					MaxResponseBodyBytes: maxResponseBody,
					MaxRequestBodyBytes:  maxRequestBody,
				},
			},
		},
	}



@@ 307,21 431,203 @@ func TestHandler_Valid(t *testing.T) {
		t.Fatalf("want no error, got %s", err)
	}

	newReq := func(s string) *http.Request {
		const partCount = 3

		parts := strings.SplitN(s, " ", partCount)
		for len(parts) < partCount {
			parts = append(parts, "")
		}

		req := httptest.NewRequest(parts[0], parts[1], strings.NewReader(parts[2]))
		return req
	}

	headers := func(s string) http.Header {
		if s == "" {
			return http.Header{}
		}

		kvs := strings.Split(s, " ")

		h := make(http.Header, len(kvs))
		for _, kv := range kvs {
			ix := strings.Index(kv, ":")
			if ix < 0 {
				t.Fatalf("header key-value must contain ':': %s", kv)
			}
			h.Add(kv[:ix], kv[ix+1:])
		}
		return h
	}

	cases := []struct {
		req  string // formatted as "METHOD /path/... body"
		reqh string // formatted as "Name:Value Name:Value"
		code int
		resh string // formatted as "Name:Value", * means any value
		resh string // formatted as "Name:Value", * means any value, - means not set
	}{
		{
			req:  "GET /",
			code: 200,
			resh: "X-Request-Id:*",
			resh: "X-Request-Id:* X-Frame-Options:DENY",
		},
		{
			req:  "GET /",
			reqh: "X-Request-Id:abcd",
			code: 200,
			resh: "X-Request-Id:abcd",
		},
		{
			req:  "GET /noslash/",
			code: 301,
			resh: "Location:/noslash",
		},
		{
			req:  "GET /slash",
			code: 301,
			resh: "Location:/slash/",
		},
		{
			req:  "GET /noslash/.//",
			code: 301,
			resh: "Location:/noslash",
		},
		{
			req:  "GET /panic",
			code: 405,
			resh: "Allow:POST,OPTIONS",
		},
		{
			req:  "GET /notexist",
			code: 404,
		},
		{
			req:  "POST /panic",
			code: 500,
		},
		{
			req:  "GET http://example.com/validate?scheme=https",
			code: 400,
		},
		{
			req:  "GET http://example.com/validate?scheme=https",
			reqh: "X-Forwarded-Proto:https",
			code: 200,
		},
		{
			req:  "GET /validate?remote=1.2.3.4",
			code: 400,
		},
		{
			req:  "GET /validate?remote=1.2.3.4",
			reqh: "X-Forwarded-For:1.2.3.4",
			code: 200,
		},
		{
			req:  "POST /noslash",
			code: 405,
		},
		{
			req:  "POST /noslash",
			reqh: "X-HTTP-Method-Override:PUT",
			code: 204,
		},
		{
			req:  "GET http://www.example.com/",
			code: 301,
			resh: "Location:http://example.com/",
		},
		{
			req:  "GET /body",
			reqh: "Accept:text/plain",
			code: 406,
		},
		{
			req:  "GET /body",
			reqh: "Accept:application/json Accept-Encoding:gzip",
			code: 200,
			resh: "Content-Encoding:-",
		},
		{
			req:  "GET /body",
			reqh: "Accept:application/xml",
			code: 200,
			resh: "Content-Encoding:-",
		},
		{
			req:  "GET /body",
			reqh: "Accept:application/xml Accept-Encoding:gzip",
			code: 200,
			resh: "Content-Encoding:gzip",
		},
		{
			req:  "GET /sleep?duration=10ms",
			code: 200,
		},
		{
			req:  "GET /sleep?duration=110ms",
			code: 503,
		},
		{
			req:  "POST /random some-plain-text-body",
			reqh: "Content-Type:text/plain",
			code: 415,
		},
		{
			req:  "POST /random {\"size\":4}",
			reqh: "Content-Type:application/json",
			code: 200,
		},
		{
			req:  fmt.Sprintf("POST /random {\"size\":%d}", maxResponseBody+1),
			reqh: "Content-Type:application/json",
			code: 500,
		},
		{
			req:  "POST /random {\"bloat\":\"abcdefghijklmnopqrstuvwxyz\",\"size\":1}",
			reqh: "Content-Type:application/json",
			code: 413,
		},
	}
	for _, c := range cases {
		t.Run(fmt.Sprintf("%s %s", c.req, c.reqh), func(t *testing.T) {
			_ = h
			logBuf.Reset()

			req := newReq(c.req)
			req.Header = headers(c.reqh)
			res := httptest.NewRecorder()
			h.ServeHTTP(res, req)

			if res.Code != c.code {
				t.Fatalf("want status code %d, got %d", c.code, res.Code)
			}
			if !strings.Contains(logBuf.String(), fmt.Sprintf("status=%d", c.code)) {
				t.Fatalf("unexpected log message: %s", logBuf.String())
			}

			wanth := headers(c.resh)
			for k := range wanth {
				want := wanth.Get(k)
				got := strings.Replace(res.Header().Get(k), " ", "", -1)

				if want == "*" {
					if got == "" {
						t.Fatalf("want header %q to have a value", k)
					}
					continue
				}
				if want == "-" {
					if got != "" {
						t.Fatalf("want header %q to not be set, got %q", k, got)
					}
					continue
				}
				if want != got {
					t.Fatalf("want header %q=%q, got %q", k, want, got)
				}
			}
			t.Log(logBuf.String())
		})
	}
}