~gsthnz/satellite

ref: 08c1a5a59b68a33880f45218cdfad85b6e1b7c65 satellite/tls.go -rw-r--r-- 4.5 KiB
08c1a5a5Paper Set minimum TLS version to TLS 1.2 1 year, 6 months ago
                                                                                
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
package main

import (
	"crypto/ed25519"
	"crypto/rand"
	"crypto/tls"
	"crypto/x509"
	"crypto/x509/pkix"
	"encoding/pem"
	"fmt"
	"math/big"
	"os"
	"path"
	"time"
)

func getTLSConfig() *tls.Config {
	cfg := &tls.Config{}
	cfg.MinVersion = tls.VersionTLS12
	for i := 0; i < len(config.Domain); i++ {
		if config.Domain[i].Name == "" {
			fmt.Printf("Invalid domain name on domain %d\n", i)
			os.Exit(1)
		}
		certPath, keyPath, err := checkDomainCerts(config.Domain[i].Name)
		if err != nil {
			fmt.Printf("Unable to generate tls config: %s\n", err)
			os.Exit(1)
		}
		cert, err := tls.LoadX509KeyPair(certPath, keyPath)
		if err != nil {
			fmt.Println("Error loading certificate", err)
			os.Exit(1)
		}
		fmt.Println("Loaded certificate", certPath)
		cfg.Certificates = append(cfg.Certificates, cert)
	}
	return cfg
}

func checkDomainCerts(domainName string) (string, string, error) {
	certPath := path.Join(config.TLS.Directory, domainName+".crt")
	keyPath := path.Join(config.TLS.Directory, domainName+".key")

	if !checkCertValidity(certPath, keyPath) {
		fmt.Printf(
			"Certificate for %s is no longer valid, creating a new one\n",
			domainName)
		err := generateX509KeyPair(certPath, keyPath, domainName, 365)
		if err != nil {
			fmt.Println("Error generating certificate", err)
			return certPath, keyPath, err
		}
		fmt.Printf("Generated certificate for %s\n", domainName)
	}
	return certPath, keyPath, nil
}

func checkCertValidity(certPath, keyPath string) bool {
	cert, err := tls.LoadX509KeyPair(certPath, keyPath)
	if err != nil {
		return false
	}
	x509cert, err := x509.ParseCertificate(cert.Certificate[0])
	if err != nil {
		fmt.Println("Error checking certificate validity:", err)
		return false
	} else if x509cert.NotAfter.Sub(time.Now()) < 0 {
		return false
	}
	updateCertReloadTime(x509cert.NotAfter)
	return true
}

func updateCertReloadTime(notAfter time.Time) {
	if notAfter.Sub(reloadTime) > 0 {
		reloadTime = notAfter
	}
}

func generateX509KeyPair(certPath, keyPath, domain string, validFor int) error {
	pub, priv, err := ed25519.GenerateKey(rand.Reader)
	if err != nil {
		fmt.Println("Error generating ed25519 key", err)
		return err
	}
	keyUsage := x509.KeyUsageDigitalSignature

	notBefore := time.Now()
	notAfter := notBefore.AddDate(0, 0, validFor)

	serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
	serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
	if err != nil {
		fmt.Println("Error generating serial number", err)
		return err
	}

	template := x509.Certificate{
		SerialNumber:          serialNumber,
		NotBefore:             notBefore,
		NotAfter:              notAfter,
		KeyUsage:              keyUsage,
		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
		BasicConstraintsValid: true,
	}
	template.Subject = pkix.Name{
		CommonName: domain,
	}

	template.DNSNames = append(template.DNSNames, domain)

	derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, pub, priv)
	if err != nil {
		fmt.Println("Error generating certificate", err)
		return err
	}
	err = writeX509KeyPair(certPath, keyPath, derBytes, priv)
	if err != nil {
		fmt.Println("Error writing certificate", err)
		return err
	}
	updateCertReloadTime(notAfter)
	return nil
}

func writeX509KeyPair(certPath, keyPath string, derBytes []byte, privateKey ed25519.PrivateKey) error {
	dir := path.Dir(certPath)
	err := os.MkdirAll(dir, 0700)
	if err != nil {
		fmt.Printf("Error creating directory %s: %v\n", certPath, err)
		return err
	}
	certOut, err := os.Create(certPath)
	if err != nil {
		fmt.Printf("Failed to open %s for writing: %v\n", certPath, err)
		return err
	}
	if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
		fmt.Printf("Failed to write data to %s: %v\n", certPath, err)
		return err
	}
	if err := certOut.Close(); err != nil {
		fmt.Printf("Error closing %s: %v\n", certPath, err)
		return err
	}
	keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
	if err != nil {
		fmt.Printf("Failed to open %s for writing: %v\n", keyPath, err)
		return err
	}
	privBytes, err := x509.MarshalPKCS8PrivateKey(privateKey)
	if err != nil {
		fmt.Printf("Unable to marshal private key: %v\n", err)
		return err
	}
	if err := pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil {
		fmt.Printf("Failed to write data to %s: %v\n", keyPath, err)
		return err
	}
	if err := keyOut.Close(); err != nil {
		fmt.Printf("Error closing %s: %v\n", keyPath, err)
		return err
	}
	return nil
}