~emersion/soju

ref: 0f6bac30b894b4a562f72ccffdaa0411b6ac74f3 soju/config/config.go -rw-r--r-- 2.3 KiB
0f6bac30Hubert Hirtz Drop TAGMSG in detached channels 2 months ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
package config

import (
	"fmt"
	"net"
	"os"

	"git.sr.ht/~emersion/go-scfg"
)

type IPSet []*net.IPNet

func (set IPSet) Contains(ip net.IP) bool {
	for _, n := range set {
		if n.Contains(ip) {
			return true
		}
	}
	return false
}

// loopbackIPs contains the loopback networks 127.0.0.0/8 and ::1/128.
var loopbackIPs = IPSet{
	&net.IPNet{
		IP:   net.IP{127, 0, 0, 0},
		Mask: net.CIDRMask(8, 32),
	},
	&net.IPNet{
		IP:   net.IPv6loopback,
		Mask: net.CIDRMask(128, 128),
	},
}

type TLS struct {
	CertPath, KeyPath string
}

type Server struct {
	Listen         []string
	Hostname       string
	TLS            *TLS
	SQLDriver      string
	SQLSource      string
	LogPath        string
	HTTPOrigins    []string
	AcceptProxyIPs IPSet
}

func Defaults() *Server {
	hostname, err := os.Hostname()
	if err != nil {
		hostname = "localhost"
	}
	return &Server{
		Hostname:  hostname,
		SQLDriver: "sqlite3",
		SQLSource: "soju.db",
	}
}

func Load(path string) (*Server, error) {
	cfg, err := scfg.Load(path)
	if err != nil {
		return nil, err
	}
	return parse(cfg)
}

func parse(cfg scfg.Block) (*Server, error) {
	srv := Defaults()
	for _, d := range cfg {
		switch d.Name {
		case "listen":
			var uri string
			if err := d.ParseParams(&uri); err != nil {
				return nil, err
			}
			srv.Listen = append(srv.Listen, uri)
		case "hostname":
			if err := d.ParseParams(&srv.Hostname); err != nil {
				return nil, err
			}
		case "tls":
			tls := &TLS{}
			if err := d.ParseParams(&tls.CertPath, &tls.KeyPath); err != nil {
				return nil, err
			}
			srv.TLS = tls
		case "sql":
			if err := d.ParseParams(&srv.SQLDriver, &srv.SQLSource); err != nil {
				return nil, err
			}
		case "log":
			if err := d.ParseParams(&srv.LogPath); err != nil {
				return nil, err
			}
		case "http-origin":
			srv.HTTPOrigins = d.Params
		case "accept-proxy-ip":
			srv.AcceptProxyIPs = nil
			for _, s := range d.Params {
				if s == "localhost" {
					srv.AcceptProxyIPs = append(srv.AcceptProxyIPs, loopbackIPs...)
					continue
				}
				_, n, err := net.ParseCIDR(s)
				if err != nil {
					return nil, fmt.Errorf("directive %q: failed to parse CIDR: %v", d.Name, err)
				}
				srv.AcceptProxyIPs = append(srv.AcceptProxyIPs, n)
			}
		default:
			return nil, fmt.Errorf("unknown directive %q", d.Name)
		}
	}

	return srv, nil
}