~emersion/soju

11b4d105a55cf983b9b865ea2d43c28f75dafc35 — Simon Ser 5 months ago 4af7a1b webpush
Add webpush extension

References: https://github.com/ircv3/ircv3-specifications/pull/471
Co-authored-by: delthas <delthas@dille.cc>
9 files changed, 632 insertions(+), 0 deletions(-)

M database/database.go
M database/postgres.go
M database/sqlite.go
M downstream.go
M go.mod
M go.sum
M server.go
M upstream.go
M user.go
M database/database.go => database/database.go +24 -0
@@ 31,6 31,13 @@ type Database interface {

	GetReadReceipt(ctx context.Context, networkID int64, name string) (*ReadReceipt, error)
	StoreReadReceipt(ctx context.Context, networkID int64, receipt *ReadReceipt) error

	ListWebPushConfigs(ctx context.Context) ([]WebPushConfig, error)
	StoreWebPushConfig(ctx context.Context, config *WebPushConfig) error

	ListWebPushSubscriptions(ctx context.Context, networkID int64) ([]WebPushSubscription, error)
	StoreWebPushSubscription(ctx context.Context, networkID int64, sub *WebPushSubscription) error
	DeleteWebPushSubscription(ctx context.Context, id int64) error
}

type MetricsCollectorDatabase interface {


@@ 175,3 182,20 @@ type ReadReceipt struct {
	Target    string // channel or nick
	Timestamp time.Time
}

type WebPushConfig struct {
	ID        int64
	VAPIDKeys struct {
		Public, Private string
	}
}

type WebPushSubscription struct {
	ID       int64
	Endpoint string
	Keys     struct {
		Auth   string
		P256DH string
		VAPID  string
	}
}

M database/postgres.go => database/postgres.go +139 -0
@@ 85,6 85,26 @@ CREATE TABLE "ReadReceipt" (
	timestamp TIMESTAMP WITH TIME ZONE NOT NULL,
	UNIQUE(network, target)
);

CREATE TABLE "WebPushConfig" (
	id SERIAL PRIMARY KEY,
	created_at TIMESTAMP WITH TIME ZONE NOT NULL,
	vapid_key_public TEXT NOT NULL,
	vapid_key_private TEXT NOT NULL,
	UNIQUE(vapid_key_public)
);

CREATE TABLE "WebPushSubscription" (
	id SERIAL PRIMARY KEY,
	created_at TIMESTAMP WITH TIME ZONE NOT NULL,
	updated_at TIMESTAMP WITH TIME ZONE NOT NULL,
	network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
	endpoint TEXT NOT NULL,
	key_vapid TEXT,
	key_auth TEXT,
	key_p256dh TEXT,
	UNIQUE(network, endpoint)
);
`

var postgresMigrations = []string{


@@ 106,6 126,27 @@ var postgresMigrations = []string{
			UNIQUE(network, target)
		);
	`,
	`
		CREATE TABLE "WebPushConfig" (
			id SERIAL PRIMARY KEY,
			created_at TIMESTAMP WITH TIME ZONE NOT NULL,
			vapid_key_public TEXT NOT NULL,
			vapid_key_private TEXT NOT NULL,
			UNIQUE(vapid_key_public)
		);

		CREATE TABLE "WebPushSubscription" (
			id SERIAL PRIMARY KEY,
			created_at TIMESTAMP WITH TIME ZONE NOT NULL,
			updated_at TIMESTAMP WITH TIME ZONE NOT NULL,
			network INTEGER NOT NULL REFERENCES "Network"(id) ON DELETE CASCADE,
			endpoint TEXT NOT NULL,
			key_vapid TEXT,
			key_auth TEXT,
			key_p256dh TEXT,
			UNIQUE(network, endpoint)
		);
	`,
}

type PostgresDB struct {


@@ 623,6 664,104 @@ func (db *PostgresDB) listTopNetworkAddrs(ctx context.Context) (map[string]int, 
	return addrs, rows.Err()
}

func (db *PostgresDB) ListWebPushConfigs(ctx context.Context) ([]WebPushConfig, error) {
	ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
	defer cancel()

	rows, err := db.db.QueryContext(ctx, `
		SELECT id, vapid_key_public, vapid_key_private
		FROM "WebPushConfig"`)
	if err != nil {
		return nil, err
	}
	defer rows.Close()

	var configs []WebPushConfig
	for rows.Next() {
		var config WebPushConfig
		if err := rows.Scan(&config.ID, &config.VAPIDKeys.Public, &config.VAPIDKeys.Private); err != nil {
			return nil, err
		}
		configs = append(configs, config)
	}

	return configs, rows.Err()
}

func (db *PostgresDB) StoreWebPushConfig(ctx context.Context, config *WebPushConfig) error {
	ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
	defer cancel()

	if config.ID != 0 {
		return fmt.Errorf("cannot update a WebPushConfig")
	}

	err := db.db.QueryRowContext(ctx, `
		INSERT INTO "WebPushConfig" (created_at, vapid_key_public, vapid_key_private)
		VALUES (NOW(), $1, $2)
		RETURNING id`,
		config.VAPIDKeys.Public, config.VAPIDKeys.Private).Scan(&config.ID)
	return err
}

func (db *PostgresDB) ListWebPushSubscriptions(ctx context.Context, networkID int64) ([]WebPushSubscription, error) {
	ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
	defer cancel()

	rows, err := db.db.QueryContext(ctx, `
		SELECT id, endpoint, key_auth, key_p256dh, key_vapid
		FROM "WebPushSubscription"
		WHERE network = $1`, networkID)
	if err != nil {
		return nil, err
	}
	defer rows.Close()

	var subs []WebPushSubscription
	for rows.Next() {
		var sub WebPushSubscription
		if err := rows.Scan(&sub.ID, &sub.Endpoint, &sub.Keys.Auth, &sub.Keys.P256DH, &sub.Keys.VAPID); err != nil {
			return nil, err
		}
		subs = append(subs, sub)
	}

	return subs, rows.Err()
}

func (db *PostgresDB) StoreWebPushSubscription(ctx context.Context, networkID int64, sub *WebPushSubscription) error {
	ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
	defer cancel()

	var err error
	if sub.ID != 0 {
		_, err = db.db.ExecContext(ctx, `
			UPDATE "WebPushSubscription"
			SET updated_at = NOW(), key_auth = $1, key_p256dh = $2,
				key_vapid = $3
			WHERE id = $4`,
			sub.Keys.Auth, sub.Keys.P256DH, sub.Keys.VAPID, sub.ID)
	} else {
		err = db.db.QueryRowContext(ctx, `
			INSERT INTO "WebPushSubscription" (created_at, updated_at, network,
				endpoint, key_auth, key_p256dh, key_vapid)
			VALUES (NOW(), NOW(), $1, $2, $3, $4, $5)
			RETURNING id`,
			networkID, sub.Endpoint, sub.Keys.Auth, sub.Keys.P256DH,
			sub.Keys.VAPID).Scan(&sub.ID)
	}

	return err
}

func (db *PostgresDB) DeleteWebPushSubscription(ctx context.Context, id int64) error {
	ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
	defer cancel()

	_, err := db.db.ExecContext(ctx, `DELETE FROM "WebPushSubscription" WHERE id = $1`, id)
	return err
}

var postgresNetworksTotalDesc = prometheus.NewDesc("soju_networks_total", "Number of networks", []string{"hostname"}, nil)

type postgresMetricsCollector struct {

M database/sqlite.go => database/sqlite.go +166 -0
@@ 84,6 84,27 @@ CREATE TABLE ReadReceipt (
	FOREIGN KEY(network) REFERENCES Network(id),
	UNIQUE(network, target)
);

CREATE TABLE WebPushConfig (
	id INTEGER PRIMARY KEY,
	created_at TEXT NOT NULL,
	vapid_key_public TEXT NOT NULL,
	vapid_key_private TEXT NOT NULL,
	UNIQUE(vapid_key_public)
);

CREATE TABLE WebPushSubscription (
	id INTEGER PRIMARY KEY,
	created_at TEXT NOT NULL,
	updated_at TEXT NOT NULL,
	network INTEGER NOT NULL,
	endpoint TEXT NOT NULL,
	key_vapid TEXT,
	key_auth TEXT,
	key_p256dh TEXT,
	FOREIGN KEY(network) REFERENCES Network(id),
	UNIQUE(network, endpoint)
);
`

var sqliteMigrations = []string{


@@ 194,6 215,28 @@ var sqliteMigrations = []string{
			UNIQUE(network, target)
		);
	`,
	`
		CREATE TABLE WebPushConfig (
			id INTEGER PRIMARY KEY,
			created_at TEXT NOT NULL,
			vapid_key_public TEXT NOT NULL,
			vapid_key_private TEXT NOT NULL,
			UNIQUE(vapid_key_public)
		);

		CREATE TABLE WebPushSubscription (
			id INTEGER PRIMARY KEY,
			created_at TEXT NOT NULL,
			updated_at TEXT NOT NULL,
			network INTEGER NOT NULL,
			endpoint TEXT NOT NULL,
			key_vapid TEXT,
			key_auth TEXT,
			key_p256dh TEXT,
			FOREIGN KEY(network) REFERENCES Network(id),
			UNIQUE(network, endpoint)
		);
	`,
}

type SqliteDB struct {


@@ 555,6 598,11 @@ func (db *SqliteDB) DeleteNetwork(ctx context.Context, id int64) error {
	}
	defer tx.Rollback()

	_, err = tx.ExecContext(ctx, "DELETE FROM WebPushSubscription WHERE network = ?", id)
	if err != nil {
		return err
	}

	_, err = tx.ExecContext(ctx, "DELETE FROM DeliveryReceipt WHERE network = ?", id)
	if err != nil {
		return err


@@ 784,3 832,121 @@ func (db *SqliteDB) StoreReadReceipt(ctx context.Context, networkID int64, recei

	return err
}

func (db *SqliteDB) ListWebPushConfigs(ctx context.Context) ([]WebPushConfig, error) {
	ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
	defer cancel()

	rows, err := db.db.QueryContext(ctx, `
		SELECT id, vapid_key_public, vapid_key_private
		FROM WebPushConfig`)
	if err != nil {
		return nil, err
	}
	defer rows.Close()

	var configs []WebPushConfig
	for rows.Next() {
		var config WebPushConfig
		if err := rows.Scan(&config.ID, &config.VAPIDKeys.Public, &config.VAPIDKeys.Private); err != nil {
			return nil, err
		}
		configs = append(configs, config)
	}

	return configs, rows.Err()
}

func (db *SqliteDB) StoreWebPushConfig(ctx context.Context, config *WebPushConfig) error {
	ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
	defer cancel()

	if config.ID != 0 {
		return fmt.Errorf("cannot update a WebPushConfig")
	}

	res, err := db.db.ExecContext(ctx, `
		INSERT INTO WebPushConfig(created_at, vapid_key_public, vapid_key_private)
		VALUES (:now, :vapid_key_public, :vapid_key_private)`,
		sql.Named("vapid_key_public", config.VAPIDKeys.Public),
		sql.Named("vapid_key_private", config.VAPIDKeys.Private),
		sql.Named("now", formatSqliteTime(time.Now())))
	if err != nil {
		return err
	}
	config.ID, err = res.LastInsertId()
	return err
}

func (db *SqliteDB) ListWebPushSubscriptions(ctx context.Context, networkID int64) ([]WebPushSubscription, error) {
	ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
	defer cancel()

	rows, err := db.db.QueryContext(ctx, `
		SELECT id, endpoint, key_auth, key_p256dh, key_vapid
		FROM WebPushSubscription
		WHERE network = ?`, networkID)
	if err != nil {
		return nil, err
	}
	defer rows.Close()

	var subs []WebPushSubscription
	for rows.Next() {
		var sub WebPushSubscription
		if err := rows.Scan(&sub.ID, &sub.Endpoint, &sub.Keys.Auth, &sub.Keys.P256DH, &sub.Keys.VAPID); err != nil {
			return nil, err
		}
		subs = append(subs, sub)
	}

	return subs, rows.Err()
}

func (db *SqliteDB) StoreWebPushSubscription(ctx context.Context, networkID int64, sub *WebPushSubscription) error {
	ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
	defer cancel()

	args := []interface{}{
		sql.Named("id", sub.ID),
		sql.Named("network", networkID),
		sql.Named("now", formatSqliteTime(time.Now())),
		sql.Named("endpoint", sub.Endpoint),
		sql.Named("key_auth", sub.Keys.Auth),
		sql.Named("key_p256dh", sub.Keys.P256DH),
		sql.Named("key_vapid", sub.Keys.VAPID),
	}

	var err error
	if sub.ID != 0 {
		_, err = db.db.ExecContext(ctx, `
			UPDATE WebPushSubscription
			SET updated_at = :now, key_auth = :key_auth, key_p256dh = :key_p256dh,
				key_vapid = :key_vapid
			WHERE id = :id`,
			args...)
	} else {
		var res sql.Result
		res, err = db.db.ExecContext(ctx, `
			INSERT INTO
			WebPushSubscription(created_at, updated_at, network, endpoint,
				key_auth, key_p256dh, key_vapid)
			VALUES (:now, :now, :network, :endpoint, :key_auth, :key_p256dh,
				:key_vapid)`,
			args...)
		if err != nil {
			return err
		}
		sub.ID, err = res.LastInsertId()
	}

	return err
}

func (db *SqliteDB) DeleteWebPushSubscription(ctx context.Context, id int64) error {
	ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
	defer cancel()

	_, err := db.db.ExecContext(ctx, "DELETE FROM WebPushSubscription WHERE id = ?", id)
	return err
}

M downstream.go => downstream.go +187 -0
@@ 9,10 9,12 @@ import (
	"fmt"
	"io"
	"net"
	"net/http"
	"strconv"
	"strings"
	"time"

	"github.com/SherClockHolmes/webpush-go"
	"github.com/emersion/go-sasl"
	"golang.org/x/crypto/bcrypt"
	"gopkg.in/irc.v3"


@@ 835,6 837,7 @@ func (dc *downstreamConn) handleCapCommand(cmd string, args []string) error {
			for k, v := range needAllDownstreamCaps {
				dc.caps.Available[k] = v
			}
			dc.caps.Available["soju.im/webpush"] = ""
		}

		caps := make([]string, 0, len(dc.caps.Available))


@@ 1143,6 1146,12 @@ func (dc *downstreamConn) updateSupportedCaps() {
	} else {
		dc.unsetSupportedCap("draft/event-playback")
	}

	if dc.network != nil {
		dc.setSupportedCap("soju.im/webpush", "")
	} else {
		dc.unsetSupportedCap("soju.im/webpush")
	}
}

func (dc *downstreamConn) updateNick() {


@@ 1501,6 1510,9 @@ func (dc *downstreamConn) welcome(ctx context.Context) error {
	if dc.network == nil && !dc.isMultiUpstream {
		isupport = append(isupport, "WHOX")
	}
	if dc.caps.IsEnabled("soju.im/webpush") {
		isupport = append(isupport, "VAPID="+dc.srv.webPush.VAPIDKeys.Public)
	}

	if uc := dc.upstream(); uc != nil {
		for k := range passthroughIsupport {


@@ 3200,6 3212,135 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
				Params:  []string{"BOUNCER", "UNKNOWN_COMMAND", subcommand, "Unknown subcommand"},
			}}
		}
	case "WEBPUSH":
		// We don't support Web push without a bound network
		if !dc.caps.IsEnabled("soju.im/webpush") {
			return newUnknownCommandError(msg.Command)
		}

		var subcommand string
		if err := parseMessageParams(msg, &subcommand); err != nil {
			return err
		}

		switch subcommand {
		case "REGISTER":
			var endpoint, keysStr string
			if err := parseMessageParams(msg, nil, &endpoint, &keysStr); err != nil {
				return err
			}

			if err := checkWebPushEndpoint(ctx, endpoint); err != nil {
				dc.logger.Printf("failed to check Web push endpoint %q: %v", endpoint, err)
				return ircError{&irc.Message{
					Command: "FAIL",
					Params:  []string{"WEBPUSH", "INVALID_PARAMS", subcommand, "Invalid endpoint"},
				}}
			}

			rawKeys := irc.ParseTags(keysStr)
			authKey, hasAuthKey := rawKeys["auth"]
			p256dhKey, hasP256dh := rawKeys["p256dh"]
			if !hasAuthKey || !hasP256dh {
				return ircError{&irc.Message{
					Command: "FAIL",
					Params:  []string{"WEBPUSH", "INVALID_PARAMS", subcommand, "Missing auth or p256dh key"},
				}}
			}

			newSub := database.WebPushSubscription{
				Endpoint: endpoint,
			}
			newSub.Keys.VAPID = dc.srv.webPush.VAPIDKeys.Public
			newSub.Keys.Auth = string(authKey)
			newSub.Keys.P256DH = string(p256dhKey)

			oldSub, err := dc.findWebPushSubscription(ctx, endpoint)
			if err != nil {
				dc.logger.Printf("failed to fetch Web push subscription: %v", err)
				return ircError{&irc.Message{
					Command: "FAIL",
					Params:  []string{"WEBPUSH", "INTERNAL_ERROR", subcommand, "Internal error"},
				}}
			}

			if oldSub != nil {
				if oldSub.Keys.VAPID == newSub.Keys.VAPID && oldSub.Keys.Auth == newSub.Keys.Auth && oldSub.Keys.P256DH == newSub.Keys.P256DH {
					// Nothing has changed, this is a no-op
					return nil
				}

				// Update the old subscription instead of creating a new one
				newSub.ID = oldSub.ID
			}

			// TODO: limit max number of subscriptions, prune old ones

			if err := dc.user.srv.db.StoreWebPushSubscription(ctx, dc.network.ID, &newSub); err != nil {
				dc.logger.Printf("failed to store Web push subscription: %v", err)
				return ircError{&irc.Message{
					Command: "FAIL",
					Params:  []string{"WEBPUSH", "INTERNAL_ERROR", subcommand, "Internal error"},
				}}
			}

			err = dc.srv.sendWebPush(ctx, &webpush.Subscription{
				Endpoint: newSub.Endpoint,
				Keys: webpush.Keys{
					Auth:   newSub.Keys.Auth,
					P256dh: newSub.Keys.P256DH,
				},
			}, newSub.Keys.VAPID, &irc.Message{
				Command: "NOTE",
				Params:  []string{"WEBPUSH", "REGISTERED", "Push notifications enabled"},
			})
			if err != nil {
				dc.logger.Printf("failed to send Web push notification to endpoint %q: %v", newSub.Endpoint, err)
			}

			dc.SendMessage(&irc.Message{
				Prefix:  dc.srv.prefix(),
				Command: "WEBPUSH",
				Params:  []string{"REGISTER", endpoint},
			})
		case "UNREGISTER":
			var endpoint string
			if err := parseMessageParams(msg, nil, &endpoint); err != nil {
				return err
			}

			oldSub, err := dc.findWebPushSubscription(ctx, endpoint)
			if err != nil {
				dc.logger.Printf("failed to fetch Web push subscription: %v", err)
				return ircError{&irc.Message{
					Command: "FAIL",
					Params:  []string{"WEBPUSH", "INTERNAL_ERROR", subcommand, "Internal error"},
				}}
			}

			if oldSub == nil {
				return nil
			}

			if err := dc.srv.db.DeleteWebPushSubscription(ctx, oldSub.ID); err != nil {
				dc.logger.Printf("failed to delete Web push subscription: %v", err)
				return ircError{&irc.Message{
					Command: "FAIL",
					Params:  []string{"WEBPUSH", "INTERNAL_ERROR", subcommand, "Internal error"},
				}}
			}

			dc.SendMessage(&irc.Message{
				Prefix:  dc.srv.prefix(),
				Command: "WEBPUSH",
				Params:  []string{"UNREGISTER", endpoint},
			})
		default:
			return ircError{&irc.Message{
				Command: "FAIL",
				Params:  []string{"WEBPUSH", "INVALID_PARAMS", subcommand, "Unknown command"},
			}}
		}
	default:
		dc.logger.Printf("unhandled message: %v", msg)



@@ 3228,6 3369,20 @@ func (dc *downstreamConn) handleNickServPRIVMSG(ctx context.Context, uc *upstrea
	}
}

func (dc *downstreamConn) findWebPushSubscription(ctx context.Context, endpoint string) (*database.WebPushSubscription, error) {
	subs, err := dc.user.srv.db.ListWebPushSubscriptions(ctx, dc.network.ID)
	if err != nil {
		return nil, err
	}

	for i, sub := range subs {
		if sub.Endpoint == endpoint {
			return &subs[i], nil
		}
	}
	return nil, nil
}

func parseNickServCredentials(text, nick string) (username, password string, ok bool) {
	fields := strings.Fields(text)
	if len(fields) < 2 {


@@ 3257,3 3412,35 @@ func parseNickServCredentials(text, nick string) (username, password string, ok 
	}
	return username, password, true
}

func checkWebPushEndpoint(ctx context.Context, endpoint string) error {
	req, err := http.NewRequestWithContext(ctx, http.MethodOptions, endpoint, nil)
	if err != nil {
		return fmt.Errorf("failed to create HTTP request: %v", err)
	}

	resp, err := http.DefaultClient.Do(req)
	if err != nil {
		return fmt.Errorf("HTTP request failed: %v", err)
	}
	resp.Body.Close()

	if resp.StatusCode/100 != 2 {
		return fmt.Errorf("HTTP request failed: %v", resp.Status)
	}

	allow := strings.Split(resp.Header.Get("Allow"), ",")
	found := false
	for _, method := range allow {
		if strings.EqualFold(strings.TrimSpace(method), http.MethodPost) {
			found = true
			break
		}
	}

	if !found {
		return fmt.Errorf("POST missing from Allow header in OPTIONS response")
	}

	return nil
}

M go.mod => go.mod +1 -0
@@ 5,6 5,7 @@ go 1.15
require (
	git.sr.ht/~emersion/go-scfg v0.0.0-20211215104734-c2c7a15d6c99
	git.sr.ht/~sircmpwn/go-bare v0.0.0-20210406120253-ab86bc2846d9
	github.com/SherClockHolmes/webpush-go v1.2.0
	github.com/emersion/go-sasl v0.0.0-20211008083017-0b9dcfb154ac
	github.com/klauspost/compress v1.14.4 // indirect
	github.com/lib/pq v1.10.4

M go.sum => go.sum +5 -0
@@ 38,6 38,8 @@ git.sr.ht/~sircmpwn/go-bare v0.0.0-20210406120253-ab86bc2846d9 h1:Ahny8Ud1LjVMMA
git.sr.ht/~sircmpwn/go-bare v0.0.0-20210406120253-ab86bc2846d9/go.mod h1:BVJwbDfVjCjoFiKrhkei6NdGcZYpkDkdyCdg1ukytRA=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
github.com/SherClockHolmes/webpush-go v1.2.0 h1:sGv0/ZWCvb1HUH+izLqrb2i68HuqD/0Y+AmGQfyqKJA=
github.com/SherClockHolmes/webpush-go v1.2.0/go.mod h1:w6X47YApe/B9wUz2Wh8xukxlyupaxSSEbu6yKJcHN2w=
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=


@@ 94,6 96,8 @@ github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6Wezm
github.com/gobwas/ws v1.0.2 h1:CoAavW/wd/kulfZmSIBt6p24n4j7tHgNVCjsfHVNUbo=
github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM=
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=
github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc=


@@ 249,6 253,7 @@ go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190131182504-b8fe1690c613/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=

M server.go => server.go +73 -0
@@ 14,6 14,7 @@ import (
	"sync/atomic"
	"time"

	"github.com/SherClockHolmes/webpush-go"
	"github.com/prometheus/client_golang/prometheus"
	"github.com/prometheus/client_golang/prometheus/promauto"
	"gopkg.in/irc.v3"


@@ 38,6 39,8 @@ var downstreamRegisterTimeout = 30 * time.Second
var chatHistoryLimit = 1000
var backlogLimit = 4000

var errWebPushSubscriptionExpired = fmt.Errorf("Web Push subscription expired")

type Logger interface {
	Printf(format string, v ...interface{})
	Debugf(format string, v ...interface{})


@@ 165,6 168,8 @@ type Server struct {

		upstreamConnectErrorsTotal prometheus.Counter
	}

	webPush *database.WebPushConfig
}

func NewServer(db database.Database) *Server {


@@ 197,6 202,10 @@ func (s *Server) SetConfig(cfg *Config) {
func (s *Server) Start() error {
	s.registerMetrics()

	if err := s.loadWebPushConfig(context.TODO()); err != nil {
		return err
	}

	users, err := s.db.ListUsers(context.TODO())
	if err != nil {
		return err


@@ 260,6 269,70 @@ func (s *Server) registerMetrics() {
	})
}

func (s *Server) loadWebPushConfig(ctx context.Context) error {
	configs, err := s.db.ListWebPushConfigs(ctx)
	if err != nil {
		return fmt.Errorf("failed to list Web push configs: %v", err)
	}

	if len(configs) > 1 {
		return fmt.Errorf("expected zero or one Web push config, got %v", len(configs))
	} else if len(configs) == 1 {
		s.webPush = &configs[0]
		return nil
	}

	s.Logger.Printf("Generating Web push VAPID key pair")
	priv, pub, err := webpush.GenerateVAPIDKeys()
	if err != nil {
		return fmt.Errorf("failed to generate Web push VAPID key pair: %v", err)
	}

	config := new(database.WebPushConfig)
	config.VAPIDKeys.Public = pub
	config.VAPIDKeys.Private = priv
	if err := s.db.StoreWebPushConfig(ctx, config); err != nil {
		return fmt.Errorf("failed to store Web push config: %v", err)
	}

	s.webPush = config
	return nil
}

func (s *Server) sendWebPush(ctx context.Context, sub *webpush.Subscription, vapidPubKey string, msg *irc.Message) error {
	ctx, cancel := context.WithTimeout(ctx, 15*time.Second)
	defer cancel()

	options := webpush.Options{
		VAPIDPublicKey:  s.webPush.VAPIDKeys.Public,
		VAPIDPrivateKey: s.webPush.VAPIDKeys.Private,
		Subscriber:      "https://soju.im",
		TTL:             7 * 24 * 60 * 60, // seconds
		Urgency:         webpush.UrgencyHigh,
		RecordSize:      2048,
	}

	if vapidPubKey != options.VAPIDPublicKey {
		return fmt.Errorf("unknown VAPID public key %q", vapidPubKey)
	}

	payload := []byte(msg.String())
	resp, err := webpush.SendNotificationWithContext(ctx, payload, sub, &options)
	if err != nil {
		return err
	}
	resp.Body.Close()

	// 404 means the subscription has expired as per RFC 8030 section 7.3
	if resp.StatusCode == http.StatusNotFound {
		return errWebPushSubscriptionExpired
	} else if resp.StatusCode/100 != 2 {
		return fmt.Errorf("HTTP error: %v", resp.Status)
	}

	return nil
}

func (s *Server) Shutdown() {
	s.lock.Lock()
	for ln := range s.listeners {

M upstream.go => upstream.go +10 -0
@@ 516,6 516,12 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
				if ch.DetachOn == database.FilterMessage || ch.DetachOn == database.FilterDefault || (ch.DetachOn == database.FilterHighlight && highlight) {
					uc.updateChannelAutoDetach(target)
				}
				if highlight {
					uc.network.broadcastWebPush(ctx, msg)
				}
			}
			if ch == nil && uc.isOurNick(entity) {
				uc.network.broadcastWebPush(ctx, msg)
			}

			uc.produce(target, msg, downstreamID)


@@ 1529,6 1535,10 @@ func (uc *upstreamConn) handleMessage(ctx context.Context, msg *irc.Message) err
				Params:  []string{dc.marshalEntity(uc.network, nick), dc.marshalEntity(uc.network, channel)},
			})
		})

		if weAreInvited {
			uc.network.broadcastWebPush(ctx, msg)
		}
	case irc.RPL_INVITING:
		var nick, channel string
		if err := parseMessageParams(msg, nil, &nick, &channel); err != nil {

M user.go => user.go +27 -0
@@ 13,6 13,7 @@ import (
	"strings"
	"time"

	"github.com/SherClockHolmes/webpush-go"
	"gopkg.in/irc.v3"

	"git.sr.ht/~emersion/soju/database"


@@ 445,6 446,32 @@ func (net *network) autoSaveSASLPlain(ctx context.Context, username, password st
	}
}

func (net *network) broadcastWebPush(ctx context.Context, msg *irc.Message) {
	subs, err := net.user.srv.db.ListWebPushSubscriptions(ctx, net.ID)
	if err != nil {
		net.logger.Printf("failed to list Web push subscriptions: %v", err)
		return
	}

	for _, sub := range subs {
		err := net.user.srv.sendWebPush(ctx, &webpush.Subscription{
			Endpoint: sub.Endpoint,
			Keys: webpush.Keys{
				Auth:   sub.Keys.Auth,
				P256dh: sub.Keys.P256DH,
			},
		}, sub.Keys.VAPID, msg)
		if err != nil {
			net.logger.Printf("failed to send Web push notification to endpoint %q: %v", sub.Endpoint, err)
		}
		if err == errWebPushSubscriptionExpired {
			if err := net.user.srv.db.DeleteWebPushSubscription(ctx, sub.ID); err != nil {
				net.logger.Printf("failed to delete expired Web Push subscription: %v", err)
			}
		}
	}
}

type user struct {
	database.User
	srv    *Server