M database/database.go => database/database.go +2 -2
@@ 36,8 36,8 @@ type Database interface {
ListWebPushConfigs(ctx context.Context) ([]WebPushConfig, error)
StoreWebPushConfig(ctx context.Context, config *WebPushConfig) error
- ListWebPushSubscriptions(ctx context.Context, networkID int64) ([]WebPushSubscription, error)
- StoreWebPushSubscription(ctx context.Context, networkID int64, sub *WebPushSubscription) error
+ ListWebPushSubscriptions(ctx context.Context, userID, networkID int64) ([]WebPushSubscription, error)
+ StoreWebPushSubscription(ctx context.Context, userID, networkID int64, sub *WebPushSubscription) error
DeleteWebPushSubscription(ctx context.Context, id int64) error
}
M database/postgres.go => database/postgres.go +13 -7
@@ 98,6 98,7 @@ CREATE TABLE "WebPushSubscription" (
id SERIAL PRIMARY KEY,
created_at TIMESTAMP WITH TIME ZONE NOT NULL,
updated_at TIMESTAMP WITH TIME ZONE NOT NULL,
+ "user" INTEGER REFERENCES "User"(id) ON DELETE CASCADE,
network INTEGER REFERENCES "Network"(id) ON DELETE CASCADE,
endpoint TEXT NOT NULL,
key_vapid TEXT,
@@ 147,6 148,11 @@ var postgresMigrations = []string{
UNIQUE(network, endpoint)
);
`,
+ `
+ ALTER TABLE "WebPushSubscription"
+ ADD COLUMN "user" INTEGER
+ REFERENCES "User"(id) ON DELETE CASCADE
+ `,
}
type PostgresDB struct {
@@ 704,7 710,7 @@ func (db *PostgresDB) StoreWebPushConfig(ctx context.Context, config *WebPushCon
return err
}
-func (db *PostgresDB) ListWebPushSubscriptions(ctx context.Context, networkID int64) ([]WebPushSubscription, error) {
+func (db *PostgresDB) ListWebPushSubscriptions(ctx context.Context, userID, networkID int64) ([]WebPushSubscription, error) {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
@@ 716,7 722,7 @@ func (db *PostgresDB) ListWebPushSubscriptions(ctx context.Context, networkID in
rows, err := db.db.QueryContext(ctx, `
SELECT id, endpoint, key_auth, key_p256dh, key_vapid
FROM "WebPushSubscription"
- WHERE network IS NOT DISTINCT FROM $1`, nullNetworkID)
+ WHERE "user" = $1 AND network IS NOT DISTINCT FROM $2`, userID, nullNetworkID)
if err != nil {
return nil, err
}
@@ 734,7 740,7 @@ func (db *PostgresDB) ListWebPushSubscriptions(ctx context.Context, networkID in
return subs, rows.Err()
}
-func (db *PostgresDB) StoreWebPushSubscription(ctx context.Context, networkID int64, sub *WebPushSubscription) error {
+func (db *PostgresDB) StoreWebPushSubscription(ctx context.Context, userID, networkID int64, sub *WebPushSubscription) error {
ctx, cancel := context.WithTimeout(ctx, postgresQueryTimeout)
defer cancel()
@@ 753,11 759,11 @@ func (db *PostgresDB) StoreWebPushSubscription(ctx context.Context, networkID in
sub.Keys.Auth, sub.Keys.P256DH, sub.Keys.VAPID, sub.ID)
} else {
err = db.db.QueryRowContext(ctx, `
- INSERT INTO "WebPushSubscription" (created_at, updated_at, network,
- endpoint, key_auth, key_p256dh, key_vapid)
- VALUES (NOW(), NOW(), $1, $2, $3, $4, $5)
+ INSERT INTO "WebPushSubscription" (created_at, updated_at, "user",
+ network, endpoint, key_auth, key_p256dh, key_vapid)
+ VALUES (NOW(), NOW(), $1, $2, $3, $4, $5, $6)
RETURNING id`,
- nullNetworkID, sub.Endpoint, sub.Keys.Auth, sub.Keys.P256DH,
+ nullNetworkID, userID, sub.Endpoint, sub.Keys.Auth, sub.Keys.P256DH,
sub.Keys.VAPID).Scan(&sub.ID)
}
M database/sqlite.go => database/sqlite.go +12 -6
@@ 97,11 97,13 @@ CREATE TABLE WebPushSubscription (
id INTEGER PRIMARY KEY,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL,
+ user INTEGER NOT NULL,
network INTEGER,
endpoint TEXT NOT NULL,
key_vapid TEXT,
key_auth TEXT,
key_p256dh TEXT,
+ FOREIGN KEY(user) REFERENCES User(id),
FOREIGN KEY(network) REFERENCES Network(id),
UNIQUE(network, endpoint)
);
@@ 237,6 239,9 @@ var sqliteMigrations = []string{
UNIQUE(network, endpoint)
);
`,
+ `
+ ALTER TABLE WebPushSubscription ADD COLUMN user INTEGER REFERENCES User(id);
+ `,
}
type SqliteDB struct {
@@ 878,7 883,7 @@ func (db *SqliteDB) StoreWebPushConfig(ctx context.Context, config *WebPushConfi
return err
}
-func (db *SqliteDB) ListWebPushSubscriptions(ctx context.Context, networkID int64) ([]WebPushSubscription, error) {
+func (db *SqliteDB) ListWebPushSubscriptions(ctx context.Context, userID, networkID int64) ([]WebPushSubscription, error) {
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel()
@@ 890,7 895,7 @@ func (db *SqliteDB) ListWebPushSubscriptions(ctx context.Context, networkID int6
rows, err := db.db.QueryContext(ctx, `
SELECT id, endpoint, key_auth, key_p256dh, key_vapid
FROM WebPushSubscription
- WHERE network IS ?`, nullNetworkID)
+ WHERE user = ? AND network IS ?`, userID, nullNetworkID)
if err != nil {
return nil, err
}
@@ 908,12 913,13 @@ func (db *SqliteDB) ListWebPushSubscriptions(ctx context.Context, networkID int6
return subs, rows.Err()
}
-func (db *SqliteDB) StoreWebPushSubscription(ctx context.Context, networkID int64, sub *WebPushSubscription) error {
+func (db *SqliteDB) StoreWebPushSubscription(ctx context.Context, userID, networkID int64, sub *WebPushSubscription) error {
ctx, cancel := context.WithTimeout(ctx, sqliteQueryTimeout)
defer cancel()
args := []interface{}{
sql.Named("id", sub.ID),
+ sql.Named("user", userID),
sql.Named("network", sql.NullInt64{
Int64: networkID,
Valid: networkID != 0,
@@ 937,10 943,10 @@ func (db *SqliteDB) StoreWebPushSubscription(ctx context.Context, networkID int6
var res sql.Result
res, err = db.db.ExecContext(ctx, `
INSERT INTO
- WebPushSubscription(created_at, updated_at, network, endpoint,
+ WebPushSubscription(created_at, updated_at, user, network, endpoint,
key_auth, key_p256dh, key_vapid)
- VALUES (:now, :now, :network, :endpoint, :key_auth, :key_p256dh,
- :key_vapid)`,
+ VALUES (:now, :now, :user, :network, :endpoint, :key_auth,
+ :key_p256dh, :key_vapid)`,
args...)
if err != nil {
return err
M downstream.go => downstream.go +2 -2
@@ 3278,7 3278,7 @@ func (dc *downstreamConn) handleMessageRegistered(ctx context.Context, msg *irc.
// TODO: limit max number of subscriptions, prune old ones
- if err := dc.user.srv.db.StoreWebPushSubscription(ctx, networkID, &newSub); err != nil {
+ if err := dc.user.srv.db.StoreWebPushSubscription(ctx, dc.user.ID, networkID, &newSub); err != nil {
dc.logger.Printf("failed to store Web push subscription: %v", err)
return ircError{&irc.Message{
Command: "FAIL",
@@ 3382,7 3382,7 @@ func (dc *downstreamConn) findWebPushSubscription(ctx context.Context, endpoint
networkID = dc.network.ID
}
- subs, err := dc.user.srv.db.ListWebPushSubscriptions(ctx, networkID)
+ subs, err := dc.user.srv.db.ListWebPushSubscriptions(ctx, dc.user.ID, networkID)
if err != nil {
return nil, err
}
M user.go => user.go +1 -1
@@ 445,7 445,7 @@ func (net *network) autoSaveSASLPlain(ctx context.Context, username, password st
}
func (net *network) broadcastWebPush(ctx context.Context, msg *irc.Message) {
- subs, err := net.user.srv.db.ListWebPushSubscriptions(ctx, net.ID)
+ subs, err := net.user.srv.db.ListWebPushSubscriptions(ctx, net.user.ID, net.ID)
if err != nil {
net.logger.Printf("failed to list Web push subscriptions: %v", err)
return