~emersion/soju

ref: 4b3469335ed2bd4d43a7f3dc945e7220fd05bdb4 soju/config/config.go -rw-r--r-- 2.2 KiB
4b346933Simon Ser Fail auth on empty password in DB 1 year, 7 months ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
package config

import (
	"bufio"
	"fmt"
	"io"
	"os"

	"github.com/google/shlex"
)

type TLS struct {
	CertPath, KeyPath string
}

type Server struct {
	Listen    []string
	Hostname  string
	TLS       *TLS
	SQLDriver string
	SQLSource string
	LogPath   string
}

func Defaults() *Server {
	hostname, err := os.Hostname()
	if err != nil {
		hostname = "localhost"
	}
	return &Server{
		Hostname:  hostname,
		SQLDriver: "sqlite3",
		SQLSource: "soju.db",
	}
}

func Load(path string) (*Server, error) {
	f, err := os.Open(path)
	if err != nil {
		return nil, err
	}
	defer f.Close()

	return Parse(f)
}

func Parse(r io.Reader) (*Server, error) {
	scanner := bufio.NewScanner(r)

	var directives []directive
	for scanner.Scan() {
		words, err := shlex.Split(scanner.Text())
		if err != nil {
			return nil, fmt.Errorf("failed to parse config file: %v", err)
		} else if len(words) == 0 {
			continue
		}

		name, params := words[0], words[1:]
		directives = append(directives, directive{name, params})
	}
	if err := scanner.Err(); err != nil {
		return nil, fmt.Errorf("failed to read config file: %v", err)
	}

	srv := Defaults()
	for _, d := range directives {
		switch d.Name {
		case "listen":
			var uri string
			if err := d.parseParams(&uri); err != nil {
				return nil, err
			}
			srv.Listen = append(srv.Listen, uri)
		case "hostname":
			if err := d.parseParams(&srv.Hostname); err != nil {
				return nil, err
			}
		case "tls":
			tls := &TLS{}
			if err := d.parseParams(&tls.CertPath, &tls.KeyPath); err != nil {
				return nil, err
			}
			srv.TLS = tls
		case "sql":
			if err := d.parseParams(&srv.SQLDriver, &srv.SQLSource); err != nil {
				return nil, err
			}
		case "log":
			if err := d.parseParams(&srv.LogPath); err != nil {
				return nil, err
			}
		default:
			return nil, fmt.Errorf("unknown directive %q", d.Name)
		}
	}

	return srv, nil
}

type directive struct {
	Name   string
	Params []string
}

func (d *directive) parseParams(out ...*string) error {
	if len(d.Params) != len(out) {
		return fmt.Errorf("directive %q has wrong number of parameters: expected %v, got %v", d.Name, len(out), len(d.Params))
	}
	for i := range out {
		*out[i] = d.Params[i]
	}
	return nil
}