Partly fix notification counts (#2621)

* Fix notification query

* Also for SQLite

* Move tests to whitelist

* Revert "Move tests to whitelist"

This reverts commit a7d0120019a111ce45a447ba40233d9c101e6e9b.
This commit is contained in:
Till 2022-08-05 13:44:20 +02:00 committed by GitHub
parent 2a1df0129e
commit cecd11be9a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 21 additions and 7 deletions

View file

@ -58,7 +58,7 @@ const upsertRoomUnreadNotificationCountsSQL = `INSERT INTO syncapi_notification_
(user_id, room_id, notification_count, highlight_count) (user_id, room_id, notification_count, highlight_count)
VALUES ($1, $2, $3, $4) VALUES ($1, $2, $3, $4)
ON CONFLICT (user_id, room_id) ON CONFLICT (user_id, room_id)
DO UPDATE SET notification_count = $3, highlight_count = $4 DO UPDATE SET id = nextval('syncapi_notification_data_id_seq'), notification_count = $3, highlight_count = $4
RETURNING id` RETURNING id`
const selectUserUnreadNotificationCountsSQL = `SELECT const selectUserUnreadNotificationCountsSQL = `SELECT

View file

@ -25,12 +25,14 @@ import (
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
) )
func NewSqliteNotificationDataTable(db *sql.DB) (tables.NotificationData, error) { func NewSqliteNotificationDataTable(db *sql.DB, streamID *StreamIDStatements) (tables.NotificationData, error) {
_, err := db.Exec(notificationDataSchema) _, err := db.Exec(notificationDataSchema)
if err != nil { if err != nil {
return nil, err return nil, err
} }
r := &notificationDataStatements{} r := &notificationDataStatements{
streamIDStatements: streamID,
}
return r, sqlutil.StatementList{ return r, sqlutil.StatementList{
{&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL}, {&r.upsertRoomUnreadCounts, upsertRoomUnreadNotificationCountsSQL},
{&r.selectUserUnreadCounts, selectUserUnreadNotificationCountsSQL}, {&r.selectUserUnreadCounts, selectUserUnreadNotificationCountsSQL},
@ -39,6 +41,7 @@ func NewSqliteNotificationDataTable(db *sql.DB) (tables.NotificationData, error)
} }
type notificationDataStatements struct { type notificationDataStatements struct {
streamIDStatements *StreamIDStatements
upsertRoomUnreadCounts *sql.Stmt upsertRoomUnreadCounts *sql.Stmt
selectUserUnreadCounts *sql.Stmt selectUserUnreadCounts *sql.Stmt
selectMaxID *sql.Stmt selectMaxID *sql.Stmt
@ -58,8 +61,7 @@ const upsertRoomUnreadNotificationCountsSQL = `INSERT INTO syncapi_notification_
(user_id, room_id, notification_count, highlight_count) (user_id, room_id, notification_count, highlight_count)
VALUES ($1, $2, $3, $4) VALUES ($1, $2, $3, $4)
ON CONFLICT (user_id, room_id) ON CONFLICT (user_id, room_id)
DO UPDATE SET notification_count = $3, highlight_count = $4 DO UPDATE SET id = $5, notification_count = $6, highlight_count = $7`
RETURNING id`
const selectUserUnreadNotificationCountsSQL = `SELECT const selectUserUnreadNotificationCountsSQL = `SELECT
id, room_id, notification_count, highlight_count id, room_id, notification_count, highlight_count
@ -71,7 +73,11 @@ const selectUserUnreadNotificationCountsSQL = `SELECT
const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data` const selectMaxNotificationIDSQL = `SELECT CASE COUNT(*) WHEN 0 THEN 0 ELSE MAX(id) END FROM syncapi_notification_data`
func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (pos types.StreamPosition, err error) { func (r *notificationDataStatements) UpsertRoomUnreadCounts(ctx context.Context, userID, roomID string, notificationCount, highlightCount int) (pos types.StreamPosition, err error) {
err = r.upsertRoomUnreadCounts.QueryRowContext(ctx, userID, roomID, notificationCount, highlightCount).Scan(&pos) pos, err = r.streamIDStatements.nextNotificationID(ctx, nil)
if err != nil {
return
}
_, err = r.upsertRoomUnreadCounts.ExecContext(ctx, userID, roomID, notificationCount, highlightCount, pos, notificationCount, highlightCount)
return return
} }

View file

@ -26,6 +26,8 @@ INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("invite", 0)
ON CONFLICT DO NOTHING; ON CONFLICT DO NOTHING;
INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("presence", 0) INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("presence", 0)
ON CONFLICT DO NOTHING; ON CONFLICT DO NOTHING;
INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("notification", 0)
ON CONFLICT DO NOTHING;
` `
const increaseStreamIDStmt = "" + const increaseStreamIDStmt = "" +
@ -78,3 +80,9 @@ func (s *StreamIDStatements) nextPresenceID(ctx context.Context, txn *sql.Tx) (p
err = increaseStmt.QueryRowContext(ctx, "presence").Scan(&pos) err = increaseStmt.QueryRowContext(ctx, "presence").Scan(&pos)
return return
} }
func (s *StreamIDStatements) nextNotificationID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
err = increaseStmt.QueryRowContext(ctx, "notification").Scan(&pos)
return
}

View file

@ -95,7 +95,7 @@ func (d *SyncServerDatasource) prepare() (err error) {
if err != nil { if err != nil {
return err return err
} }
notificationData, err := NewSqliteNotificationDataTable(d.db) notificationData, err := NewSqliteNotificationDataTable(d.db, &d.streamID)
if err != nil { if err != nil {
return err return err
} }