// Package builder defines the default builder functions that
// generate a kick server.
package builder
import (
"compress/gzip"
"crypto/tls"
"errors"
"fmt"
"net/http"
"time"
"git.sr.ht/~mna/kick"
"github.com/NYTimes/gziphandler"
"github.com/gorilla/csrf"
"github.com/gorilla/handlers"
"github.com/julienschmidt/httprouter"
"github.com/unrolled/secure"
"golang.org/x/crypto/acme/autocert"
)
// Default is a ready-to-use kick.Builders that uses the
// builder functions of this package.
var Default = &kick.Builders{
TLS: TLS,
HTTPServer: HTTPServer,
Handler: Handler,
}
// TLS builds the TLS configuration for the server.
func TLS(s *kick.Server) (*tls.Config, error) {
var tc *tls.Config
if s.TLS == nil {
return nil, nil
}
if s.TLS.AutoCert {
tc = configureTLSAutoCert(s.TLS)
} else {
tc = new(tls.Config)
}
switch s.TLS.Mode {
case kick.TLSDefault:
// zero-value of tls.Config requested
case kick.TLSIntermediate, kick.TLSModern:
// Causes servers to use Go's default ciphersuite preferences,
// which are tuned to avoid attacks.
tc.PreferServerCipherSuites = true
// Only use curves which have assembly implementations
tc.CurvePreferences = []tls.CurveID{
tls.CurveP256,
tls.X25519,
}
if s.TLS.Mode == kick.TLSIntermediate {
break
}
tc.MinVersion = tls.VersionTLS12
tc.CipherSuites = []uint16{
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
// Best disabled, as they don't provide Forward Secrecy,
// but might be necessary for some clients
// tls.TLS_RSA_WITH_AES_256_GCM_SHA384,
// tls.TLS_RSA_WITH_AES_128_GCM_SHA256,
}
default:
return nil, fmt.Errorf("tls: unsupported TLS mode: %d", s.TLS.Mode)
}
if !s.TLS.AutoCert {
if err := configureTLSCertificate(tc, s.TLS.CertFile, s.TLS.KeyFile); err != nil {
return nil, err
}
}
return tc, nil
}
func configureTLSAutoCert(config *kick.TLSConfig) *tls.Config {
m := &autocert.Manager{
Cache: autocert.DirCache(config.AutoCertCacheDir),
Prompt: autocert.AcceptTOS,
HostPolicy: autocert.HostWhitelist(config.AutoCertExactHosts...),
Email: config.AutoCertContactEmail,
}
return m.TLSConfig()
}
func configureTLSCertificate(tc *tls.Config, certFile, keyFile string) error {
// certfile and keyfile are required
if certFile == "" || keyFile == "" {
return errors.New("tls: certificate and key files must be set")
}
var err error
tc.Certificates = make([]tls.Certificate, 1)
tc.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
return err
}
// HTTPServer builds the http.Server struct for the server.
func HTTPServer(s *kick.Server) (*http.Server, error) {
srv := &http.Server{
Addr: s.Addr,
ReadTimeout: s.ReadTimeout,
ReadHeaderTimeout: s.ReadHeaderTimeout,
WriteTimeout: s.WriteTimeout,
IdleTimeout: s.IdleTimeout,
MaxHeaderBytes: s.MaxHeaderBytes,
ConnState: s.ConnStateHook,
ErrorLog: s.ErrorLog,
}
return srv, nil
}
type middleware func(http.Handler) http.Handler
// Handler builds the handler for the routes defined on the
// server, including any required middleware.
func Handler(s *kick.Server) (http.Handler, error) {
router := httprouter.New()
router.RedirectTrailingSlash = true
router.RedirectFixedPath = true
var mw []middleware
var err error
if s.Root != nil {
mw, err = rootMiddleware(s, router)
if err != nil {
return nil, err
}
}
for _, route := range s.Routes {
if err := configureRoute(router, route); err != nil {
return nil, err
}
}
h := combineMiddleware(router, mw)
return h, nil
}
func rootMiddleware(s *kick.Server, router *httprouter.Router) ([]middleware, error) {
if s.Root.MethodNotAllowedHandler != nil {
router.MethodNotAllowed = s.Root.MethodNotAllowedHandler
router.HandleMethodNotAllowed = true
}
if s.Root.NotFoundHandler != nil {
router.NotFound = s.Root.NotFoundHandler
}
var mw []middleware
if fn := s.Root.PanicRecoveryFunc; fn != nil {
mw = append(mw, panicRecovery(fn))
}
if s.Root.TrustProxyHeaders {
mw = append(mw, handlers.ProxyHeaders)
}
if s.Root.AllowMethodOverride {
mw = append(mw, handlers.HTTPMethodOverrideHandler)
}
if s.Root.RequestIDHeader != "" {
mw = append(mw, requestID(s.Root.RequestIDHeader, s.Root.RequestIDForceNew))
}
if s.Root.SecurityHeaders != nil {
sec := secure.New(*s.Root.SecurityHeaders)
mw = append(mw, sec.Handler)
}
if fn := s.Root.LoggingFunc; fn != nil {
mw = append(mw, logging(fn))
}
if s.Root.CanonicalHost != "" {
redir := s.Root.CanonicalHostRedirectStatusCode
if redir < 300 || redir >= 400 {
return nil, fmt.Errorf("handler: canonical host redirection code must be in the 3xx family: %d", redir)
}
mw = append(mw, handlers.CanonicalHost(s.Root.CanonicalHost, redir))
}
if conf := s.Root.Gzip; conf != nil {
level := conf.CompressionLevel
if level == 0 {
level = gzip.DefaultCompression
}
gzh, err := gziphandler.GzipHandlerWithOpts(
gziphandler.CompressionLevel(level),
gziphandler.ContentTypes(conf.ContentTypes),
gziphandler.MinSize(conf.MinSize),
)
if err != nil {
return nil, err
}
mw = append(mw, gzh)
}
return mw, nil
}
func configureRoute(router *httprouter.Router, route *kick.Route) error {
if route.Path == "" {
return errors.New("handler: route must have a path")
}
if route.Path[0] != '/' {
return fmt.Errorf("handler: route path %s must start with a slash `/`", route.Path)
}
if route.Method == "" {
return fmt.Errorf("handler: method is missing for route %s", route.Path)
}
if route.Handler == nil {
return fmt.Errorf("handler: handler is missing for route %s %s", route.Method, route.Path)
}
if route.Config == nil {
router.Handler(route.Method, route.Path, route.Handler)
return nil
}
var mw []middleware
if route.Config.HandlerTimeout > 0 {
mw = append(mw, timeoutHandler(route.Config.HandlerTimeout))
}
if conf := route.Config.RequestLimit; conf != nil {
// either rate or fill interval must be set, not both.
if conf.FillInterval > 0 && conf.Rate > 0 {
return fmt.Errorf("handler: only one of FillInterval or Rate must be set")
}
if conf.FillInterval <= 0 && conf.Rate <= 0 {
return fmt.Errorf("handler: one of FillInterval or Rate must be set")
}
mw = append(mw, requestLimit(route.Config.RequestLimit))
}
if conf := route.Config.CORS; conf != nil {
opts := []handlers.CORSOption{
handlers.AllowedHeaders(conf.AllowedHeaders),
handlers.AllowedMethods(conf.AllowedMethods),
handlers.AllowedOrigins(conf.AllowedOrigins),
handlers.ExposedHeaders(conf.ExposedHeaders),
handlers.MaxAge(int(conf.MaxAge / time.Second)),
}
if conf.AllowCredentials {
opts = append(opts, handlers.AllowCredentials())
}
if conf.OptionsStatusCode > 0 {
opts = append(opts, handlers.OptionStatusCode(conf.OptionsStatusCode))
}
mw = append(mw, handlers.CORS(opts...))
}
if conf := route.Config.CSRF; conf != nil {
mw = append(mw, csrf.Protect(
conf.AuthKey,
csrf.CookieName(conf.CookieName),
csrf.Domain(conf.Domain),
csrf.HttpOnly(conf.HTTPOnly),
csrf.MaxAge(int(conf.MaxAge/time.Second)),
csrf.Path(conf.Path),
csrf.RequestHeader(conf.RequestHeader),
csrf.Secure(conf.Secure),
))
}
if route.Config.MaxRequestBodyBytes > 0 {
mw = append(mw, limitRequestBodyBytes(route.Config.MaxRequestBodyBytes))
}
if route.Config.MaxResponseBodyBytes > 0 {
mw = append(mw, limitResponseBodyBytes(route.Config.MaxResponseBodyBytes))
}
if cts := route.Config.RequestContentTypes; len(cts) > 0 {
mw = append(mw, requestContentType(cts))
}
if cts := route.Config.ResponseContentTypes; len(cts) > 0 {
mw = append(mw, responseContentType(cts))
}
h := combineMiddleware(route.Handler, mw)
router.Handler(route.Method, route.Path, h)
return nil
}
func combineMiddleware(h http.Handler, mw []middleware) http.Handler {
for i := len(mw) - 1; i >= 0; i-- {
h = mw[i](h)
}
return h
}