41e6ab767d21c9792c129d6f034cfae3b59860a2 — Martin Angers 7 months ago ffc2e85
builder: support OPTIONS and HEAD requests, more tests
3 files changed, 91 insertions(+), 12 deletions(-)

M builder/builder.go
M builder/builder_test.go
M builder/middleware.go
M builder/builder.go => builder/builder.go +39 -3
@@ 128,9 128,12 @@
 // Handler builds the handler for the routes defined on the
 // server, including any required middleware.
 func Handler(s *kick.Server) (http.Handler, error) {
-	router := httprouter.New()
-	router.RedirectTrailingSlash = true
-	router.RedirectFixedPath = true
+	router := &httprouter.Router{
+		RedirectTrailingSlash:  true,
+		RedirectFixedPath:      true,
+		HandleMethodNotAllowed: false,
+		HandleOPTIONS:          true,
+	}
 
 	var mw []middleware
 	var err error


@@ 148,6 151,11 @@
 		}
 	}
 
+	// once all routes have been registered, add any missing HEAD
+	// and OPTIONS handlers (for existing GET and CORS-configured
+	// routes, respectively).
+	configureGetAndOptions(router, s.Routes)
+
 	h := combineMiddleware(router, mw)
 	return h, nil
 }


@@ 298,6 306,34 @@
 	return nil
 }
 
+func configureGetAndOptions(router *httprouter.Router, routes []*kick.Route) {
+	for _, route := range routes {
+		switch {
+		case route.Method == "GET":
+			// configure a HEAD handler if there is not already one
+			h, _, _ := router.Lookup("HEAD", route.Path)
+			if h != nil {
+				continue
+			}
+
+			// get the existing handler and register it under HEAD
+			geth, _, _ := router.Lookup(route.Method, route.Path)
+			router.HEAD(route.Path, geth)
+
+		case route.Config != nil && route.Config.CORS != nil:
+			// configure an OPTIONS handler if there is not already one
+			h, _, _ := router.Lookup("OPTIONS", route.Path)
+			if h != nil {
+				continue
+			}
+
+			// get the existing handler and register it under OPTIONS
+			curh, _, _ := router.Lookup(route.Method, route.Path)
+			router.OPTIONS(route.Path, curh)
+		}
+	}
+}
+
 func combineMiddleware(h http.Handler, mw []middleware) http.Handler {
 	for i := len(mw) - 1; i >= 0; i-- {
 		h = mw[i](h)

M builder/builder_test.go => builder/builder_test.go +45 -8
@@ 12,6 12,7 @@
 	"net/http"
 	"net/http/httptest"
 	"os"
+	"strconv"
 	"strings"
 	"testing"
 	"time"


@@ 294,6 295,11 @@
 
 	t.Logf("xml uncompressed body size: %d; json uncompressed body size: %d", len(xmlDoc), len(jsonDoc))
 
+	noslashCORS := &kick.CORSConfig{
+		AllowedMethods: []string{"GET", "PUT"},
+		AllowedOrigins: []string{"http://notexample.com", "notexample.com"},
+	}
+
 	s := &kick.Server{
 		Root: &kick.Root{
 			MethodNotAllowedHandler: statusHandler(405),


@@ 326,16 332,16 @@
 				Method:  "GET",
 				Path:    "/noslash",
 				Handler: statusHandler(200),
+				Config: &kick.HandlerConfig{
+					CORS: noslashCORS,
+				},
 			},
 			{
 				Method:  "PUT",
 				Path:    "/noslash",
 				Handler: statusHandler(204),
 				Config: &kick.HandlerConfig{
-					CORS: &kick.CORSConfig{
-						AllowedMethods: []string{"PUT"},
-						AllowedOrigins: []string{"example.com"},
-					},
+					CORS: noslashCORS,
 				},
 			},
 			{


@@ 365,8 371,10 @@
 					ct := w.Header().Get("Content-Type")
 					switch ct {
 					case "application/xml":
+						w.Header().Add("Content-Length", strconv.Itoa(len(xmlDoc)))
 						fmt.Fprint(w, xmlDoc)
 					case "application/json":
+						w.Header().Add("Content-Length", strconv.Itoa(len(jsonDoc)))
 						fmt.Fprint(w, jsonDoc)
 					default:
 						w.WriteHeader(400)


@@ 485,6 493,11 @@
 		resh string // formatted as "Name:Value", * means any value, - means not set
 	}{
 		{
+			req:  "OPTIONS *",
+			code: 200,
+			resh: "Allow:*",
+		},
+		{
 			req:  "GET /",
 			code: 200,
 			resh: "X-Request-Id:* X-Frame-Options:DENY",


@@ 573,6 586,12 @@
 			resh: "Content-Encoding:-",
 		},
 		{
+			req:  "HEAD /body",
+			reqh: "Accept:application/xml",
+			code: 200,
+			resh: fmt.Sprintf("Content-Length:%d", len(xmlDoc)),
+		},
+		{
 			req:  "GET /body",
 			reqh: "Accept:application/xml Accept-Encoding:gzip",
 			code: 200,


@@ 618,9 637,20 @@
 		},
 		{
 			req:  "OPTIONS /noslash",
-			reqh: "Origin:http://notexample.com Access-Control-Request-Method:PUT",
-			code: 204,
-			resh: "Access-Control-Allow-Origin:x",
+			reqh: "Origin:notexample.com Access-Control-Request-Method:PUT",
+			code: 200,
+			resh: "Access-Control-Allow-Origin:notexample.com Access-Control-Allow-Methods:PUT",
+		},
+		{
+			req:  "OPTIONS /noslash",
+			reqh: "Origin:notexample.com Access-Control-Request-Method:PATCH",
+			code: 405,
+		},
+		{
+			req:  "GET /noslash",
+			reqh: "Origin:notexample.com",
+			code: 200,
+			resh: "Access-Control-Allow-Origin:notexample.com",
 		},
 	}
 	for _, c := range cases {


@@ 657,9 687,16 @@
 					continue
 				}
 				if want != got {
-					t.Fatalf("want header %q=%q, got %q", k, want, got)
+					t.Fatalf("want header %q=%q, got %q (headers: %v)", k, want, got, res.Header())
 				}
 			}
+
+			// TODO: HEAD body should be verified by an actual client call
+			// to that endpoint - I think the ResponseWriter sees the writes,
+			// but they wouldn't make it to the wire and the client wouldn't
+			// see them.
+			// TODO: same for CSRF, should be tested via a client with a
+			// CookieJar.
 			t.Log(logBuf.String())
 		})
 	}

M builder/middleware.go => builder/middleware.go +7 -1
@@ 206,7 206,13 @@
 				Size() int64
 			}); ok {
 				m["body_bytes_sent"] = w.Size()
-				m["status"] = w.Status()
+				status := w.Status()
+				if status == 0 {
+					// can happen if returning without writing a body nor
+					// an explicit status.
+					status = 200
+				}
+				m["status"] = status
 			}
 			logFn(w, r, m)
 		})