~gsthnz/satellite

c1d97fa0a57a0c174f8593165ca4087c6956c1d7 — Gustavo Heinz 3 years ago 1023641
Add certificate generation
2 files changed, 113 insertions(+), 6 deletions(-)

M main.go
M tls.go
M main.go => main.go +14 -5
@@ 13,21 13,30 @@ var rootDomain string
var rootPort int

func main() {
    var address = flag.String("addr", "0.0.0.0:1965", "Address to run server")
	var address = flag.String("addr", "0.0.0.0:1965", "Address to run server")
	var domain = flag.String("domain", "", "Domain to accept connections")
	var root = flag.String("root", "", "Root directory to serve files")
	var certPath = flag.String("cert", "", "Path to SSL certificate file")
	var certDays = flag.Int("days", 365, "SSL certificate validity in days, will be rotated when it expires")
	var keyPath = flag.String("key", "", "Path to SSL key file")

	flag.Parse()

	if *certPath == "" || *keyPath == "" || *root == "" || *domain == "" {
	if *certPath == "" || *keyPath == "" || *root == "" || *domain == "" || *certDays <= 0 {
		flag.PrintDefaults()
		os.Exit(1)
	}

	rootPath = *root
    rootDomain = *domain
	rootDomain = *domain

	if !checkCertValidity(*certPath, *keyPath) {
		fmt.Println("Certificate is no longer valid, creating a new one")
		err := generateX509KeyPair(*certPath, *keyPath, rootDomain, *certDays)
		if err != nil {
			fmt.Println("Error generating certificate", err)
		}
	}

	ln, err := listenTLSServer(
		*certPath,


@@ 39,9 48,9 @@ func main() {
	}
	defer ln.Close()

    rootPort = ln.Addr().(*net.TCPAddr).Port
	rootPort = ln.Addr().(*net.TCPAddr).Port

    fmt.Println("Accepting new connections on", ln.Addr())
	fmt.Println("Accepting new connections on", ln.Addr())
	for {
		conn, err := ln.Accept()
		if err != nil {

M tls.go => tls.go +99 -1
@@ 1,10 1,17 @@
package main

import (
	"crypto/ed25519"
	"crypto/rand"
	"crypto/tls"
	"crypto/x509"
	"crypto/x509/pkix"
	"encoding/pem"
	"fmt"
	"math/big"
	"net"
	"os"
	"crypto/tls"
	"time"
)

func listenTLSServer(certPath, keyPath, addr string) (net.Listener, error) {


@@ 13,6 20,97 @@ func listenTLSServer(certPath, keyPath, addr string) (net.Listener, error) {
		fmt.Println("Error loading certificate", err)
		os.Exit(1)
	}

	c := &tls.Config{Certificates: []tls.Certificate{cert}}
	return tls.Listen("tcp", addr, c)
}

func checkCertValidity(certPath, keyPath string) bool {
	cert, err := tls.LoadX509KeyPair(certPath, keyPath)
	if err != nil {
		return false
	}
	x509cert, err := x509.ParseCertificate(cert.Certificate[0])
	if err != nil {
		fmt.Println("Error checking certificate validity:", err)
		return false
	} else if x509cert.NotAfter.Sub(time.Now()) < 0 {
		fmt.Println("Certificate expired in", x509cert.NotAfter)
		return false
	}
	return true
}

func generateX509KeyPair(certPath, keyPath, domain string, validFor int) error {
	pub, priv, err := ed25519.GenerateKey(rand.Reader)
	if err != nil {
		fmt.Println("Error generating ed25519 key", err)
		return err
	}
	keyUsage := x509.KeyUsageDigitalSignature

	notBefore := time.Now()
	notAfter := notBefore.AddDate(0, 0, validFor)

	serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
	serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
	if err != nil {
		fmt.Println("Error generating serial number", err)
		return err
	}

	template := x509.Certificate{
		SerialNumber:          serialNumber,
		NotBefore:             notBefore,
		NotAfter:              notAfter,
		KeyUsage:              keyUsage,
		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
		BasicConstraintsValid: true,
	}
	template.Subject = pkix.Name{
		CommonName: domain,
	}

	template.DNSNames = append(template.DNSNames, "localhost")

	derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, pub, priv)
	if err != nil {
		fmt.Println("Error generating certificate", err)
		return err
	}
	err = writeX509KeyPair(certPath, keyPath, derBytes, priv)
	if err != nil {
		fmt.Println("Error writing certificate", err)
		return err
	}
	return nil
}

func writeX509KeyPair(certPath, keyPath string, derBytes []byte, privateKey ed25519.PrivateKey) error {
	certOut, err := os.Create(certPath)
	if err != nil {
		fmt.Printf("Failed to open %s for writing: %v", certPath, err)
	}
	if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
		fmt.Printf("Failed to write data to %s: %v", certPath, err)
	}
	if err := certOut.Close(); err != nil {
		fmt.Printf("Error closing %s: %v", certPath, err)
	}
	keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
	if err != nil {
		fmt.Printf("Failed to open %s for writing: %v", keyPath, err)
		return nil
	}
	privBytes, err := x509.MarshalPKCS8PrivateKey(privateKey)
	if err != nil {
		fmt.Printf("Unable to marshal private key: %v", err)
	}
	if err := pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil {
		fmt.Printf("Failed to write data to %s: %v", keyPath, err)
	}
	if err := keyOut.Close(); err != nil {
		fmt.Printf("Error closing %s: %v", keyPath, err)
	}
	return nil
}