~emersion/soju

1e4ff49472467e1e30c897608aeddb6921dc81c7 — Simon Ser 2 months ago 5b4469f
Save delivery receipts in DB

This avoids loosing history on restart for clients that don't
support chathistory.

Closes: https://todo.sr.ht/~emersion/soju/80
3 files changed, 161 insertions(+), 11 deletions(-)

M db.go
M upstream.go
M user.go
M db.go => db.go +90 -0
@@ 120,6 120,13 @@ type Channel struct {
	DetachOn      MessageFilter
}

type DeliveryReceipt struct {
	ID            int64
	Target        string // channel or nick
	Client        string
	InternalMsgID string
}

const schema = `
CREATE TABLE User (
	id INTEGER PRIMARY KEY,


@@ 161,6 168,16 @@ CREATE TABLE Channel (
	FOREIGN KEY(network) REFERENCES Network(id),
	UNIQUE(network, name)
);

CREATE TABLE DeliveryReceipt (
	id INTEGER PRIMARY KEY,
	network INTEGER NOT NULL,
	target VARCHAR(255) NOT NULL,
	client VARCHAR(255),
	internal_msgid VARCHAR(255) NOT NULL,
	FOREIGN KEY(network) REFERENCES Network(id),
	UNIQUE(network, target, client)
);
`

var migrations = []string{


@@ 217,6 234,17 @@ var migrations = []string{
		ALTER TABLE Channel ADD COLUMN detach_after INTEGER NOT NULL DEFAULT 0;
		ALTER TABLE Channel ADD COLUMN detach_on INTEGER NOT NULL DEFAULT 0;
	`,
	`
		CREATE TABLE DeliveryReceipt (
			id INTEGER PRIMARY KEY,
			network INTEGER NOT NULL,
			target VARCHAR(255) NOT NULL,
			client VARCHAR(255),
			internal_msgid VARCHAR(255) NOT NULL,
			FOREIGN KEY(network) REFERENCES Network(id),
			UNIQUE(network, target, client)
		);
	`,
}

type DB struct {


@@ 578,3 606,65 @@ func (db *DB) DeleteChannel(id int64) error {
	_, err := db.db.Exec("DELETE FROM Channel WHERE id = ?", id)
	return err
}

func (db *DB) ListDeliveryReceipts(networkID int64) ([]DeliveryReceipt, error) {
	db.lock.RLock()
	defer db.lock.RUnlock()

	rows, err := db.db.Query(`SELECT id, target, client, internal_msgid
		FROM DeliveryReceipt
		WHERE network = ?`, networkID)
	if err != nil {
		return nil, err
	}
	defer rows.Close()

	var receipts []DeliveryReceipt
	for rows.Next() {
		var rcpt DeliveryReceipt
		var client sql.NullString
		if err := rows.Scan(&rcpt.ID, &rcpt.Target, &client, &rcpt.InternalMsgID); err != nil {
			return nil, err
		}
		rcpt.Client = client.String
		receipts = append(receipts, rcpt)
	}
	if err := rows.Err(); err != nil {
		return nil, err
	}

	return receipts, nil
}

func (db *DB) StoreClientDeliveryReceipts(networkID int64, client string, receipts []DeliveryReceipt) error {
	db.lock.Lock()
	defer db.lock.Unlock()

	tx, err := db.db.Begin()
	if err != nil {
		return err
	}
	defer tx.Rollback()

	_, err = tx.Exec("DELETE FROM DeliveryReceipt WHERE network = ? AND client = ?",
		networkID, toNullString(client))
	if err != nil {
		return err
	}

	for i := range receipts {
		rcpt := &receipts[i]

		res, err := tx.Exec("INSERT INTO DeliveryReceipt(network, target, client, internal_msgid) VALUES (?, ?, ?, ?)",
			networkID, rcpt.Target, toNullString(client), rcpt.InternalMsgID)
		if err != nil {
			return err
		}
		rcpt.ID, err = res.LastInsertId()
		if err != nil {
			return err
		}
	}

	return tx.Commit()
}

M upstream.go => upstream.go +2 -2
@@ 1752,9 1752,9 @@ func (uc *upstreamConn) appendLog(entity string, msg *irc.Message) (msgID string
			return ""
		}

		for clientName, _ := range uc.user.clientNames {
		uc.network.delivered.ForEachClient(func(clientName string) {
			uc.network.delivered.StoreID(entity, clientName, lastID)
		}
		})
	}

	msgID, err := uc.user.msgStore.Append(uc.network, entityCM, msg)

M user.go => user.go +69 -9
@@ 92,6 92,20 @@ func (ds deliveredStore) ForEachTarget(f func(target string)) {
	}
}

func (ds deliveredStore) ForEachClient(f func(clientName string)) {
	clients := make(map[string]struct{})
	for _, entry := range ds.m.innerMap {
		delivered := entry.value.(deliveredClientMap)
		for clientName := range delivered {
			clients[clientName] = struct{}{}
		}
	}

	for clientName := range clients {
		f(clientName)
	}
}

type network struct {
	Network
	user    *user


@@ 298,6 312,28 @@ func (net *network) updateCasemapping(newCasemap casemapping) {
	}
}

func (net *network) storeClientDeliveryReceipts(clientName string) {
	if !net.user.hasPersistentMsgStore() {
		return
	}

	var receipts []DeliveryReceipt
	net.delivered.ForEachTarget(func(target string) {
		msgID := net.delivered.LoadID(target, clientName)
		if msgID == "" {
			return
		}
		receipts = append(receipts, DeliveryReceipt{
			Target:        target,
			InternalMsgID: msgID,
		})
	})

	if err := net.user.srv.db.StoreClientDeliveryReceipts(net.ID, clientName, receipts); err != nil {
		net.user.srv.Logger.Printf("failed to store delivery receipts for user %q, client %q, network %q: %v", net.user.Username, clientName, net.GetName(), err)
	}
}

type user struct {
	User
	srv *Server


@@ 308,7 344,6 @@ type user struct {
	networks        []*network
	downstreamConns []*downstreamConn
	msgStore        messageStore
	clientNames     map[string]struct{}

	// LIST commands in progress
	pendingLISTs []pendingLIST


@@ 329,12 364,11 @@ func newUser(srv *Server, record *User) *user {
	}

	return &user{
		User:        *record,
		srv:         srv,
		events:      make(chan event, 64),
		done:        make(chan struct{}),
		msgStore:    msgStore,
		clientNames: make(map[string]struct{}),
		User:     *record,
		srv:      srv,
		events:   make(chan event, 64),
		done:     make(chan struct{}),
		msgStore: msgStore,
	}
}



@@ 407,6 441,18 @@ func (u *user) run() {
		network := newNetwork(u, &record, channels)
		u.networks = append(u.networks, network)

		if u.hasPersistentMsgStore() {
			receipts, err := u.srv.db.ListDeliveryReceipts(record.ID)
			if err != nil {
				u.srv.Logger.Printf("failed to load delivery receipts for user %q, network %q: %v", u.Username, network.GetName(), err)
				return
			}

			for _, rcpt := range receipts {
				network.delivered.StoreID(rcpt.Target, rcpt.Client, rcpt.InternalMsgID)
			}
		}

		go network.run()
	}



@@ 489,8 535,6 @@ func (u *user) run() {
			u.forEachUpstream(func(uc *upstreamConn) {
				uc.updateAway()
			})

			u.clientNames[dc.clientName] = struct{}{}
		case eventDownstreamDisconnected:
			dc := e.dc



@@ 501,6 545,10 @@ func (u *user) run() {
				}
			}

			dc.forEachNetwork(func(net *network) {
				net.storeClientDeliveryReceipts(dc.clientName)
			})

			u.forEachUpstream(func(uc *upstreamConn) {
				uc.updateAway()
			})


@@ 524,6 572,10 @@ func (u *user) run() {
			})
			for _, n := range u.networks {
				n.stop()

				n.delivered.ForEachClient(func(clientName string) {
					n.storeClientDeliveryReceipts(clientName)
				})
			}
			return
		default:


@@ 665,3 717,11 @@ func (u *user) stop() {
	u.events <- eventStop{}
	<-u.done
}

func (u *user) hasPersistentMsgStore() bool {
	if u.msgStore == nil {
		return false
	}
	_, isMem := u.msgStore.(*memoryMessageStore)
	return !isMem
}