~emersion/soju

ref: d0cf1d2882cf193db0825671b3e5f3a4db018f07 soju/config/config.go -rw-r--r-- 2.3 KiB
d0cf1d28Simon Ser Add support for WebSocket connections 1 year, 8 months ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
package config

import (
	"bufio"
	"fmt"
	"io"
	"os"

	"github.com/google/shlex"
)

type TLS struct {
	CertPath, KeyPath string
}

type Server struct {
	Listen      []string
	Hostname    string
	TLS         *TLS
	SQLDriver   string
	SQLSource   string
	LogPath     string
	HTTPOrigins []string
}

func Defaults() *Server {
	hostname, err := os.Hostname()
	if err != nil {
		hostname = "localhost"
	}
	return &Server{
		Hostname:  hostname,
		SQLDriver: "sqlite3",
		SQLSource: "soju.db",
	}
}

func Load(path string) (*Server, error) {
	f, err := os.Open(path)
	if err != nil {
		return nil, err
	}
	defer f.Close()

	return Parse(f)
}

func Parse(r io.Reader) (*Server, error) {
	scanner := bufio.NewScanner(r)

	var directives []directive
	for scanner.Scan() {
		words, err := shlex.Split(scanner.Text())
		if err != nil {
			return nil, fmt.Errorf("failed to parse config file: %v", err)
		} else if len(words) == 0 {
			continue
		}

		name, params := words[0], words[1:]
		directives = append(directives, directive{name, params})
	}
	if err := scanner.Err(); err != nil {
		return nil, fmt.Errorf("failed to read config file: %v", err)
	}

	srv := Defaults()
	for _, d := range directives {
		switch d.Name {
		case "listen":
			var uri string
			if err := d.parseParams(&uri); err != nil {
				return nil, err
			}
			srv.Listen = append(srv.Listen, uri)
		case "hostname":
			if err := d.parseParams(&srv.Hostname); err != nil {
				return nil, err
			}
		case "tls":
			tls := &TLS{}
			if err := d.parseParams(&tls.CertPath, &tls.KeyPath); err != nil {
				return nil, err
			}
			srv.TLS = tls
		case "sql":
			if err := d.parseParams(&srv.SQLDriver, &srv.SQLSource); err != nil {
				return nil, err
			}
		case "log":
			if err := d.parseParams(&srv.LogPath); err != nil {
				return nil, err
			}
		case "http-origin":
			srv.HTTPOrigins = append(srv.HTTPOrigins, d.Params...)
		default:
			return nil, fmt.Errorf("unknown directive %q", d.Name)
		}
	}

	return srv, nil
}

type directive struct {
	Name   string
	Params []string
}

func (d *directive) parseParams(out ...*string) error {
	if len(d.Params) != len(out) {
		return fmt.Errorf("directive %q has wrong number of parameters: expected %v, got %v", d.Name, len(out), len(d.Params))
	}
	for i := range out {
		*out[i] = d.Params[i]
	}
	return nil
}