diff --git a/syncapi/storage/postgres/current_room_state_table.go b/syncapi/storage/postgres/current_room_state_table.go index ee649c16..d646a0e4 100644 --- a/syncapi/storage/postgres/current_room_state_table.go +++ b/syncapi/storage/postgres/current_room_state_table.go @@ -77,6 +77,9 @@ const DeleteRoomStateForRoomSQL = "" + const selectRoomIDsWithMembershipSQL = "" + "SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2" +const selectRoomIDsWithAnyMembershipSQL = "" + + "SELECT DISTINCT room_id, membership FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1" + const selectCurrentStateSQL = "" + "SELECT event_id, headered_event_json FROM syncapi_current_room_state WHERE room_id = $1" + " AND ( $2::text[] IS NULL OR sender = ANY($2) )" + @@ -102,14 +105,15 @@ const selectEventsWithEventIDsSQL = "" + " FROM syncapi_current_room_state WHERE event_id = ANY($1)" type currentRoomStateStatements struct { - upsertRoomStateStmt *sql.Stmt - deleteRoomStateByEventIDStmt *sql.Stmt - DeleteRoomStateForRoomStmt *sql.Stmt - selectRoomIDsWithMembershipStmt *sql.Stmt - selectCurrentStateStmt *sql.Stmt - selectJoinedUsersStmt *sql.Stmt - selectEventsWithEventIDsStmt *sql.Stmt - selectStateEventStmt *sql.Stmt + upsertRoomStateStmt *sql.Stmt + deleteRoomStateByEventIDStmt *sql.Stmt + DeleteRoomStateForRoomStmt *sql.Stmt + selectRoomIDsWithMembershipStmt *sql.Stmt + selectRoomIDsWithAnyMembershipStmt *sql.Stmt + selectCurrentStateStmt *sql.Stmt + selectJoinedUsersStmt *sql.Stmt + selectEventsWithEventIDsStmt *sql.Stmt + selectStateEventStmt *sql.Stmt } func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error) { @@ -130,6 +134,9 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil { return nil, err } + if s.selectRoomIDsWithAnyMembershipStmt, err = db.Prepare(selectRoomIDsWithAnyMembershipSQL); err != nil { + return nil, err + } if s.selectCurrentStateStmt, err = db.Prepare(selectCurrentStateSQL); err != nil { return nil, err } @@ -194,6 +201,31 @@ func (s *currentRoomStateStatements) SelectRoomIDsWithMembership( return result, rows.Err() } +// SelectRoomIDsWithAnyMembership returns a map of all memberships for the given user. +func (s *currentRoomStateStatements) SelectRoomIDsWithAnyMembership( + ctx context.Context, + txn *sql.Tx, + userID string, +) (map[string]string, error) { + stmt := sqlutil.TxStmt(txn, s.selectRoomIDsWithAnyMembershipStmt) + rows, err := stmt.QueryContext(ctx, userID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsWithAnyMembership: rows.close() failed") + + result := map[string]string{} + for rows.Next() { + var roomID string + var membership string + if err := rows.Scan(&roomID, &membership); err != nil { + return nil, err + } + result[roomID] = membership + } + return result, rows.Err() +} + // SelectCurrentState returns all the current state events for the given room. func (s *currentRoomStateStatements) SelectCurrentState( ctx context.Context, txn *sql.Tx, roomID string, diff --git a/syncapi/storage/postgres/output_room_events_table.go b/syncapi/storage/postgres/output_room_events_table.go index d4cc4f3f..26689f44 100644 --- a/syncapi/storage/postgres/output_room_events_table.go +++ b/syncapi/storage/postgres/output_room_events_table.go @@ -119,13 +119,14 @@ const selectStateInRangeSQL = "" + "SELECT event_id, id, headered_event_json, exclude_from_sync, add_state_ids, remove_state_ids" + " FROM syncapi_output_room_events" + " WHERE (id > $1 AND id <= $2) AND (add_state_ids IS NOT NULL OR remove_state_ids IS NOT NULL)" + - " AND ( $3::text[] IS NULL OR sender = ANY($3) )" + - " AND ( $4::text[] IS NULL OR NOT(sender = ANY($4)) )" + - " AND ( $5::text[] IS NULL OR type LIKE ANY($5) )" + - " AND ( $6::text[] IS NULL OR NOT(type LIKE ANY($6)) )" + - " AND ( $7::bool IS NULL OR contains_url = $7 )" + + " AND room_id = ANY($3)" + + " AND ( $4::text[] IS NULL OR sender = ANY($4) )" + + " AND ( $5::text[] IS NULL OR NOT(sender = ANY($5)) )" + + " AND ( $6::text[] IS NULL OR type LIKE ANY($6) )" + + " AND ( $7::text[] IS NULL OR NOT(type LIKE ANY($7)) )" + + " AND ( $8::bool IS NULL OR contains_url = $8 )" + " ORDER BY id ASC" + - " LIMIT $8" + " LIMIT $9" const deleteEventsForRoomSQL = "" + "DELETE FROM syncapi_output_room_events WHERE room_id = $1" @@ -200,12 +201,12 @@ func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, event // two positions, only the most recent state is returned. func (s *outputRoomEventsStatements) SelectStateInRange( ctx context.Context, txn *sql.Tx, r types.Range, - stateFilter *gomatrixserverlib.StateFilter, + stateFilter *gomatrixserverlib.StateFilter, roomIDs []string, ) (map[string]map[string]bool, map[string]types.StreamEvent, error) { stmt := sqlutil.TxStmt(txn, s.selectStateInRangeStmt) rows, err := stmt.QueryContext( - ctx, r.Low(), r.High(), + ctx, r.Low(), r.High(), pq.StringArray(roomIDs), pq.StringArray(stateFilter.Senders), pq.StringArray(stateFilter.NotSenders), pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.Types)), diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index 87d7c6df..2c166eef 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -689,10 +689,26 @@ func (d *Database) GetStateDeltas( var succeeded bool defer sqlutil.EndTransactionWithCheck(txn, &succeeded, &err) + // Look up all memberships for the user. We only care about rooms that a + // user has ever interacted with — joined to, kicked/banned from, left. + memberships, err := d.CurrentRoomState.SelectRoomIDsWithAnyMembership(ctx, txn, userID) + if err != nil { + return nil, nil, err + } + + allRoomIDs := make([]string, 0, len(memberships)) + joinedRoomIDs := make([]string, 0, len(memberships)) + for roomID, membership := range memberships { + allRoomIDs = append(allRoomIDs, roomID) + if membership == gomatrixserverlib.Join { + joinedRoomIDs = append(joinedRoomIDs, roomID) + } + } + var deltas []types.StateDelta // get all the state events ever (i.e. for all available rooms) between these two positions - stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter) + stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter, allRoomIDs) if err != nil { return nil, nil, err } @@ -760,10 +776,6 @@ func (d *Database) GetStateDeltas( } // Add in currently joined rooms - joinedRoomIDs, err := d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join) - if err != nil { - return nil, nil, err - } for _, joinedRoomID := range joinedRoomIDs { deltas = append(deltas, types.StateDelta{ Membership: gomatrixserverlib.Join, @@ -792,6 +804,22 @@ func (d *Database) GetStateDeltasForFullStateSync( var succeeded bool defer sqlutil.EndTransactionWithCheck(txn, &succeeded, &err) + // Look up all memberships for the user. We only care about rooms that a + // user has ever interacted with — joined to, kicked/banned from, left. + memberships, err := d.CurrentRoomState.SelectRoomIDsWithAnyMembership(ctx, txn, userID) + if err != nil { + return nil, nil, err + } + + allRoomIDs := make([]string, 0, len(memberships)) + joinedRoomIDs := make([]string, 0, len(memberships)) + for roomID, membership := range memberships { + allRoomIDs = append(allRoomIDs, roomID) + if membership == gomatrixserverlib.Join { + joinedRoomIDs = append(joinedRoomIDs, roomID) + } + } + // Use a reasonable initial capacity deltas := make(map[string]types.StateDelta) @@ -816,7 +844,7 @@ func (d *Database) GetStateDeltasForFullStateSync( } // Get all the state events ever between these two positions - stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter) + stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter, allRoomIDs) if err != nil { return nil, nil, err } @@ -842,11 +870,6 @@ func (d *Database) GetStateDeltasForFullStateSync( } } - joinedRoomIDs, err := d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join) - if err != nil { - return nil, nil, err - } - // Add full states for all joined rooms for _, joinedRoomID := range joinedRoomIDs { s, stateErr := d.currentStateStreamEventsForRoom(ctx, txn, joinedRoomID, stateFilter) diff --git a/syncapi/storage/sqlite3/current_room_state_table.go b/syncapi/storage/sqlite3/current_room_state_table.go index c91ca692..587f9d24 100644 --- a/syncapi/storage/sqlite3/current_room_state_table.go +++ b/syncapi/storage/sqlite3/current_room_state_table.go @@ -66,6 +66,9 @@ const DeleteRoomStateForRoomSQL = "" + const selectRoomIDsWithMembershipSQL = "" + "SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2" +const selectRoomIDsWithAnyMembershipSQL = "" + + "SELECT DISTINCT room_id, membership FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1" + const selectCurrentStateSQL = "" + "SELECT event_id, headered_event_json FROM syncapi_current_room_state WHERE room_id = $1" @@ -86,14 +89,15 @@ const selectEventsWithEventIDsSQL = "" + " FROM syncapi_current_room_state WHERE event_id IN ($1)" type currentRoomStateStatements struct { - db *sql.DB - streamIDStatements *streamIDStatements - upsertRoomStateStmt *sql.Stmt - deleteRoomStateByEventIDStmt *sql.Stmt - DeleteRoomStateForRoomStmt *sql.Stmt - selectRoomIDsWithMembershipStmt *sql.Stmt - selectJoinedUsersStmt *sql.Stmt - selectStateEventStmt *sql.Stmt + db *sql.DB + streamIDStatements *streamIDStatements + upsertRoomStateStmt *sql.Stmt + deleteRoomStateByEventIDStmt *sql.Stmt + DeleteRoomStateForRoomStmt *sql.Stmt + selectRoomIDsWithMembershipStmt *sql.Stmt + selectRoomIDsWithAnyMembershipStmt *sql.Stmt + selectJoinedUsersStmt *sql.Stmt + selectStateEventStmt *sql.Stmt } func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (tables.CurrentRoomState, error) { @@ -117,6 +121,9 @@ func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (t if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil { return nil, err } + if s.selectRoomIDsWithAnyMembershipStmt, err = db.Prepare(selectRoomIDsWithAnyMembershipSQL); err != nil { + return nil, err + } if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil { return nil, err } @@ -175,6 +182,31 @@ func (s *currentRoomStateStatements) SelectRoomIDsWithMembership( return result, nil } +// SelectRoomIDsWithAnyMembership returns a map of all memberships for the given user. +func (s *currentRoomStateStatements) SelectRoomIDsWithAnyMembership( + ctx context.Context, + txn *sql.Tx, + userID string, +) (map[string]string, error) { + stmt := sqlutil.TxStmt(txn, s.selectRoomIDsWithAnyMembershipStmt) + rows, err := stmt.QueryContext(ctx, userID) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsWithAnyMembership: rows.close() failed") + + result := map[string]string{} + for rows.Next() { + var roomID string + var membership string + if err := rows.Scan(&roomID, &membership); err != nil { + return nil, err + } + result[roomID] = membership + } + return result, rows.Err() +} + // CurrentState returns all the current state events for the given room. func (s *currentRoomStateStatements) SelectCurrentState( ctx context.Context, txn *sql.Tx, roomID string, diff --git a/syncapi/storage/sqlite3/output_room_events_table.go b/syncapi/storage/sqlite3/output_room_events_table.go index 1b256f91..b9115262 100644 --- a/syncapi/storage/sqlite3/output_room_events_table.go +++ b/syncapi/storage/sqlite3/output_room_events_table.go @@ -21,6 +21,7 @@ import ( "encoding/json" "fmt" "sort" + "strings" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/roomserver/api" @@ -87,6 +88,7 @@ const selectStateInRangeSQL = "" + "SELECT event_id, id, headered_event_json, exclude_from_sync, add_state_ids, remove_state_ids" + " FROM syncapi_output_room_events" + " WHERE (id > $1 AND id <= $2)" + + " AND room_id IN ($3)" + " AND ((add_state_ids IS NOT NULL AND add_state_ids != '') OR (remove_state_ids IS NOT NULL AND remove_state_ids != ''))" // WHEN, ORDER BY and LIMIT are appended by prepareWithFilters @@ -155,13 +157,17 @@ func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, event // two positions, only the most recent state is returned. func (s *outputRoomEventsStatements) SelectStateInRange( ctx context.Context, txn *sql.Tx, r types.Range, - stateFilter *gomatrixserverlib.StateFilter, + stateFilter *gomatrixserverlib.StateFilter, roomIDs []string, ) (map[string]map[string]bool, map[string]types.StreamEvent, error) { + stmtSQL := strings.Replace(selectStateInRangeSQL, "($3)", sqlutil.QueryVariadicOffset(len(roomIDs), 2), 1) + inputParams := []interface{}{ + r.Low(), r.High(), + } + for _, roomID := range roomIDs { + inputParams = append(inputParams, roomID) + } stmt, params, err := prepareWithFilters( - s.db, txn, selectStateInRangeSQL, - []interface{}{ - r.Low(), r.High(), - }, + s.db, txn, stmtSQL, inputParams, stateFilter.Senders, stateFilter.NotSenders, stateFilter.Types, stateFilter.NotTypes, nil, stateFilter.Limit, FilterOrderAsc, diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index 1ebb4265..9d1078f5 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -51,7 +51,7 @@ type Peeks interface { } type Events interface { - SelectStateInRange(ctx context.Context, txn *sql.Tx, r types.Range, stateFilter *gomatrixserverlib.StateFilter) (map[string]map[string]bool, map[string]types.StreamEvent, error) + SelectStateInRange(ctx context.Context, txn *sql.Tx, r types.Range, stateFilter *gomatrixserverlib.StateFilter, roomIDs []string) (map[string]map[string]bool, map[string]types.StreamEvent, error) SelectMaxEventID(ctx context.Context, txn *sql.Tx) (id int64, err error) InsertEvent(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, addState, removeState []string, transactionID *api.TransactionID, excludeFromSync bool) (streamPos types.StreamPosition, err error) // SelectRecentEvents returns events between the two stream positions: exclusive of low and inclusive of high. @@ -99,6 +99,8 @@ type CurrentRoomState interface { SelectCurrentState(ctx context.Context, txn *sql.Tx, roomID string, stateFilter *gomatrixserverlib.StateFilter, excludeEventIDs []string) ([]*gomatrixserverlib.HeaderedEvent, error) // SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state. SelectRoomIDsWithMembership(ctx context.Context, txn *sql.Tx, userID string, membership string) ([]string, error) + // SelectRoomIDsWithAnyMembership returns a map of all memberships for the given user. + SelectRoomIDsWithAnyMembership(ctx context.Context, txn *sql.Tx, userID string) (map[string]string, error) // SelectJoinedUsers returns a map of room ID to a list of joined user IDs. SelectJoinedUsers(ctx context.Context) (map[string][]string, error) }