~erock/pico

ref: 90951e09568b61d17747a4a76036f2a876fe3f3e pico/shared/router.go -rw-r--r-- 2.9 KiB
90951e09Eric Bower fix(lists): wrong anchor links 5 months ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
package shared

import (
	"context"
	"fmt"
	"net"
	"net/http"
	"regexp"
	"strings"

	"git.sr.ht/~erock/pico/db"
	"go.uber.org/zap"
)

type Route struct {
	method  string
	regex   *regexp.Regexp
	handler http.HandlerFunc
}

func NewRoute(method, pattern string, handler http.HandlerFunc) Route {
	return Route{
		method,
		regexp.MustCompile("^" + pattern + "$"),
		handler,
	}
}

type ServeFn func(http.ResponseWriter, *http.Request)

func CreateServe(routes []Route, subdomainRoutes []Route, cfg *ConfigSite, dbpool db.DB, logger *zap.SugaredLogger) ServeFn {
	return func(w http.ResponseWriter, r *http.Request) {
		var allow []string
		var subdomain string
		curRoutes := routes

		if cfg.IsCustomdomains() || cfg.IsSubdomains() {
			hostDomain := strings.ToLower(strings.Split(r.Host, ":")[0])
			appDomain := strings.ToLower(strings.Split(cfg.ConfigCms.Domain, ":")[0])

			if hostDomain != appDomain {
				if strings.Contains(hostDomain, appDomain) {
					subdomain = strings.TrimSuffix(hostDomain, fmt.Sprintf(".%s", appDomain))
					if subdomain != "" {
						curRoutes = subdomainRoutes
					}
				} else {
					subdomain = GetCustomDomain(hostDomain, cfg.Space)
					if subdomain != "" {
						curRoutes = subdomainRoutes
					}
				}
			}
		}

		for _, route := range curRoutes {
			matches := route.regex.FindStringSubmatch(r.URL.Path)
			if len(matches) > 0 {
				if r.Method != route.method {
					allow = append(allow, route.method)
					continue
				}
				loggerCtx := context.WithValue(r.Context(), ctxLoggerKey{}, logger)
				subdomainCtx := context.WithValue(loggerCtx, ctxSubdomainKey{}, subdomain)
				dbCtx := context.WithValue(subdomainCtx, ctxDBKey{}, dbpool)
				cfgCtx := context.WithValue(dbCtx, ctxCfg{}, cfg)
				ctx := context.WithValue(cfgCtx, ctxKey{}, matches[1:])
				route.handler(w, r.WithContext(ctx))
				return
			}
		}
		if len(allow) > 0 {
			w.Header().Set("Allow", strings.Join(allow, ", "))
			http.Error(w, "405 method not allowed", http.StatusMethodNotAllowed)
			return
		}
		http.NotFound(w, r)
	}
}

type ctxDBKey struct{}
type ctxKey struct{}
type ctxLoggerKey struct{}
type ctxSubdomainKey struct{}
type ctxCfg struct{}

func GetCfg(r *http.Request) *ConfigSite {
	return r.Context().Value(ctxCfg{}).(*ConfigSite)
}

func GetLogger(r *http.Request) *zap.SugaredLogger {
	return r.Context().Value(ctxLoggerKey{}).(*zap.SugaredLogger)
}

func GetDB(r *http.Request) db.DB {
	return r.Context().Value(ctxDBKey{}).(db.DB)
}

func GetField(r *http.Request, index int) string {
	fields := r.Context().Value(ctxKey{}).([]string)
	return fields[index]
}

func GetSubdomain(r *http.Request) string {
	return r.Context().Value(ctxSubdomainKey{}).(string)
}

func GetCustomDomain(host string, space string) string {
	records, err := net.LookupTXT(fmt.Sprintf("_%s.%s", space, host))
	if err != nil {
		return ""
	}

	for _, v := range records {
		return strings.TrimSpace(v)
	}

	return ""
}