~emersion/soju

ref: db55076d04033830fc52c2f910bb514337af049f soju/config/config.go -rw-r--r-- 2.7 KiB
db55076dSimon Ser wip: systemd-wide config/data paths 2 months ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
package config

import (
	"errors"
	"fmt"
	"net"
	"os"
	"path/filepath"

	"git.sr.ht/~emersion/go-scfg"
)

// These can be set by build scripts
var sysConfDir, sharedStateDir string

type IPSet []*net.IPNet

func (set IPSet) Contains(ip net.IP) bool {
	for _, n := range set {
		if n.Contains(ip) {
			return true
		}
	}
	return false
}

// loopbackIPs contains the loopback networks 127.0.0.0/8 and ::1/128.
var loopbackIPs = IPSet{
	&net.IPNet{
		IP:   net.IP{127, 0, 0, 0},
		Mask: net.CIDRMask(8, 32),
	},
	&net.IPNet{
		IP:   net.IPv6loopback,
		Mask: net.CIDRMask(128, 128),
	},
}

type TLS struct {
	CertPath, KeyPath string
}

type Server struct {
	Listen         []string
	Hostname       string
	TLS            *TLS
	SQLDriver      string
	SQLSource      string
	LogPath        string
	HTTPOrigins    []string
	AcceptProxyIPs IPSet
}

func Defaults() *Server {
	hostname, err := os.Hostname()
	if err != nil {
		hostname = "localhost"
	}
	sqlSource := "soju.db"
	if sharedStateDir != "" {
		sqlSource = filepath.Join(sharedStateDir, "soju", "main.db")
	}
	return &Server{
		Hostname:  hostname,
		SQLDriver: "sqlite3",
		SQLSource: sqlSource,
	}
}

func Load(path string) (*Server, error) {
	isDefaultPath := path == ""
	if path == "" && sysConfDir != "" {
		path = filepath.Join(sysConfDir, "soju", "config")
	}
	if path == "" {
		return Defaults(), nil
	}

	cfg, err := scfg.Load(path)
	if err != nil {
		if isDefaultPath && errors.Is(err, os.ErrNotExist) {
			return Defaults(), nil
		}
		return nil, err
	}
	return parse(cfg)
}

func parse(cfg scfg.Block) (*Server, error) {
	srv := Defaults()
	for _, d := range cfg {
		switch d.Name {
		case "listen":
			var uri string
			if err := d.ParseParams(&uri); err != nil {
				return nil, err
			}
			srv.Listen = append(srv.Listen, uri)
		case "hostname":
			if err := d.ParseParams(&srv.Hostname); err != nil {
				return nil, err
			}
		case "tls":
			tls := &TLS{}
			if err := d.ParseParams(&tls.CertPath, &tls.KeyPath); err != nil {
				return nil, err
			}
			srv.TLS = tls
		case "sql":
			if err := d.ParseParams(&srv.SQLDriver, &srv.SQLSource); err != nil {
				return nil, err
			}
		case "log":
			if err := d.ParseParams(&srv.LogPath); err != nil {
				return nil, err
			}
		case "http-origin":
			srv.HTTPOrigins = d.Params
		case "accept-proxy-ip":
			srv.AcceptProxyIPs = nil
			for _, s := range d.Params {
				if s == "localhost" {
					srv.AcceptProxyIPs = append(srv.AcceptProxyIPs, loopbackIPs...)
					continue
				}
				_, n, err := net.ParseCIDR(s)
				if err != nil {
					return nil, fmt.Errorf("directive %q: failed to parse CIDR: %v", d.Name, err)
				}
				srv.AcceptProxyIPs = append(srv.AcceptProxyIPs, n)
			}
		default:
			return nil, fmt.Errorf("unknown directive %q", d.Name)
		}
	}

	return srv, nil
}