~samwhited/mux

dc885c17a3b47d967b15f813d3cb6f8be43b90b5 — Sam Whited 1 year, 7 months ago b7441e8
mux: support OPTIONS requests by default
6 files changed, 208 insertions(+), 88 deletions(-)

D handler_test.go
R notfound.go => handlers.go
A handlers_test.go
M mux.go
M node.go
M params.go
D handler_test.go => handler_test.go +0 -57
@@ 1,57 0,0 @@
package mux_test

import (
	"net/http"
	"net/http/httptest"
	"strconv"
	"testing"

	"code.soquee.net/mux"
)

var notFoundTests = [...]struct {
	h    http.HandlerFunc
	code int
}{
	0: {
		h: func(w http.ResponseWriter, r *http.Request) {
			w.WriteHeader(http.StatusPaymentRequired)
		},
		code: http.StatusPaymentRequired,
	},
	1: {
		h: func(w http.ResponseWriter, r *http.Request) {
			w.WriteHeader(http.StatusMultiStatus)
			_, err := w.Write([]byte("Test"))
			if err != nil {
				panic(err)
			}
		},
		code: http.StatusMultiStatus,
	},
	2: {
		h: func(w http.ResponseWriter, r *http.Request) {
			_, err := w.Write([]byte("Test"))
			if err != nil {
				panic(err)
			}
		},
		code: http.StatusNotFound,
	},
}

// NotFound handlers should always return a 404. If the underlying handler
// doesn't explicitly call WriteHeader, we should set a 404 as the default
// instead of a 200.
func TestNotFoundAlwaysSetsStatusCode(t *testing.T) {
	for i, tc := range notFoundTests {
		t.Run(strconv.Itoa(i), func(t *testing.T) {
			m := mux.New(mux.NotFound(tc.h))
			rec := httptest.NewRecorder()
			m.ServeHTTP(rec, httptest.NewRequest("GET", "/", nil))
			if rec.Code != tc.code {
				t.Errorf("Wrong status code from handler: want=%d, got=%d", tc.code, rec.Code)
			}
		})
	}
}

R notfound.go => handlers.go +12 -0
@@ 2,6 2,7 @@ package mux

import (
	"net/http"
	"strings"
)

// defCodeWriter is an http.ResponseWriter that writes the given status code by


@@ 33,3 34,14 @@ func notFoundHandler(h http.Handler) http.HandlerFunc {
		}, r)
	}
}

func defOptions(node node) http.Handler {
	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
		var verbs []string
		for v, _ := range node.handlers {
			verbs = append(verbs, v)
		}
		w.Header().Add("Allow", strings.Join(verbs, ","))
		w.Write(nil)
	})
}

A handlers_test.go => handlers_test.go +131 -0
@@ 0,0 1,131 @@
package mux_test

import (
	"net/http"
	"net/http/httptest"
	"sort"
	"strconv"
	"strings"
	"testing"

	"code.soquee.net/mux"
)

const (
	testBody = "Test"
)

var handlerTests = [...]struct {
	opts     func(t *testing.T) []mux.Option
	method   string
	req      string
	code     int
	respBody string
	header   http.Header
}{
	0: {
		opts: func(t *testing.T) []mux.Option {
			return []mux.Option{
				mux.NotFound(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
					w.WriteHeader(http.StatusPaymentRequired)
				})),
			}
		},
		code: http.StatusPaymentRequired,
	},
	1: {
		opts: func(t *testing.T) []mux.Option {
			return []mux.Option{
				mux.NotFound(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
					w.WriteHeader(http.StatusMultiStatus)
					_, err := w.Write([]byte(testBody))
					if err != nil {
						panic(err)
					}
				})),
			}
		},
		code:     http.StatusMultiStatus,
		respBody: testBody,
	},
	2: {
		opts: func(t *testing.T) []mux.Option {
			return []mux.Option{
				mux.NotFound(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
					_, err := w.Write([]byte(testBody))
					if err != nil {
						panic(err)
					}
				})),
			}
		},
		code:     http.StatusNotFound,
		respBody: testBody,
	},
	3: {
		opts: func(t *testing.T) []mux.Option {
			return []mux.Option{
				mux.Handle("GET", "/", failHandler(t)),
				mux.Handle("POST", "/", failHandler(t)),
				mux.Handle("PUT", "/test", failHandler(t)),
			}
		},
		method: http.MethodOptions,
		code:   http.StatusOK,
		header: map[string][]string{
			"Allow": []string{"GET,POST"},
		},
	},
	4: {
		method: http.MethodOptions,
		code:   http.StatusOK,
		header: map[string][]string{
			"Allow": []string{""},
		},
	},
}

func TestHandlers(t *testing.T) {
	for i, tc := range handlerTests {
		t.Run(strconv.Itoa(i), func(t *testing.T) {
			if tc.opts == nil {
				tc.opts = func(*testing.T) []mux.Option { return []mux.Option{} }
			}
			m := mux.New(tc.opts(t)...)
			rec := httptest.NewRecorder()
			if tc.req == "" {
				tc.req = "/"
			}
			if tc.method == "" {
				tc.method = http.MethodGet
			}
			m.ServeHTTP(rec, httptest.NewRequest(tc.method, tc.req, nil))
			if rec.Code != tc.code {
				t.Errorf("Unexpected status code: want=%d, got=%d", tc.code, rec.Code)
			}
			if s := rec.Body.String(); s != tc.respBody {
				t.Errorf("Unexpected response body: want=%q, got=%q", tc.respBody, s)
			}
			for k, _ := range tc.header {
				var v, vv string
				if k == "Allow" {
					// Sort "Allow" headers as a special case so that we don't have to do
					// a sort or anything in the actual handler.
					methods := strings.Split(tc.header.Get(k), ",")
					sort.Strings(methods)
					v = strings.Join(methods, ",")

					methods = strings.Split(rec.HeaderMap.Get(k), ",")
					sort.Strings(methods)
					vv = strings.Join(methods, ",")
				} else {
					v = tc.header.Get(k)
					vv = rec.HeaderMap.Get(k)
				}
				if vv != v {
					t.Errorf("Unexpected value for header %q: want=%q, got=%q", k, v, vv)
				}
			}
		})
	}
}

M mux.go => mux.go +55 -24
@@ 58,6 58,24 @@ const (
type ServeMux struct {
	node
	notFound http.Handler
	options  func(node) http.Handler
}

// New allocates and returns a new ServeMux.
func New(opts ...Option) *ServeMux {
	mux := &ServeMux{
		node: node{
			name:     "/",
			typ:      typStatic,
			handlers: make(map[string]http.Handler),
		},
		notFound: http.HandlerFunc(http.NotFound),
		options:  defOptions,
	}
	for _, o := range opts {
		o(mux)
	}
	return mux
}

// ServeHTTP dispatches the request to the handler whose pattern most closely


@@ 116,15 134,14 @@ func (mux *ServeMux) Handler(r *http.Request) (http.Handler, *http.Request) {
func (mux *ServeMux) handler(r *http.Request) (http.Handler, *http.Request) {
	// TODO: Add /tree to /tree/ redirect option and apply here.
	// TODO: use host
	host := r.Host
	_ = host
	//host := r.URL.Host
	path := r.URL.Path

	// CONNECT requests are not canonicalized
	if r.Method != http.MethodConnect {
		// All other requests have any port stripped and path cleaned
		// before passing to mux.handler.
		host = stripHostPort(r.Host)
		//host = stripHostPort(r.Host)
		path = cleanPath(r.URL.Path)
		if path != r.URL.Path {
			url := *r.URL


@@ 136,10 153,15 @@ func (mux *ServeMux) handler(r *http.Request) (http.Handler, *http.Request) {
	// TODO: add host based matching and check it here.
	node := &mux.node
	path = strings.TrimPrefix(path, "/")

	// Requests for /
	if path == "" {
		h, ok := node.handlers[r.Method]
		h, ok := mux.node.handlers[r.Method]
		if !ok {
			// TODO: method not supported vs not found config
			if r.Method == http.MethodOptions && mux.options != nil {
				return mux.options(mux.node), r
			}
			return mux.notFound, r
		}
		return h, r


@@ 149,7 171,7 @@ func (mux *ServeMux) handler(r *http.Request) (http.Handler, *http.Request) {

nodeloop:
	for node != nil {
		// If this is a variable route,
		// If this is a variable route
		if len(node.child) == 1 && node.child[0].typ != typStatic {
			var part, remain string
			part, remain, r = node.child[0].match(path, offset, r)


@@ 166,6 188,9 @@ nodeloop:
				h, ok := node.child[0].handlers[r.Method]
				if !ok {
					// TODO: method not supported vs not found config
					if r.Method == http.MethodOptions && mux.options != nil {
						return mux.options(node.child[0]), r
					}
					return mux.notFound, r
				}
				return h, r


@@ 192,6 217,9 @@ nodeloop:
				h, ok := child.handlers[r.Method]
				if !ok {
					// TODO: method not supported vs not found config
					if r.Method == http.MethodOptions && mux.options != nil {
						return mux.options(child), r
					}
					return mux.notFound, r
				}
				return h, r


@@ 210,25 238,6 @@ nodeloop:
	return mux.notFound, r
}

// ctxParam is a type used for context keys that contain route parameters.
type ctxParam string

// New allocates and returns a new ServeMux.
func New(opts ...Option) *ServeMux {
	mux := &ServeMux{
		node: node{
			name:     "/",
			typ:      typStatic,
			handlers: make(map[string]http.Handler),
		},
		notFound: http.HandlerFunc(http.NotFound),
	}
	for _, o := range opts {
		o(mux)
	}
	return mux
}

// Option is used to configure a ServeMux.
type Option func(*ServeMux)



@@ 245,6 254,28 @@ func NotFound(h http.Handler) Option {
	}
}

// The ServeMux handles OPTIONS requests by default. If you do not want this
// behavior, set f to "nil".
//
// Registering handlers for OPTIONS requests on a specific path always overrides
// the default handler.
func DefaultOptions(f func([]string) http.Handler) Option {
	return func(mux *ServeMux) {
		if f == nil {
			mux.options = nil
			return
		}

		mux.options = func(n node) http.Handler {
			var verbs []string
			for v, _ := range n.handlers {
				verbs = append(verbs, v)
			}
			return f(verbs)
		}
	}
}

// Handle registers the handler for the given pattern.
// If a handler already exists for pattern, Handle panics.
func HandleFunc(method, r string, h http.HandlerFunc) Option {

M node.go => node.go +7 -7
@@ 63,14 63,14 @@ func (n *node) match(path string, offset uint, r *http.Request) (part string, re
}

func addValue(r *http.Request, name, typ, raw string, offset uint, val interface{}) *http.Request {
	pinfo := ParamInfo{
		Value:  val,
		Raw:    raw,
		Name:   name,
		Type:   typ,
		Offset: offset,
	}
	if name != "" {
		pinfo := ParamInfo{
			Value:  val,
			Raw:    raw,
			Name:   name,
			Type:   typ,
			Offset: offset,
		}
		return r.WithContext(context.WithValue(r.Context(), ctxParam(name), pinfo))
	}
	return r

M params.go => params.go +3 -0
@@ 4,6 4,9 @@ import (
	"net/http"
)

// ctxParam is a type used for context keys that contain route parameters.
type ctxParam string

// ParamInfo represents a route parameter and related metadata.
type ParamInfo struct {
	// The parsed value of the parameter (for example int64(10))