~gsthnz/satellite

16b300ecf7996fcebe1fb3ba31ea6c22c2c85190 — Gustavo Heinz 3 years ago 8403b30
Support multiple hosts and regenerate certificates
9 files changed, 282 insertions(+), 112 deletions(-)

M README.md
A config.go
M gemini.go
A go.mod
A go.sum
M main.go
A server.go
M tls.go
A utils.go
M README.md => README.md +25 -5
@@ 4,14 4,34 @@ Satellite is a small Gemini server for serving static files.

## Usage
```
satellite -cert <certfile> -key <keyfile> -root <contentdir> -domain <domain>
satellite
```

The SSL certificate generation will be done automatically in the first run.
## Configuration
By default, Satellite uses `/etc/satellite.toml`, but a custom config file can
be supplied with the `-c` flag.

Every time the server starts, it will check for the certificate expire date
will generate a new one if needed. This is not optimal, I plan to do the
certificate rotation without needing to restart in the future.
### Example configuration
```
# Address to listen to requests (default: 0.0.0.0:1965)
listen = "0.0.0.0"

[tls]
# Directory to save certificates
directory = "/var/lib/gemini/certs"

# Multiple domains can be set with the [[domain]] section
[[domain]]
name = "example.com"
root = "/srv/gemini/example.com"

[[domain]]
name = "example2.com"
root = "/srv/gemini/example2.com"
```

The SSL certificate generation will be managed by Satellite and will be
automatically regenerated on the certificate expiry date.

## Building
```

A config.go => config.go +40 -0
@@ 0,0 1,40 @@
package main

import (
	"github.com/pelletier/go-toml"
)

type Route struct {
	Match string
	Root  string
}

type Domain struct {
	Name  string
	Root  string
	Route []Route
}

type TLS struct {
	Directory string
}

type Config struct {
	Listen  string
	CertDir string
	TLS     TLS
	Domain  []Domain
}

func LoadConfig(file string) (*Config, error) {
	c := new(Config)
	tree, err := toml.LoadFile(file)
	if err != nil {
		return c, err
	}
	err = tree.Unmarshal(c)
	if err != nil {
		return c, err
	}
	return c, nil
}

M gemini.go => gemini.go +8 -35
@@ 8,8 8,6 @@ import (
	"net/url"
	"os"
	"path"
	"strings"
	"unicode/utf8"
)

type Status int


@@ 32,35 30,10 @@ const MaxURLSize = 1024
const IndexFile = "index.gmi"
const GeminiMIME = "text/gemini"

func handleRequest(c net.Conn, rawURL string) {
	rawURL = strings.TrimSpace(rawURL)
	if rawURL == "" {
		sendError(c, BadRequest, "Empty URL")
		return
	} else if !utf8.ValidString(rawURL) {
		sendError(c, BadRequest, "Non UTF-8 Request")
		return
	} else if len(rawURL) > MaxURLSize {
		sendError(c, BadRequest, "URL Larger than 1024 bytes")
		return
	}

	parsedURL, err := url.Parse(rawURL)
	if err != nil {
		sendError(c, BadRequest, "Bad URL")
	}

	if parsedURL.Scheme != "" && parsedURL.Scheme != "gemini" {
		sendError(c, ProxyRequestRefused, fmt.Sprintf("Unknown scheme '%s'", parsedURL.Scheme))
		return
	} else if parsedURL.Host == "" {
		sendError(c, BadRequest, "Host not supplied")
		return
	} else if parsedURL.Host != rootDomain && parsedURL.Host != fmt.Sprintf("%s:%d", rootDomain, rootPort) {
		sendError(c, ProxyRequestRefused, fmt.Sprintf("Host not found '%s'", parsedURL.Host))
		return
	} else if parsedURL.Path == "" {
		redirectPermanent(c, rawURL+"/")
func handleRequest(c net.Conn, di int, parsedURL *url.URL) {
	if parsedURL.Path == "" {
		parsedURL.Path = "/"
		redirectPermanent(c, parsedURL.String())
		return
	} else if parsedURL.Path != path.Clean(parsedURL.Path) {
		sendError(c, BadRequest, "Path error")


@@ 68,14 41,14 @@ func handleRequest(c net.Conn, rawURL string) {
	}

	if parsedURL.Path == "/" || parsedURL.Path == "." {
		serve(c, IndexFile)
		serve(c, di, IndexFile)
	} else {
		serve(c, parsedURL.Path)
		serve(c, di, parsedURL.Path)
	}
}

func serve(c net.Conn, filepath string) {
	fullPath := path.Join(rootPath, filepath)
func serve(c net.Conn, di int, filepath string) {
	fullPath := path.Join(config.Domain[di].Root, filepath)

	_, err := os.Stat(fullPath)
	if err != nil {

A go.mod => go.mod +5 -0
@@ 0,0 1,5 @@
module git.sr.ht/~gsthnz/satellite

go 1.15

require github.com/pelletier/go-toml v1.8.1

A go.sum => go.sum +4 -0
@@ 0,0 1,4 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/pelletier/go-toml v1.8.1 h1:1Nf83orprkJyknT6h7zbuEGUEjcyVlCxSUGTENmNCRM=
github.com/pelletier/go-toml v1.8.1/go.mod h1:T2/BmBdy8dvIRq1a/8aqjN41wvWlN4lrapLU/GW4pbc=

M main.go => main.go +13 -54
@@ 1,72 1,31 @@
package main

import (
	"bufio"
	"flag"
	"fmt"
	"net"
	"os"
)

var rootPath string
var rootDomain string
var rootPort int
var config *Config

func main() {
	var address = flag.String("addr", "0.0.0.0:1965", "Address to run server")
	var domain = flag.String("domain", "", "Domain to accept connections")
	var root = flag.String("root", "", "Root directory to serve files")
	var certPath = flag.String("cert", "", "Path to SSL certificate file")
	var certDays = flag.Int("days", 365, "SSL certificate validity in days, will be rotated when it expires")
	var keyPath = flag.String("key", "", "Path to SSL key file")

	var err error
	var configFile = flag.String("c", "/etc/satellite.toml", "Config file")
	flag.Parse()

	if *certPath == "" || *keyPath == "" || *root == "" || *domain == "" || *certDays <= 0 {
		flag.PrintDefaults()
	config, err = LoadConfig(*configFile)
	if err != nil {
		fmt.Printf("Error loading configuration: %+v\n", err)
		os.Exit(1)
	}

	rootPath = *root
	rootDomain = *domain

	if !checkCertValidity(*certPath, *keyPath) {
		fmt.Println("Certificate is no longer valid, creating a new one")
		err := generateX509KeyPair(*certPath, *keyPath, rootDomain, *certDays)
		if err != nil {
			fmt.Println("Error generating certificate", err)
		}
	}

	ln, err := listenTLSServer(
		*certPath,
		*keyPath,
		*address)
	if err != nil {
		fmt.Println("Error listening", err)
		return
	if config.TLS.Directory == "" {
		fmt.Println("TLS certificate directory not set")
		os.Exit(1)
	}
	defer ln.Close()

	rootPort = ln.Addr().(*net.TCPAddr).Port

	fmt.Println("Accepting new connections on", ln.Addr())
	for {
		conn, err := ln.Accept()
		if err != nil {
			fmt.Println("Error accepting connections", err)
			return
		}
		go handleConnection(conn)
	if len(config.Domain) == 0 {
		fmt.Printf("No domains defined on %s\n", *configFile)
		os.Exit(1)
	}
}

func handleConnection(c net.Conn) {
	defer c.Close()

	req, err := bufio.NewReader(c).ReadString('\r')
	if err != nil {
		return
	}
	handleRequest(c, req)
	startServer()
}

A server.go => server.go +100 -0
@@ 0,0 1,100 @@
package main

import (
	"bufio"
	"crypto/tls"
	"fmt"
	"net"
	"net/url"
	"strings"
	"time"
	"unicode/utf8"
)

var reloadTime time.Time

func startServer() {
	for {
		ln, err := listenTLSServer(config.Listen)
		if err != nil {
			fmt.Println("Error starting server", err)
			return
		}
		defer ln.Close()

		go waitForCertReload(ln)

		fmt.Println("Accepting new connections on", ln.Addr())
		for {
			conn, err := ln.Accept()
			if err != nil {
				fmt.Println("Error accepting connections", err)
				break
			}
			go handleConnection(conn)
		}
	}
}

func waitForCertReload(ln net.Listener) {
	fmt.Println("Will reload certs on", reloadTime)
	<-time.After(time.Until(reloadTime))
	fmt.Println("Certificates expired, reloading server")
	ln.Close()
}

func listenTLSServer(addr string) (net.Listener, error) {
	cfg := getTLSConfig()
	addr = parseAddr(addr)
	return tls.Listen("tcp", addr, cfg)
}

func handleConnection(c net.Conn) {
	defer c.Close()

	req, err := bufio.NewReader(c).ReadString('\r')
	if err != nil {
		return
	}

	rawURL := strings.TrimSpace(req)
	if rawURL == "" {
		sendError(c, BadRequest, "Empty URL")
		return
	} else if !utf8.ValidString(rawURL) {
		sendError(c, BadRequest, "Non UTF-8 Request")
		return
	} else if len(rawURL) > MaxURLSize {
		sendError(c, BadRequest, "URL Larger than 1024 bytes")
		return
	}

	parsedURL, err := url.Parse(rawURL)
	if err != nil {
		sendError(c, BadRequest, "Bad URL")
	}
	if parsedURL.Scheme == "" {
		parsedURL.Scheme = "gemini"
	}
	if parsedURL.Scheme != "gemini" {
		sendError(c, ProxyRequestRefused, fmt.Sprintf("Unknown scheme '%s'", parsedURL.Scheme))
		return
	} else if parsedURL.Host == "" {
		sendError(c, BadRequest, "Host not supplied")
		return
	}
	_, port, _ := net.SplitHostPort(c.LocalAddr().String())
	if parsedURL.Port() != "" && parsedURL.Port() != port {
		sendError(c, ProxyRequestRefused, "Wrong port")
		return
	}

	for i := 0; i < len(config.Domain); i++ {
		if config.Domain[i].Name == parsedURL.Hostname() {
			handleRequest(c, i, parsedURL)
			return
		}
	}
	sendError(c, ProxyRequestRefused, fmt.Sprintf("Host not found '%s'",
		parsedURL.Host))
}

M tls.go => tls.go +67 -18
@@ 9,20 9,50 @@ import (
	"encoding/pem"
	"fmt"
	"math/big"
	"net"
	"os"
	"path"
	"time"
)

func listenTLSServer(certPath, keyPath, addr string) (net.Listener, error) {
	cert, err := tls.LoadX509KeyPair(certPath, keyPath)
	if err != nil {
		fmt.Println("Error loading certificate", err)
		os.Exit(1)
	}
func getTLSConfig() *tls.Config {
	cfg := &tls.Config{}
	for i := 0; i < len(config.Domain); i++ {
		if config.Domain[i].Name == "" {
			fmt.Printf("Invalid domain name on domain %d\n", i)
			os.Exit(1)
		}
		certPath, keyPath, err := checkDomainCerts(config.Domain[i].Name)
		if err != nil {
			fmt.Printf("Unable to generate tls config: %s\n", err)
			os.Exit(1)
		}
		cert, err := tls.LoadX509KeyPair(certPath, keyPath)
		if err != nil {
			fmt.Println("Error loading certificate", err)
			os.Exit(1)
		}
		fmt.Println("Loaded certificate", certPath)
		cfg.Certificates = append(cfg.Certificates, cert)
	}
	return cfg
}

func checkDomainCerts(domainName string) (string, string, error) {
	certPath := path.Join(config.TLS.Directory, domainName+".crt")
	keyPath := path.Join(config.TLS.Directory, domainName+".key")

	c := &tls.Config{Certificates: []tls.Certificate{cert}}
	return tls.Listen("tcp", addr, c)
	if !checkCertValidity(certPath, keyPath) {
		fmt.Printf(
			"Certificate for %s is no longer valid, creating a new one\n",
			domainName)
		err := generateX509KeyPair(certPath, keyPath, domainName, 365)
		if err != nil {
			fmt.Println("Error generating certificate", err)
			return certPath, keyPath, err
		}
		fmt.Printf("Generated certificate for %s\n", domainName)
	}
	return certPath, keyPath, nil
}

func checkCertValidity(certPath, keyPath string) bool {


@@ 35,12 65,18 @@ func checkCertValidity(certPath, keyPath string) bool {
		fmt.Println("Error checking certificate validity:", err)
		return false
	} else if x509cert.NotAfter.Sub(time.Now()) < 0 {
		fmt.Println("Certificate expired in", x509cert.NotAfter)
		return false
	}
	updateCertReloadTime(x509cert.NotAfter)
	return true
}

func updateCertReloadTime(notAfter time.Time) {
	if notAfter.Sub(reloadTime) > 0 {
		reloadTime = notAfter
	}
}

func generateX509KeyPair(certPath, keyPath, domain string, validFor int) error {
	pub, priv, err := ed25519.GenerateKey(rand.Reader)
	if err != nil {


@@ 83,34 119,47 @@ func generateX509KeyPair(certPath, keyPath, domain string, validFor int) error {
		fmt.Println("Error writing certificate", err)
		return err
	}
	updateCertReloadTime(notAfter)
	return nil
}

func writeX509KeyPair(certPath, keyPath string, derBytes []byte, privateKey ed25519.PrivateKey) error {
	dir := path.Dir(certPath)
	err := os.MkdirAll(dir, 0700)
	if err != nil {
		fmt.Printf("Error creating directory %s: %v\n", certPath, err)
		return err
	}
	certOut, err := os.Create(certPath)
	if err != nil {
		fmt.Printf("Failed to open %s for writing: %v", certPath, err)
		fmt.Printf("Failed to open %s for writing: %v\n", certPath, err)
		return err
	}
	if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
		fmt.Printf("Failed to write data to %s: %v", certPath, err)
		fmt.Printf("Failed to write data to %s: %v\n", certPath, err)
		return err
	}
	if err := certOut.Close(); err != nil {
		fmt.Printf("Error closing %s: %v", certPath, err)
		fmt.Printf("Error closing %s: %v\n", certPath, err)
		return err
	}
	keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
	if err != nil {
		fmt.Printf("Failed to open %s for writing: %v", keyPath, err)
		return nil
		fmt.Printf("Failed to open %s for writing: %v\n", keyPath, err)
		return err
	}
	privBytes, err := x509.MarshalPKCS8PrivateKey(privateKey)
	if err != nil {
		fmt.Printf("Unable to marshal private key: %v", err)
		fmt.Printf("Unable to marshal private key: %v\n", err)
		return err
	}
	if err := pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil {
		fmt.Printf("Failed to write data to %s: %v", keyPath, err)
		fmt.Printf("Failed to write data to %s: %v\n", keyPath, err)
		return err
	}
	if err := keyOut.Close(); err != nil {
		fmt.Printf("Error closing %s: %v", keyPath, err)
		fmt.Printf("Error closing %s: %v\n", keyPath, err)
		return err
	}
	return nil
}

A utils.go => utils.go +20 -0
@@ 0,0 1,20 @@
package main

import "net"

func parseAddr(addr string) string {
	hostname, port, err := net.SplitHostPort(addr)
	if err != nil {
		ip := net.ParseIP(addr)
		if ip != nil {
			hostname = ip.String()
		}
	}
	if hostname == "" {
		hostname = "0.0.0.0"
	}
	if port == "" {
		port = "1965"
	}
	return net.JoinHostPort(hostname, port)
}