~sbaildon/blocky

ce585ea22b128a86e26183c1813d08b78592f123 — Sean Baildon 1 year, 4 months ago d7e9f88
more sane
4 files changed, 180 insertions(+), 73 deletions(-)

M server/server.go
A socket/launchd_disabled.go
A socket/launchd_enabled.go
A socket/net.go
M server/server.go => server/server.go +99 -73
@@ 1,13 1,8 @@
package server

/*
#include <stdlib.h>
int launch_activate_socket(const char *name, int **fds, size_t *cnt);
*/
import "C"

import (
	"bytes"
	"context"
	"crypto/rand"
	"crypto/rsa"
	"crypto/tls"


@@ 17,13 12,11 @@ import (
	"math/big"
	mrand "math/rand"
	"net"
	"os"
	"net/http"
	"runtime"
	"runtime/debug"
	"strings"
	"sync"
	"time"
	"unsafe"

	"github.com/0xERR0R/blocky/api"
	"github.com/0xERR0R/blocky/config"


@@ 32,6 25,7 @@ import (
	"github.com/0xERR0R/blocky/model"
	"github.com/0xERR0R/blocky/redis"
	"github.com/0xERR0R/blocky/resolver"
	"github.com/0xERR0R/blocky/socket"
	"github.com/0xERR0R/blocky/util"
	"github.com/hashicorp/go-multierror"



@@ 98,7 92,7 @@ func getServerAddress(addr string) string {
	return addr
}

type NewServerFunc func(address string) (*dns.Server, error)
type NewServerFunc func(address string, sockets map[string][]int) (*dns.Server, error)

func retrieveCertificate(cfg *config.Config) (cert tls.Certificate, err error) {
	if cfg.CertFile == "" && cfg.KeyFile == "" {


@@ 195,9 189,9 @@ func createServers(cfg *config.Config, cert tls.Certificate) ([]*dns.Server, err

	var err *multierror.Error

	addServers := func(newServer NewServerFunc, addresses config.ListenConfig) error {
	addServers := func(newServer NewServerFunc, sockets map[string][]int, addresses config.ListenConfig) error {
		for _, address := range addresses {
			server, err := newServer(getServerAddress(address))
			server, err := newServer(getServerAddress(address), sockets)
			if err != nil {
				return err
			}


@@ 208,11 202,21 @@ func createServers(cfg *config.Config, cert tls.Certificate) ([]*dns.Server, err
		return nil
	}

	err = multierror.Append(err, addServers(createUDPServer, cfg.DNSPorts))
	// addServers(createTCPServer, cfg.DNSPorts),
	// addServers(func(address string) (*dns.Server, error) {
	// 	return createTLSServer(address, cert)
	// }, cfg.TLSPorts))
	sockets := make(map[string][]int)
	for _, addr := range cfg.DNSPorts {
		if strings.HasPrefix(addr, "launchd:") {
			key := addr[8:]
			fds, _ := socket.LaunchdSockets(key)
			sockets[key] = fds
		}
	}

	err = multierror.Append(err,
		addServers(createUDPServer, sockets, cfg.DNSPorts),
		addServers(createTCPServer, sockets, cfg.DNSPorts),
		addServers(func(address string, sockets map[string][]int) (*dns.Server, error) {
			return createTLSServer(address, cert)
		}, sockets, cfg.TLSPorts))

	return dnsServers, err.ErrorOrNil()
}


@@ 275,34 279,67 @@ func createTLSServer(address string, cert tls.Certificate) (*dns.Server, error) 
	}, nil
}

func createTCPServer(address string) (*dns.Server, error) {
func createTCPServer(address string, sockets map[string][]int) (*dns.Server, error) {
	addr, network := socket.ParseAddress(address)

	var listener net.Listener

	logger().Infof("looking at TCP %s", network)

	switch network {
	case "launchd":
		logger().Infof("got launchd key: %s", addr)
		listener, _ = socket.BuildListener(sockets[addr][0])
	default:
		logger().Infof("got normie uri: %s", addr)
		var lc net.ListenConfig
		listener, _ = lc.Listen(context.Background(), "tcp", addr)
	}

	if listener == nil {
		logger().Info("nil listener")
	} else {
		logger().Info("good listener")
	}

	return &dns.Server{
		Addr:    address,
		Net:     "tcp",
		Handler: dns.NewServeMux(),
		Addr:     address,
		Net:      "tcp",
		Handler:  dns.NewServeMux(),
		Listener: listener,
		NotifyStartedFunc: func() {
			logger().Infof("TCP server is up and running on address %s", address)
		},
	}, nil
}

func createUDPServer(address string) (*dns.Server, error) {
	logger().Infof("Creating UDP server with %s\n", address)
	pktConn, _ := launchdSocket("Listeners")
func createUDPServer(address string, sockets map[string][]int) (*dns.Server, error) {
	addr, network := socket.ParseAddress(address)

	var pktConn net.PacketConn
	switch network {
	case "launchd":
		logger().Infof("got launchd key: %s", addr)
		pktConn, _ = socket.BuildPacketConn(sockets[addr][1])
	default:
		logger().Infof("got normie uri: %s", addr)
		var lc net.ListenConfig
		pktConn, _ = lc.ListenPacket(context.Background(), "udp", addr)
	}

	if pktConn == nil {
		logger().Info("got nil from launchd")
		logger().Info("nil pktcon")
	} else {
		logger().Info("good to go")
		logger().Info("good pktconn")
	}

	return &dns.Server{
		Addr:       "127.0.0.1:53",
		Addr:       addr,
		Net:        "udp",
		PacketConn: pktConn,
		Handler:    dns.NewServeMux(),
		NotifyStartedFunc: func() {
			logger().Infof("UDP server is up and running on address %s", address)
			logger().Infof("UDP server is up and running on address %s", addr)
		},
		UDPSize: maxUDPBufferSize,
	}, nil


@@ 485,63 522,52 @@ func toMB(b uint64) uint64 {
	return b / bytesInKB / bytesInKB
}

func unlockOnce(l sync.Locker) func() {
	var once sync.Once
	return func() { once.Do(l.Unlock) }
}

func launchdSocket(address string) (net.PacketConn, error) {
	c_name := C.CString(address)
	var c_fds *C.int
	c_cnt := C.size_t(0)
// Start starts the server
func (s *Server) Start(errCh chan<- error) {
	logger().Info("Starting server")

	err := C.launch_activate_socket(c_name, &c_fds, &c_cnt)
	if err != 0 {
		return nil, fmt.Errorf("couldn't activate launchd socket: %v", err)
	}
	for _, srv := range s.dnsServers {
		srv := srv

	length := int(c_cnt)
	if length != 1 {
		return nil, fmt.Errorf("expected exactly one socket to be configured in launchd for '%s', found %d", address, length)
		go func() {
			if err := srv.ActivateAndServe(); err != nil {
				errCh <- fmt.Errorf("start %s listener failed: %w", srv.Net, err)
			}
		}()
	}
	ptr := unsafe.Pointer(c_fds)
	defer C.free(ptr)

	fds := (*[1]C.int)(ptr)
	file := os.NewFile(uintptr(fds[0]), "")
	for i, listener := range s.httpListeners {
		listener := listener
		address := s.cfg.HTTPPorts[i]

	l, e := net.FilePacketConn(file)
		go func() {
			logger().Infof("http server is up and running on addr/port %s", address)

	if _, ok := l.(*net.UDPConn); !ok {
		logger().Info("typecast not okay")
	} else {
		logger().Info("typecast is okay?")
			if err := http.Serve(listener, s.httpMux); err != nil {
				errCh <- fmt.Errorf("start http listener failed: %w", err)
			}
		}()
	}

	return l, e
	// return net.FileListener(file)
}

// Start starts the server
func (s *Server) Start(errCh chan<- error) {
	logger().Info("Starting server")

	for i, _ := range s.dnsServers {
		srv := s.dnsServers[i]

		// if err != nil {
		// 	errCh <- fmt.Errorf("got err: %w", err)
		// }
	for i, listener := range s.httpsListeners {
		listener := listener
		address := s.cfg.HTTPSPorts[i]

		go func() {
			if srv.PacketConn == nil {
				logger().Info("packetconn is nill")
			} else {
				logger().Info("packetconn should be good")
			logger().Infof("https server is up and running on addr/port %s", address)

			server := http.Server{
				Handler: s.httpsMux,
				//nolint:gosec
				TLSConfig: &tls.Config{
					MinVersion:   minTLSVersion(),
					CipherSuites: tlsCipherSuites(),
					Certificates: []tls.Certificate{s.cert},
				},
			}

			if err := srv.ActivateAndServe(); err != nil {
				errCh <- fmt.Errorf("start %s listener failed: %w", srv.Net, err)
			if err := server.ServeTLS(listener, "", ""); err != nil {
				errCh <- fmt.Errorf("start https listener failed: %w", err)
			}
		}()
	}

A socket/launchd_disabled.go => socket/launchd_disabled.go +12 -0
@@ 0,0 1,12 @@
//go:build !darwin

package socket

import (
	"errors"
	"net"
)

func LaunchdSockets(address string) ([]int, error) {
	return nil, errors.New("launchd socket activation is only supported on darwin")
}

A socket/launchd_enabled.go => socket/launchd_enabled.go +53 -0
@@ 0,0 1,53 @@
//go:build darwin
package socket

/*
#include <stdlib.h>
#include <launch.h>
*/
import "C"

import (
	"fmt"
	"net"
	"os"
	"unsafe"
)

func LaunchdSockets(address string) ([]int, error) {
	c_name := C.CString(address)
	var c_fds *C.int
	c_cnt := C.size_t(0)

	err := C.launch_activate_socket(c_name, &c_fds, &c_cnt)
	if err != 0 {
		return nil, fmt.Errorf("couldn't activate launchd socket: %v", err)
	}

	length := int(c_cnt)
	if length < 2 {
		return nil, fmt.Errorf("expected at least two sockets to be configured in launchd for '%s', found %d", address, length)
	}
	ptr := unsafe.Pointer(c_fds)
	defer C.free(ptr)

	fds := (*[1 << 30]C.int)(ptr)[:length:length]
	res := make([]int, length)
	for i := 0; i < length; i++ {
		res[i] = int(fds[i])
	}

	return res, nil
}

func BuildPacketConn(fd int) (net.PacketConn, error) {
	file := os.NewFile(uintptr(fd), "")

	return net.FilePacketConn(file)
}

func BuildListener(fd int) (net.Listener, error) {
	file := os.NewFile(uintptr(fd), "")

	return net.FileListener(file)
}

A socket/net.go => socket/net.go +16 -0
@@ 0,0 1,16 @@
package socket

import "strings"

func ParseAddress(input string) (address, network string) {
	if strings.HasPrefix(input, "launchd:") {
		address = input[8:]
		network = "launchd"
		return
	}

	address = input
	network = ""

	return
}