package main
import (
"crypto/ed25519"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"os"
"path"
"time"
)
func getTLSConfig() *tls.Config {
cfg := &tls.Config{}
for i := 0; i < len(config.Domain); i++ {
if config.Domain[i].Name == "" {
fmt.Printf("Invalid domain name on domain %d\n", i)
os.Exit(1)
}
certPath, keyPath, err := checkDomainCerts(config.Domain[i].Name)
if err != nil {
fmt.Printf("Unable to generate tls config: %s\n", err)
os.Exit(1)
}
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
fmt.Println("Error loading certificate", err)
os.Exit(1)
}
fmt.Println("Loaded certificate", certPath)
cfg.Certificates = append(cfg.Certificates, cert)
}
return cfg
}
func checkDomainCerts(domainName string) (string, string, error) {
certPath := path.Join(config.TLS.Directory, domainName+".crt")
keyPath := path.Join(config.TLS.Directory, domainName+".key")
if !checkCertValidity(certPath, keyPath) {
fmt.Printf(
"Certificate for %s is no longer valid, creating a new one\n",
domainName)
err := generateX509KeyPair(certPath, keyPath, domainName, 365)
if err != nil {
fmt.Println("Error generating certificate", err)
return certPath, keyPath, err
}
fmt.Printf("Generated certificate for %s\n", domainName)
}
return certPath, keyPath, nil
}
func checkCertValidity(certPath, keyPath string) bool {
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
return false
}
x509cert, err := x509.ParseCertificate(cert.Certificate[0])
if err != nil {
fmt.Println("Error checking certificate validity:", err)
return false
} else if x509cert.NotAfter.Sub(time.Now()) < 0 {
return false
}
updateCertReloadTime(x509cert.NotAfter)
return true
}
func updateCertReloadTime(notAfter time.Time) {
if notAfter.Sub(reloadTime) > 0 {
reloadTime = notAfter
}
}
func generateX509KeyPair(certPath, keyPath, domain string, validFor int) error {
pub, priv, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
fmt.Println("Error generating ed25519 key", err)
return err
}
keyUsage := x509.KeyUsageDigitalSignature
notBefore := time.Now()
notAfter := notBefore.AddDate(0, 0, validFor)
serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
if err != nil {
fmt.Println("Error generating serial number", err)
return err
}
template := x509.Certificate{
SerialNumber: serialNumber,
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: keyUsage,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
}
template.Subject = pkix.Name{
CommonName: domain,
}
template.DNSNames = append(template.DNSNames, domain)
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, pub, priv)
if err != nil {
fmt.Println("Error generating certificate", err)
return err
}
err = writeX509KeyPair(certPath, keyPath, derBytes, priv)
if err != nil {
fmt.Println("Error writing certificate", err)
return err
}
updateCertReloadTime(notAfter)
return nil
}
func writeX509KeyPair(certPath, keyPath string, derBytes []byte, privateKey ed25519.PrivateKey) error {
dir := path.Dir(certPath)
err := os.MkdirAll(dir, 0700)
if err != nil {
fmt.Printf("Error creating directory %s: %v\n", certPath, err)
return err
}
certOut, err := os.Create(certPath)
if err != nil {
fmt.Printf("Failed to open %s for writing: %v\n", certPath, err)
return err
}
if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
fmt.Printf("Failed to write data to %s: %v\n", certPath, err)
return err
}
if err := certOut.Close(); err != nil {
fmt.Printf("Error closing %s: %v\n", certPath, err)
return err
}
keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
if err != nil {
fmt.Printf("Failed to open %s for writing: %v\n", keyPath, err)
return err
}
privBytes, err := x509.MarshalPKCS8PrivateKey(privateKey)
if err != nil {
fmt.Printf("Unable to marshal private key: %v\n", err)
return err
}
if err := pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil {
fmt.Printf("Failed to write data to %s: %v\n", keyPath, err)
return err
}
if err := keyOut.Close(); err != nil {
fmt.Printf("Error closing %s: %v\n", keyPath, err)
return err
}
return nil
}