diff --git a/syncapi/storage/interface.go b/syncapi/storage/interface.go index 9cff4cad..313b8758 100644 --- a/syncapi/storage/interface.go +++ b/syncapi/storage/interface.go @@ -140,4 +140,6 @@ type Database interface { StoreReceipt(ctx context.Context, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) // GetRoomReceipts gets all receipts for a given roomID GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]eduAPI.OutputReceiptEvent, error) + + GetPaginatedRooms(ctx context.Context, userID string, offset, count int) ([]string, error) } diff --git a/syncapi/storage/postgres/memberships_table.go b/syncapi/storage/postgres/memberships_table.go index 6566544d..6397626c 100644 --- a/syncapi/storage/postgres/memberships_table.go +++ b/syncapi/storage/postgres/memberships_table.go @@ -19,6 +19,7 @@ import ( "database/sql" "fmt" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" @@ -62,9 +63,14 @@ const selectMembershipSQL = "" + " ORDER BY stream_pos DESC" + " LIMIT 1" +const selectMembershipsSQL = "" + + "SELECT room_id, membership FROM syncapi_memberships" + + " WHERE user_id = $1" + type membershipsStatements struct { - upsertMembershipStmt *sql.Stmt - selectMembershipStmt *sql.Stmt + upsertMembershipStmt *sql.Stmt + selectMembershipStmt *sql.Stmt + selectMembershipsStmt *sql.Stmt } func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) { @@ -79,6 +85,9 @@ func NewPostgresMembershipsTable(db *sql.DB) (tables.Memberships, error) { if s.selectMembershipStmt, err = db.Prepare(selectMembershipSQL); err != nil { return nil, err } + if s.selectMembershipsStmt, err = db.Prepare(selectMembershipsSQL); err != nil { + return nil, err + } return s, nil } @@ -109,3 +118,24 @@ func (s *membershipsStatements) SelectMembership( err = stmt.QueryRowContext(ctx, roomID, userID, memberships).Scan(&eventID, &streamPos, &topologyPos) return } + +func (s *membershipsStatements) SelectMemberships( + ctx context.Context, txn *sql.Tx, userID string, +) (map[string]string, error) { + var roomID, membership string + result := map[string]string{} + stmt := sqlutil.TxStmt(txn, s.selectMembershipsStmt) + rows, err := stmt.QueryContext(ctx, userID) + if err != nil { + return nil, fmt.Errorf("stmt.QueryContext: %w", err) + } + defer internal.CloseAndLogIfError(ctx, rows, "failed to close rows") + for rows.Next() { + err = rows.Scan(&roomID, &membership) + if err != nil { + return nil, fmt.Errorf("rows.Scan: %w", err) + } + result[roomID] = membership + } + return result, nil +} diff --git a/syncapi/storage/postgres/output_room_events_table.go b/syncapi/storage/postgres/output_room_events_table.go index bd7aa018..7aa74ad5 100644 --- a/syncapi/storage/postgres/output_room_events_table.go +++ b/syncapi/storage/postgres/output_room_events_table.go @@ -19,6 +19,7 @@ import ( "context" "database/sql" "encoding/json" + "fmt" "sort" "github.com/matrix-org/dendrite/internal" @@ -127,19 +128,27 @@ const selectStateInRangeSQL = "" + " ORDER BY id ASC" + " LIMIT $8" +const bulkSelectMaxStreamPositionsSQL = "" + + "SELECT room_id, MAX(id) AS id FROM syncapi_output_room_events" + + " WHERE room_id = ANY($1)" + + " GROUP BY room_id" + + " ORDER BY id DESC" + + " OFFSET $2 LIMIT $3" + const deleteEventsForRoomSQL = "" + "DELETE FROM syncapi_output_room_events WHERE room_id = $1" type outputRoomEventsStatements struct { - insertEventStmt *sql.Stmt - selectEventsStmt *sql.Stmt - selectMaxEventIDStmt *sql.Stmt - selectRecentEventsStmt *sql.Stmt - selectRecentEventsForSyncStmt *sql.Stmt - selectEarlyEventsStmt *sql.Stmt - selectStateInRangeStmt *sql.Stmt - updateEventJSONStmt *sql.Stmt - deleteEventsForRoomStmt *sql.Stmt + insertEventStmt *sql.Stmt + selectEventsStmt *sql.Stmt + selectMaxEventIDStmt *sql.Stmt + selectRecentEventsStmt *sql.Stmt + selectRecentEventsForSyncStmt *sql.Stmt + selectEarlyEventsStmt *sql.Stmt + selectStateInRangeStmt *sql.Stmt + updateEventJSONStmt *sql.Stmt + deleteEventsForRoomStmt *sql.Stmt + bulkSelectMaxStreamPositionsStmt *sql.Stmt } func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) { @@ -175,6 +184,9 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) { if s.deleteEventsForRoomStmt, err = db.Prepare(deleteEventsForRoomSQL); err != nil { return nil, err } + if s.bulkSelectMaxStreamPositionsStmt, err = db.Prepare(bulkSelectMaxStreamPositionsSQL); err != nil { + return nil, err + } return s, nil } @@ -435,6 +447,28 @@ func (s *outputRoomEventsStatements) DeleteEventsForRoom( return err } +func (s *outputRoomEventsStatements) BulkSelectMaxStreamPositions( + ctx context.Context, txn *sql.Tx, roomIDs []string, offset, count int, +) (map[string]types.StreamPosition, error) { + result := map[string]types.StreamPosition{} + stmt := sqlutil.TxStmt(txn, s.bulkSelectMaxStreamPositionsStmt) + rows, err := stmt.QueryContext(ctx, roomIDs, offset, count) + if err != nil { + return nil, fmt.Errorf("stmt.QueryContext: %w", err) + } + defer internal.CloseAndLogIfError(ctx, rows, "failed to close rows") + for rows.Next() { + var roomID string + var pos types.StreamPosition + err = rows.Scan(&roomID, &pos) + if err != nil { + return nil, fmt.Errorf("rows.Scan: %w", err) + } + result[roomID] = pos + } + return result, nil +} + func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) { var result []types.StreamEvent for rows.Next() { diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index b8271877..933b2ed3 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -956,3 +956,23 @@ func (d *Database) GetRoomReceipts(ctx context.Context, roomIDs []string, stream _, receipts, err := d.Receipts.SelectRoomReceiptsAfter(ctx, roomIDs, streamPos) return receipts, err } + +func (d *Database) GetPaginatedRooms(ctx context.Context, userID string, offset, count int) ([]string, error) { + memberships, err := d.Memberships.SelectMemberships(ctx, nil, userID) + if err != nil { + return nil, fmt.Errorf("d.Memberships.SelectMemberships: %w", err) + } + rooms := []string{} + for roomID := range memberships { + rooms = append(rooms, roomID) + } + positions, err := d.OutputEvents.BulkSelectMaxStreamPositions(ctx, nil, rooms, offset, count) + if err != nil { + return nil, fmt.Errorf("d.Events.BulkSelectMaxStreamPositions: %w", err) + } + rooms = rooms[:0] + for roomID := range positions { + rooms = append(rooms, roomID) + } + return rooms, nil +} diff --git a/syncapi/storage/sqlite3/memberships_table.go b/syncapi/storage/sqlite3/memberships_table.go index e5445e81..17736ce3 100644 --- a/syncapi/storage/sqlite3/memberships_table.go +++ b/syncapi/storage/sqlite3/memberships_table.go @@ -20,6 +20,7 @@ import ( "fmt" "strings" + "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/types" @@ -63,9 +64,14 @@ const selectMembershipSQL = "" + " ORDER BY stream_pos DESC" + " LIMIT 1" +const selectMembershipsSQL = "" + + "SELECT room_id, membership FROM syncapi_memberships" + + " WHERE user_id = $1" + type membershipsStatements struct { - db *sql.DB - upsertMembershipStmt *sql.Stmt + db *sql.DB + upsertMembershipStmt *sql.Stmt + selectMembershipsStmt *sql.Stmt } func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) { @@ -79,6 +85,9 @@ func NewSqliteMembershipsTable(db *sql.DB) (tables.Memberships, error) { if s.upsertMembershipStmt, err = db.Prepare(upsertMembershipSQL); err != nil { return nil, err } + if s.selectMembershipsStmt, err = db.Prepare(selectMembershipsSQL); err != nil { + return nil, err + } return s, nil } @@ -117,3 +126,24 @@ func (s *membershipsStatements) SelectMembership( err = sqlutil.TxStmt(txn, stmt).QueryRowContext(ctx, params...).Scan(&eventID, &streamPos, &topologyPos) return } + +func (s *membershipsStatements) SelectMemberships( + ctx context.Context, txn *sql.Tx, userID string, +) (map[string]string, error) { + var roomID, membership string + result := map[string]string{} + stmt := sqlutil.TxStmt(txn, s.selectMembershipsStmt) + rows, err := stmt.QueryContext(ctx, userID) + if err != nil { + return nil, fmt.Errorf("stmt.QueryContext: %w", err) + } + defer internal.CloseAndLogIfError(ctx, rows, "failed to close rows") + for rows.Next() { + err = rows.Scan(&roomID, &membership) + if err != nil { + return nil, fmt.Errorf("rows.Scan: %w", err) + } + result[roomID] = membership + } + return result, nil +} diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index 37f7ea00..453e5727 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -21,6 +21,7 @@ import ( "encoding/json" "fmt" "sort" + "strings" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/roomserver/api" @@ -90,6 +91,13 @@ const selectStateInRangeSQL = "" + const deleteEventsForRoomSQL = "" + "DELETE FROM syncapi_output_room_events WHERE room_id = $1" +const bulkSelectMaxStreamPositionsSQL = "" + + "SELECT room_id, MAX(id) AS id FROM syncapi_output_room_events" + + " WHERE room_id IN ($1)" + + " GROUP BY room_id" + + " ORDER BY id DESC" + + " LIMIT $2 OFFSET $3" + type outputRoomEventsStatements struct { db *sql.DB streamIDStatements *streamIDStatements @@ -424,6 +432,40 @@ func (s *outputRoomEventsStatements) DeleteEventsForRoom( return err } +func (s *outputRoomEventsStatements) BulkSelectMaxStreamPositions( + ctx context.Context, txn *sql.Tx, roomIDs []string, offset, count int, +) (map[string]types.StreamPosition, error) { + origSQL := strings.Replace(bulkSelectMaxStreamPositionsSQL, "$2", fmt.Sprintf("$%d", len(roomIDs)+1), 1) + origSQL = strings.Replace(origSQL, "$3", fmt.Sprintf("$%d", len(roomIDs)+2), 1) + origSQL = strings.Replace(origSQL, "($1)", sqlutil.QueryVariadic(len(roomIDs)), 1) + origStmt, err := s.db.Prepare(origSQL) + if err != nil { + return nil, fmt.Errorf("s.db.Prepare: %w", err) + } + params := []interface{}{} + for _, roomID := range roomIDs { + params = append(params, roomID) + } + params = append(params, count, offset) + result := map[string]types.StreamPosition{} + stmt := sqlutil.TxStmt(txn, origStmt) + rows, err := stmt.QueryContext(ctx, params...) + if err != nil { + return nil, fmt.Errorf("stmt.QueryContext: %w", err) + } + defer internal.CloseAndLogIfError(ctx, rows, "failed to close rows") + for rows.Next() { + var roomID string + var pos types.StreamPosition + err = rows.Scan(&roomID, &pos) + if err != nil { + return nil, fmt.Errorf("rows.Scan: %w", err) + } + result[roomID] = pos + } + return result, nil +} + func rowsToStreamEvents(rows *sql.Rows) ([]types.StreamEvent, error) { var result []types.StreamEvent for rows.Next() { diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index 02887271..6b0e5754 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -63,6 +63,7 @@ type Events interface { UpdateEventJSON(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error // DeleteEventsForRoom removes all event information for a room. This should only be done when removing the room entirely. DeleteEventsForRoom(ctx context.Context, txn *sql.Tx, roomID string) (err error) + BulkSelectMaxStreamPositions(ctx context.Context, txn *sql.Tx, roomIDs []string, offset, count int) (map[string]types.StreamPosition, error) } // Topology keeps track of the depths and stream positions for all events. @@ -166,4 +167,5 @@ type Receipts interface { type Memberships interface { UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error SelectMembership(ctx context.Context, txn *sql.Tx, roomID, userID, memberships []string) (eventID string, streamPos, topologyPos types.StreamPosition, err error) + SelectMemberships(ctx context.Context, txn *sql.Tx, userID string) (map[string]string, error) } diff --git a/syncapi/sync/request.go b/syncapi/sync/request.go index 09a62e3d..d51e6c0e 100644 --- a/syncapi/sync/request.go +++ b/syncapi/sync/request.go @@ -69,6 +69,18 @@ func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Dat } } + switch req.URL.Query().Get("paginate_by") { + case "latest": + offset, _ := strconv.Atoi(req.URL.Query().Get("offset")) + count, _ := strconv.Atoi(req.URL.Query().Get("count")) + rooms, err := syncDB.GetPaginatedRooms(req.Context(), device.UserID, offset, count) + if err != nil { + return nil, fmt.Errorf("syncDB.GetPaginatedRooms: %w", err) + } + filter.Room.Rooms = rooms + logrus.Warnf("Filtering by rooms: %v", rooms) + } + logger := util.GetLogger(req.Context()).WithFields(logrus.Fields{ "user_id": device.UserID, "device_id": device.ID,