~emersion/soju

d0cf1d2882cf193db0825671b3e5f3a4db018f07 — Simon Ser 4 years ago 4b34693
Add support for WebSocket connections

WebSocket connections allow web-based clients to connect to IRC. This
commit implements the WebSocket sub-protocol as specified by the pending
IRCv3 proposal [1].

WebSocket listeners can now be set up via a "wss" protocol in the
`listen` directive. The new `http-origin` directive allows the CORS
allowed origins to be configured.

[1]: https://github.com/ircv3/ircv3-specifications/pull/342
9 files changed, 155 insertions(+), 29 deletions(-)

M cmd/soju/main.go
M config/config.go
M conn.go
M doc/soju.1.scd
M downstream.go
M go.mod
M go.sum
M server.go
M upstream.go
M cmd/soju/main.go => cmd/soju/main.go +27 -0
@@ 5,6 5,7 @@ import (
	"flag"
	"log"
	"net"
	"net/http"
	"net/url"
	"strings"



@@ 56,6 57,7 @@ func main() {
	// TODO: load from config/DB
	srv.Hostname = cfg.Hostname
	srv.LogPath = cfg.LogPath
	srv.HTTPOrigins = cfg.HTTPOrigins
	srv.Debug = debug

	for _, listen := range cfg.Listen {


@@ 97,6 99,31 @@ func main() {
			go func() {
				log.Fatal(srv.Serve(ln))
			}()
		case "wss":
			addr := u.Host
			if _, _, err := net.SplitHostPort(addr); err != nil {
				addr = addr + ":https"
			}
			httpSrv := http.Server{
				Addr:      addr,
				TLSConfig: tlsCfg,
				Handler:   srv,
			}
			go func() {
				log.Fatal(httpSrv.ListenAndServeTLS("", ""))
			}()
		case "ws+insecure":
			addr := u.Host
			if _, _, err := net.SplitHostPort(addr); err != nil {
				addr = addr + ":http"
			}
			httpSrv := http.Server{
				Addr:    addr,
				Handler: srv,
			}
			go func() {
				log.Fatal(httpSrv.ListenAndServe())
			}()
		default:
			log.Fatalf("failed to listen on %q: unsupported scheme", listen)
		}

M config/config.go => config/config.go +9 -6
@@ 14,12 14,13 @@ type TLS struct {
}

type Server struct {
	Listen    []string
	Hostname  string
	TLS       *TLS
	SQLDriver string
	SQLSource string
	LogPath   string
	Listen      []string
	Hostname    string
	TLS         *TLS
	SQLDriver   string
	SQLSource   string
	LogPath     string
	HTTPOrigins []string
}

func Defaults() *Server {


@@ 90,6 91,8 @@ func Parse(r io.Reader) (*Server, error) {
			if err := d.parseParams(&srv.LogPath); err != nil {
				return nil, err
			}
		case "http-origin":
			srv.HTTPOrigins = append(srv.HTTPOrigins, d.Params...)
		default:
			return nil, fmt.Errorf("unknown directive %q", d.Name)
		}

M conn.go => conn.go +52 -2
@@ 1,12 1,14 @@
package soju

import (
	"context"
	"fmt"
	"net"
	"sync"
	"time"

	"gopkg.in/irc.v3"
	"nhooyr.io/websocket"
)

// ircConn is a generic IRC connection. It's similar to net.Conn but focuses on


@@ 15,11 17,11 @@ type ircConn interface {
	ReadMessage() (*irc.Message, error)
	WriteMessage(*irc.Message) error
	Close() error
	SetWriteDeadline(time.Time) error
	SetReadDeadline(time.Time) error
	SetWriteDeadline(time.Time) error
}

func netIRCConn(c net.Conn) ircConn {
func newNetIRCConn(c net.Conn) ircConn {
	type netConn net.Conn
	return struct {
		*irc.Conn


@@ 27,6 29,54 @@ func netIRCConn(c net.Conn) ircConn {
	}{irc.NewConn(c), c}
}

type websocketIRCConn struct {
	conn                        *websocket.Conn
	readDeadline, writeDeadline time.Time
}

func newWebsocketIRCConn(c *websocket.Conn) ircConn {
	return websocketIRCConn{conn: c}
}

func (wic websocketIRCConn) ReadMessage() (*irc.Message, error) {
	ctx := context.Background()
	if !wic.readDeadline.IsZero() {
		var cancel context.CancelFunc
		ctx, cancel = context.WithDeadline(ctx, wic.readDeadline)
		defer cancel()
	}
	_, b, err := wic.conn.Read(ctx)
	if err != nil {
		return nil, err
	}
	return irc.ParseMessage(string(b))
}

func (wic websocketIRCConn) WriteMessage(msg *irc.Message) error {
	b := []byte(msg.String())
	ctx := context.Background()
	if !wic.writeDeadline.IsZero() {
		var cancel context.CancelFunc
		ctx, cancel = context.WithDeadline(ctx, wic.writeDeadline)
		defer cancel()
	}
	return wic.conn.Write(ctx, websocket.MessageText, b)
}

func (wic websocketIRCConn) Close() error {
	return wic.conn.Close(websocket.StatusNormalClosure, "")
}

func (wic websocketIRCConn) SetReadDeadline(t time.Time) error {
	wic.readDeadline = t
	return nil
}

func (wic websocketIRCConn) SetWriteDeadline(t time.Time) error {
	wic.writeDeadline = t
	return nil
}

type conn struct {
	conn   ircConn
	srv    *Server

M doc/soju.1.scd => doc/soju.1.scd +8 -0
@@ 72,6 72,10 @@ The config file has one directive per line.
	  omitted: 6697)
	- _irc+insecure://[host][:port]_ listens with plain-text over TCP (default
	  port if omitted: 6667)
	- _wss://[host][:port]_ listens for WebSocket connections over TLS (default
	  port: 443)
	- _ws+insecure://[host][:port]_ listens for plain-text WebSocket
	  connections (default port: 80)

	If the scheme is omitted, "ircs" is assumed. If multiple *listen*
	directives are specified, soju will listen on each of them.


@@ 91,6 95,10 @@ The config file has one directive per line.
	Path to the bouncer logs root directory, or empty to disable logging. By
	default, logging is disabled.

*http-origin* <patterns...>
	List of allowed HTTP origins for WebSocket listeners. The parameters are
	interpreted as shell patterns, see *glob*(7).

# IRC SERVICE

soju exposes an IRC service called *BouncerServ* to manage the bouncer.

M downstream.go => downstream.go +4 -4
@@ 99,15 99,15 @@ type downstreamConn struct {
	saslServer sasl.Server
}

func newDownstreamConn(srv *Server, netConn net.Conn, id uint64) *downstreamConn {
	logger := &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", netConn.RemoteAddr())}
func newDownstreamConn(srv *Server, ic ircConn, remoteAddr string, id uint64) *downstreamConn {
	logger := &prefixLogger{srv.Logger, fmt.Sprintf("downstream %q: ", remoteAddr)}
	dc := &downstreamConn{
		conn:          *newConn(srv, netIRCConn(netConn), logger),
		conn:          *newConn(srv, ic, logger),
		id:            id,
		supportedCaps: make(map[string]string),
		caps:          make(map[string]bool),
	}
	dc.hostname = netConn.RemoteAddr().String()
	dc.hostname = remoteAddr
	if host, _, err := net.SplitHostPort(dc.hostname); err == nil {
		dc.hostname = host
	}

M go.mod => go.mod +1 -0
@@ 9,4 9,5 @@ require (
	golang.org/x/crypto v0.0.0-20200317142112-1b76d66859c6
	golang.org/x/sys v0.0.0-20200317113312-5766fd39f98d // indirect
	gopkg.in/irc.v3 v3.1.2
	nhooyr.io/websocket v1.8.5
)

M go.sum => go.sum +19 -0
@@ 2,8 2,22 @@ github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/emersion/go-sasl v0.0.0-20191210011802-430746ea8b9b h1:uhWtEWBHgop1rqEk2klKaxPAkVDCXexai6hSuRQ7Nvs=
github.com/emersion/go-sasl v0.0.0-20191210011802-430746ea8b9b/go.mod h1:G/dpzLu16WtQpBfQ/z3LYiYJn3ZhKSGWn83fyoyQe/k=
github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee h1:s+21KNqlpePfkah2I+gwHF8xmJWRjooY+5248k6m4A0=
github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo=
github.com/gobwas/pool v0.2.0 h1:QEmUOlnSjWtnpRGHF3SauEiOsy82Cup83Vf2LcMlnc8=
github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw=
github.com/gobwas/ws v1.0.2 h1:CoAavW/wd/kulfZmSIBt6p24n4j7tHgNVCjsfHVNUbo=
github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM=
github.com/golang/protobuf v1.3.5 h1:F768QJ1E9tib+q5Sc8MkdJi1RxLTbRcTf8LJV56aRls=
github.com/golang/protobuf v1.3.5/go.mod h1:6O5/vntMXwX2lRkT1hjjk0nAC1IDOTvTlVgjlRvqsdk=
github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4=
github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4=
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ=
github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM=
github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/klauspost/compress v1.10.3 h1:OP96hzwJVBIHYU52pVTI6CczrxPvrGfgqF9N5eTO0Q8=
github.com/klauspost/compress v1.10.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs=
github.com/mattn/go-sqlite3 v2.0.3+incompatible h1:gXHsfypPkaMZrKbD5209QV9jbUTJKjyR5WD3HYQSd+U=
github.com/mattn/go-sqlite3 v2.0.3+incompatible/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=


@@ 21,6 35,9 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20200317113312-5766fd39f98d h1:62ap6LNOjDU6uGmKXHJbSfciMoV+FeI1sRXx/pLDL44=
golang.org/x/sys v0.0.0-20200317113312-5766fd39f98d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/irc.v3 v3.1.2 h1:TruRvpbZ9QrE+ZxKeWxDdA2mlMajBczQ7ApZi/S3+7k=


@@ 28,3 45,5 @@ gopkg.in/irc.v3 v3.1.2/go.mod h1:shO2gz8+PVeS+4E6GAny88Z0YVVQSxQghdrMVGQsR9s=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
nhooyr.io/websocket v1.8.5 h1:DCqbsbyRh43Ky0pWkdbWXF6z6MS2W8LqJ4ym3F+fw3I=
nhooyr.io/websocket v1.8.5/go.mod h1:szdAKb/TINbpD/bAZy4Ydj5xgVo2BOLNPIi/mcAOGrU=

M server.go => server.go +34 -16
@@ 4,10 4,13 @@ import (
	"fmt"
	"log"
	"net"
	"net/http"
	"sync"
	"sync/atomic"
	"time"

	"gopkg.in/irc.v3"
	"nhooyr.io/websocket"
)

// TODO: make configurable


@@ 44,6 47,7 @@ type Server struct {
	HistoryLimit int
	LogPath      string
	Debug        bool
	HTTPOrigins  []string

	db *DB



@@ 91,27 95,41 @@ func (s *Server) getUser(name string) *user {
	return u
}

var lastDownstreamID uint64 = 0

func (s *Server) handle(ic ircConn, remoteAddr string) {
	id := atomic.AddUint64(&lastDownstreamID, 1)
	dc := newDownstreamConn(s, ic, remoteAddr, id)
	if err := dc.runUntilRegistered(); err != nil {
		dc.logger.Print(err)
	} else {
		dc.user.events <- eventDownstreamConnected{dc}
		if err := dc.readMessages(dc.user.events); err != nil {
			dc.logger.Print(err)
		}
		dc.user.events <- eventDownstreamDisconnected{dc}
	}
	dc.Close()
}

func (s *Server) Serve(ln net.Listener) error {
	var nextDownstreamID uint64 = 1
	for {
		netConn, err := ln.Accept()
		conn, err := ln.Accept()
		if err != nil {
			return fmt.Errorf("failed to accept connection: %v", err)
		}

		dc := newDownstreamConn(s, netConn, nextDownstreamID)
		nextDownstreamID++
		go func() {
			if err := dc.runUntilRegistered(); err != nil {
				dc.logger.Print(err)
			} else {
				dc.user.events <- eventDownstreamConnected{dc}
				if err := dc.readMessages(dc.user.events); err != nil {
					dc.logger.Print(err)
				}
				dc.user.events <- eventDownstreamDisconnected{dc}
			}
			dc.Close()
		}()
		go s.handle(newNetIRCConn(conn), conn.RemoteAddr().String())
	}
}

func (s *Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
	conn, err := websocket.Accept(w, req, &websocket.AcceptOptions{
		OriginPatterns: s.HTTPOrigins,
	})
	if err != nil {
		s.Logger.Printf("failed to serve HTTP connection: %v", err)
		return
	}
	s.handle(newWebsocketIRCConn(conn), req.RemoteAddr)
}

M upstream.go => upstream.go +1 -1
@@ 143,7 143,7 @@ func connectToUpstream(network *network) (*upstreamConn, error) {
	}

	uc := &upstreamConn{
		conn:                     *newConn(network.user.srv, netIRCConn(netConn), logger),
		conn:                     *newConn(network.user.srv, newNetIRCConn(netConn), logger),
		network:                  network,
		user:                     network.user,
		channels:                 make(map[string]*upstreamChannel),