Limit JoinedUsersSetInRooms to interested users (#2234)

* Limit database work in `JoinedUsersSetInRooms` to changed user IDs only

* Comments

* Fix variadic params for SQLite, update comments
This commit is contained in:
Neil Alexander 2022-03-01 13:01:38 +00:00 committed by GitHub
parent 58bf91a585
commit 530f05885d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 49 additions and 31 deletions

View file

@ -269,6 +269,7 @@ type QueryAuthChainResponse struct {
type QuerySharedUsersRequest struct {
UserID string
OtherUserIDs []string
ExcludeRoomIDs []string
IncludeRoomIDs []string
}

View file

@ -696,7 +696,7 @@ func (r *Queryer) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUser
}
roomIDs = roomIDs[:j]
users, err := r.DB.JoinedUsersSetInRooms(ctx, roomIDs)
users, err := r.DB.JoinedUsersSetInRooms(ctx, roomIDs, req.OtherUserIDs)
if err != nil {
return err
}

View file

@ -151,8 +151,8 @@ type Database interface {
// GetBulkStateContent returns all state events which match a given room ID and a given state key tuple. Both must be satisfied for a match.
// If a tuple has the StateKey of '*' and allowWildcards=true then all state events with the EventType should be returned.
GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error)
// JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear.
JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error)
// JoinedUsersSetInRooms returns how many times each of the given users appears across the given rooms.
JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs []string) (map[string]int, error)
// GetLocalServerInRoom returns true if we think we're in a given room or false otherwise.
GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error)
// GetServerInRoom returns true if we think a server is in a given room or false otherwise.

View file

@ -66,7 +66,8 @@ CREATE TABLE IF NOT EXISTS roomserver_membership (
`
var selectJoinedUsersSetForRoomsSQL = "" +
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid = ANY($1) AND" +
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
" WHERE room_nid = ANY($1) AND target_nid = ANY($2) AND" +
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
" GROUP BY target_nid"
@ -306,13 +307,10 @@ func (s *membershipStatements) SelectRoomsWithMembership(
func (s *membershipStatements) SelectJoinedUsersSetForRooms(
ctx context.Context, txn *sql.Tx,
roomNIDs []types.RoomNID,
userNIDs []types.EventStateKeyNID,
) (map[types.EventStateKeyNID]int, error) {
roomIDarray := make([]int64, len(roomNIDs))
for i := range roomNIDs {
roomIDarray[i] = int64(roomNIDs[i])
}
stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsStmt)
rows, err := stmt.QueryContext(ctx, pq.Int64Array(roomIDarray))
rows, err := stmt.QueryContext(ctx, pq.Array(roomNIDs), pq.Array(userNIDs))
if err != nil {
return nil, err
}

View file

@ -1104,13 +1104,23 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
return result, nil
}
// JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear.
func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) {
// JoinedUsersSetInRooms returns a map of how many times the given users appear in the specified rooms.
func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs []string) (map[string]int, error) {
roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, nil, roomIDs)
if err != nil {
return nil, err
}
userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, nil, roomNIDs)
userNIDsMap, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, nil, userIDs)
if err != nil {
return nil, err
}
userNIDs := make([]types.EventStateKeyNID, 0, len(userNIDsMap))
nidToUserID := make(map[types.EventStateKeyNID]string, len(userNIDsMap))
for id, nid := range userNIDsMap {
userNIDs = append(userNIDs, nid)
nidToUserID[nid] = id
}
userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, nil, roomNIDs, userNIDs)
if err != nil {
return nil, err
}
@ -1120,10 +1130,6 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string)
stateKeyNIDs[i] = nid
i++
}
nidToUserID, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, nil, stateKeyNIDs)
if err != nil {
return nil, err
}
if len(nidToUserID) != len(userNIDToCount) {
logrus.Warnf("SelectJoinedUsersSetForRooms found %d users but BulkSelectEventStateKey only returned state key NIDs for %d of them", len(userNIDToCount), len(nidToUserID))
}

View file

@ -42,7 +42,8 @@ const membershipSchema = `
`
var selectJoinedUsersSetForRoomsSQL = "" +
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership WHERE room_nid IN ($1) AND" +
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
" WHERE room_nid IN ($1) AND target_nid IN ($2) AND" +
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
" GROUP BY target_nid"
@ -280,18 +281,22 @@ func (s *membershipStatements) SelectRoomsWithMembership(
return roomNIDs, nil
}
func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) {
iRoomNIDs := make([]interface{}, len(roomNIDs))
for i, v := range roomNIDs {
iRoomNIDs[i] = v
func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]int, error) {
params := make([]interface{}, 0, len(roomNIDs)+len(userNIDs))
for _, v := range roomNIDs {
params = append(params, v)
}
query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomNIDs)), 1)
for _, v := range userNIDs {
params = append(params, v)
}
query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
query = strings.Replace(query, "($2)", sqlutil.QueryVariadicOffset(len(userNIDs), len(roomNIDs)), 1)
var rows *sql.Rows
var err error
if txn != nil {
rows, err = txn.QueryContext(ctx, query, iRoomNIDs...)
rows, err = txn.QueryContext(ctx, query, params...)
} else {
rows, err = s.db.QueryContext(ctx, query, iRoomNIDs...)
rows, err = s.db.QueryContext(ctx, query, params...)
}
if err != nil {
return nil, err

View file

@ -127,9 +127,8 @@ type Membership interface {
SelectMembershipsFromRoomAndMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error)
UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) error
SelectRoomsWithMembership(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error)
// SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms, along with the
// counts of how many rooms they are joined.
SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error)
// SelectJoinedUsersSetForRooms returns how many times each of the given users appears across the given rooms.
SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID, userNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]int, error)
SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error)
UpdateForgetMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool) error
SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error)