D handler_test.go => handler_test.go +0 -57
@@ 1,57 0,0 @@
-package mux_test
-
-import (
- "net/http"
- "net/http/httptest"
- "strconv"
- "testing"
-
- "code.soquee.net/mux"
-)
-
-var notFoundTests = [...]struct {
- h http.HandlerFunc
- code int
-}{
- 0: {
- h: func(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(http.StatusPaymentRequired)
- },
- code: http.StatusPaymentRequired,
- },
- 1: {
- h: func(w http.ResponseWriter, r *http.Request) {
- w.WriteHeader(http.StatusMultiStatus)
- _, err := w.Write([]byte("Test"))
- if err != nil {
- panic(err)
- }
- },
- code: http.StatusMultiStatus,
- },
- 2: {
- h: func(w http.ResponseWriter, r *http.Request) {
- _, err := w.Write([]byte("Test"))
- if err != nil {
- panic(err)
- }
- },
- code: http.StatusNotFound,
- },
-}
-
-// NotFound handlers should always return a 404. If the underlying handler
-// doesn't explicitly call WriteHeader, we should set a 404 as the default
-// instead of a 200.
-func TestNotFoundAlwaysSetsStatusCode(t *testing.T) {
- for i, tc := range notFoundTests {
- t.Run(strconv.Itoa(i), func(t *testing.T) {
- m := mux.New(mux.NotFound(tc.h))
- rec := httptest.NewRecorder()
- m.ServeHTTP(rec, httptest.NewRequest("GET", "/", nil))
- if rec.Code != tc.code {
- t.Errorf("Wrong status code from handler: want=%d, got=%d", tc.code, rec.Code)
- }
- })
- }
-}
R notfound.go => handlers.go +12 -0
@@ 2,6 2,7 @@ package mux
import (
"net/http"
+ "strings"
)
// defCodeWriter is an http.ResponseWriter that writes the given status code by
@@ 33,3 34,14 @@ func notFoundHandler(h http.Handler) http.HandlerFunc {
}, r)
}
}
+
+func defOptions(node node) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ var verbs []string
+ for v, _ := range node.handlers {
+ verbs = append(verbs, v)
+ }
+ w.Header().Add("Allow", strings.Join(verbs, ","))
+ w.Write(nil)
+ })
+}
A handlers_test.go => handlers_test.go +131 -0
@@ 0,0 1,131 @@
+package mux_test
+
+import (
+ "net/http"
+ "net/http/httptest"
+ "sort"
+ "strconv"
+ "strings"
+ "testing"
+
+ "code.soquee.net/mux"
+)
+
+const (
+ testBody = "Test"
+)
+
+var handlerTests = [...]struct {
+ opts func(t *testing.T) []mux.Option
+ method string
+ req string
+ code int
+ respBody string
+ header http.Header
+}{
+ 0: {
+ opts: func(t *testing.T) []mux.Option {
+ return []mux.Option{
+ mux.NotFound(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusPaymentRequired)
+ })),
+ }
+ },
+ code: http.StatusPaymentRequired,
+ },
+ 1: {
+ opts: func(t *testing.T) []mux.Option {
+ return []mux.Option{
+ mux.NotFound(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusMultiStatus)
+ _, err := w.Write([]byte(testBody))
+ if err != nil {
+ panic(err)
+ }
+ })),
+ }
+ },
+ code: http.StatusMultiStatus,
+ respBody: testBody,
+ },
+ 2: {
+ opts: func(t *testing.T) []mux.Option {
+ return []mux.Option{
+ mux.NotFound(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ _, err := w.Write([]byte(testBody))
+ if err != nil {
+ panic(err)
+ }
+ })),
+ }
+ },
+ code: http.StatusNotFound,
+ respBody: testBody,
+ },
+ 3: {
+ opts: func(t *testing.T) []mux.Option {
+ return []mux.Option{
+ mux.Handle("GET", "/", failHandler(t)),
+ mux.Handle("POST", "/", failHandler(t)),
+ mux.Handle("PUT", "/test", failHandler(t)),
+ }
+ },
+ method: http.MethodOptions,
+ code: http.StatusOK,
+ header: map[string][]string{
+ "Allow": []string{"GET,POST"},
+ },
+ },
+ 4: {
+ method: http.MethodOptions,
+ code: http.StatusOK,
+ header: map[string][]string{
+ "Allow": []string{""},
+ },
+ },
+}
+
+func TestHandlers(t *testing.T) {
+ for i, tc := range handlerTests {
+ t.Run(strconv.Itoa(i), func(t *testing.T) {
+ if tc.opts == nil {
+ tc.opts = func(*testing.T) []mux.Option { return []mux.Option{} }
+ }
+ m := mux.New(tc.opts(t)...)
+ rec := httptest.NewRecorder()
+ if tc.req == "" {
+ tc.req = "/"
+ }
+ if tc.method == "" {
+ tc.method = http.MethodGet
+ }
+ m.ServeHTTP(rec, httptest.NewRequest(tc.method, tc.req, nil))
+ if rec.Code != tc.code {
+ t.Errorf("Unexpected status code: want=%d, got=%d", tc.code, rec.Code)
+ }
+ if s := rec.Body.String(); s != tc.respBody {
+ t.Errorf("Unexpected response body: want=%q, got=%q", tc.respBody, s)
+ }
+ for k, _ := range tc.header {
+ var v, vv string
+ if k == "Allow" {
+ // Sort "Allow" headers as a special case so that we don't have to do
+ // a sort or anything in the actual handler.
+ methods := strings.Split(tc.header.Get(k), ",")
+ sort.Strings(methods)
+ v = strings.Join(methods, ",")
+
+ methods = strings.Split(rec.HeaderMap.Get(k), ",")
+ sort.Strings(methods)
+ vv = strings.Join(methods, ",")
+ } else {
+ v = tc.header.Get(k)
+ vv = rec.HeaderMap.Get(k)
+ }
+ if vv != v {
+ t.Errorf("Unexpected value for header %q: want=%q, got=%q", k, v, vv)
+ }
+ }
+ })
+ }
+}
M mux.go => mux.go +55 -24
@@ 58,6 58,24 @@ const (
type ServeMux struct {
node
notFound http.Handler
+ options func(node) http.Handler
+}
+
+// New allocates and returns a new ServeMux.
+func New(opts ...Option) *ServeMux {
+ mux := &ServeMux{
+ node: node{
+ name: "/",
+ typ: typStatic,
+ handlers: make(map[string]http.Handler),
+ },
+ notFound: http.HandlerFunc(http.NotFound),
+ options: defOptions,
+ }
+ for _, o := range opts {
+ o(mux)
+ }
+ return mux
}
// ServeHTTP dispatches the request to the handler whose pattern most closely
@@ 116,15 134,14 @@ func (mux *ServeMux) Handler(r *http.Request) (http.Handler, *http.Request) {
func (mux *ServeMux) handler(r *http.Request) (http.Handler, *http.Request) {
// TODO: Add /tree to /tree/ redirect option and apply here.
// TODO: use host
- host := r.Host
- _ = host
+ //host := r.URL.Host
path := r.URL.Path
// CONNECT requests are not canonicalized
if r.Method != http.MethodConnect {
// All other requests have any port stripped and path cleaned
// before passing to mux.handler.
- host = stripHostPort(r.Host)
+ //host = stripHostPort(r.Host)
path = cleanPath(r.URL.Path)
if path != r.URL.Path {
url := *r.URL
@@ 136,10 153,15 @@ func (mux *ServeMux) handler(r *http.Request) (http.Handler, *http.Request) {
// TODO: add host based matching and check it here.
node := &mux.node
path = strings.TrimPrefix(path, "/")
+
+ // Requests for /
if path == "" {
- h, ok := node.handlers[r.Method]
+ h, ok := mux.node.handlers[r.Method]
if !ok {
// TODO: method not supported vs not found config
+ if r.Method == http.MethodOptions && mux.options != nil {
+ return mux.options(mux.node), r
+ }
return mux.notFound, r
}
return h, r
@@ 149,7 171,7 @@ func (mux *ServeMux) handler(r *http.Request) (http.Handler, *http.Request) {
nodeloop:
for node != nil {
- // If this is a variable route,
+ // If this is a variable route
if len(node.child) == 1 && node.child[0].typ != typStatic {
var part, remain string
part, remain, r = node.child[0].match(path, offset, r)
@@ 166,6 188,9 @@ nodeloop:
h, ok := node.child[0].handlers[r.Method]
if !ok {
// TODO: method not supported vs not found config
+ if r.Method == http.MethodOptions && mux.options != nil {
+ return mux.options(node.child[0]), r
+ }
return mux.notFound, r
}
return h, r
@@ 192,6 217,9 @@ nodeloop:
h, ok := child.handlers[r.Method]
if !ok {
// TODO: method not supported vs not found config
+ if r.Method == http.MethodOptions && mux.options != nil {
+ return mux.options(child), r
+ }
return mux.notFound, r
}
return h, r
@@ 210,25 238,6 @@ nodeloop:
return mux.notFound, r
}
-// ctxParam is a type used for context keys that contain route parameters.
-type ctxParam string
-
-// New allocates and returns a new ServeMux.
-func New(opts ...Option) *ServeMux {
- mux := &ServeMux{
- node: node{
- name: "/",
- typ: typStatic,
- handlers: make(map[string]http.Handler),
- },
- notFound: http.HandlerFunc(http.NotFound),
- }
- for _, o := range opts {
- o(mux)
- }
- return mux
-}
-
// Option is used to configure a ServeMux.
type Option func(*ServeMux)
@@ 245,6 254,28 @@ func NotFound(h http.Handler) Option {
}
}
+// The ServeMux handles OPTIONS requests by default. If you do not want this
+// behavior, set f to "nil".
+//
+// Registering handlers for OPTIONS requests on a specific path always overrides
+// the default handler.
+func DefaultOptions(f func([]string) http.Handler) Option {
+ return func(mux *ServeMux) {
+ if f == nil {
+ mux.options = nil
+ return
+ }
+
+ mux.options = func(n node) http.Handler {
+ var verbs []string
+ for v, _ := range n.handlers {
+ verbs = append(verbs, v)
+ }
+ return f(verbs)
+ }
+ }
+}
+
// Handle registers the handler for the given pattern.
// If a handler already exists for pattern, Handle panics.
func HandleFunc(method, r string, h http.HandlerFunc) Option {
M node.go => node.go +7 -7
@@ 63,14 63,14 @@ func (n *node) match(path string, offset uint, r *http.Request) (part string, re
}
func addValue(r *http.Request, name, typ, raw string, offset uint, val interface{}) *http.Request {
- pinfo := ParamInfo{
- Value: val,
- Raw: raw,
- Name: name,
- Type: typ,
- Offset: offset,
- }
if name != "" {
+ pinfo := ParamInfo{
+ Value: val,
+ Raw: raw,
+ Name: name,
+ Type: typ,
+ Offset: offset,
+ }
return r.WithContext(context.WithValue(r.Context(), ctxParam(name), pinfo))
}
return r
M params.go => params.go +3 -0
@@ 4,6 4,9 @@ import (
"net/http"
)
+// ctxParam is a type used for context keys that contain route parameters.
+type ctxParam string
+
// ParamInfo represents a route parameter and related metadata.
type ParamInfo struct {
// The parsed value of the parameter (for example int64(10))