mirror of
https://github.com/hoernschen/dendrite.git
synced 2025-07-29 12:42:46 +00:00
Fix QuerySharedUsers for the SyncAPI keychange consumer (#2554)
* Make more use of base.BaseDendrite * Fix QuerySharedUsers if no UserIDs are supplied
This commit is contained in:
parent
f29cdb26f6
commit
5087b36af0
7 changed files with 137 additions and 31 deletions
|
@ -65,12 +65,18 @@ CREATE TABLE IF NOT EXISTS roomserver_membership (
|
|||
);
|
||||
`
|
||||
|
||||
var selectJoinedUsersSetForRoomsSQL = "" +
|
||||
var selectJoinedUsersSetForRoomsAndUserSQL = "" +
|
||||
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
|
||||
" WHERE room_nid = ANY($1) AND target_nid = ANY($2) AND" +
|
||||
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
|
||||
" GROUP BY target_nid"
|
||||
|
||||
var selectJoinedUsersSetForRoomsSQL = "" +
|
||||
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
|
||||
" WHERE room_nid = ANY($1) AND" +
|
||||
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
|
||||
" GROUP BY target_nid"
|
||||
|
||||
// Insert a row in to membership table so that it can be locked by the
|
||||
// SELECT FOR UPDATE
|
||||
const insertMembershipSQL = "" +
|
||||
|
@ -153,6 +159,7 @@ type membershipStatements struct {
|
|||
selectLocalMembershipsFromRoomStmt *sql.Stmt
|
||||
updateMembershipStmt *sql.Stmt
|
||||
selectRoomsWithMembershipStmt *sql.Stmt
|
||||
selectJoinedUsersSetForRoomsAndUserStmt *sql.Stmt
|
||||
selectJoinedUsersSetForRoomsStmt *sql.Stmt
|
||||
selectKnownUsersStmt *sql.Stmt
|
||||
updateMembershipForgetRoomStmt *sql.Stmt
|
||||
|
@ -178,6 +185,7 @@ func PrepareMembershipTable(db *sql.DB) (tables.Membership, error) {
|
|||
{&s.selectLocalMembershipsFromRoomStmt, selectLocalMembershipsFromRoomSQL},
|
||||
{&s.updateMembershipStmt, updateMembershipSQL},
|
||||
{&s.selectRoomsWithMembershipStmt, selectRoomsWithMembershipSQL},
|
||||
{&s.selectJoinedUsersSetForRoomsAndUserStmt, selectJoinedUsersSetForRoomsAndUserSQL},
|
||||
{&s.selectJoinedUsersSetForRoomsStmt, selectJoinedUsersSetForRoomsSQL},
|
||||
{&s.selectKnownUsersStmt, selectKnownUsersSQL},
|
||||
{&s.updateMembershipForgetRoomStmt, updateMembershipForgetRoom},
|
||||
|
@ -313,8 +321,18 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(
|
|||
roomNIDs []types.RoomNID,
|
||||
userNIDs []types.EventStateKeyNID,
|
||||
) (map[types.EventStateKeyNID]int, error) {
|
||||
var (
|
||||
rows *sql.Rows
|
||||
err error
|
||||
)
|
||||
stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsStmt)
|
||||
rows, err := stmt.QueryContext(ctx, pq.Array(roomNIDs), pq.Array(userNIDs))
|
||||
if len(userNIDs) > 0 {
|
||||
stmt = sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsAndUserStmt)
|
||||
rows, err = stmt.QueryContext(ctx, pq.Array(roomNIDs), pq.Array(userNIDs))
|
||||
} else {
|
||||
rows, err = stmt.QueryContext(ctx, pq.Array(roomNIDs))
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -1214,6 +1214,13 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs [
|
|||
stateKeyNIDs[i] = nid
|
||||
i++
|
||||
}
|
||||
// If we didn't have any userIDs to look up, get the UserIDs for the returned userNIDToCount now
|
||||
if len(userIDs) == 0 {
|
||||
nidToUserID, err = d.EventStateKeys(ctx, stateKeyNIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
result := make(map[string]int, len(userNIDToCount))
|
||||
for nid, count := range userNIDToCount {
|
||||
result[nidToUserID[nid]] = count
|
||||
|
|
|
@ -41,12 +41,18 @@ const membershipSchema = `
|
|||
);
|
||||
`
|
||||
|
||||
var selectJoinedUsersSetForRoomsSQL = "" +
|
||||
var selectJoinedUsersSetForRoomsAndUserSQL = "" +
|
||||
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
|
||||
" WHERE room_nid IN ($1) AND target_nid IN ($2) AND" +
|
||||
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
|
||||
" GROUP BY target_nid"
|
||||
|
||||
var selectJoinedUsersSetForRoomsSQL = "" +
|
||||
"SELECT target_nid, COUNT(room_nid) FROM roomserver_membership" +
|
||||
" WHERE room_nid IN ($1) AND " +
|
||||
" membership_nid = " + fmt.Sprintf("%d", tables.MembershipStateJoin) + " and forgotten = false" +
|
||||
" GROUP BY target_nid"
|
||||
|
||||
// Insert a row in to membership table so that it can be locked by the
|
||||
// SELECT FOR UPDATE
|
||||
const insertMembershipSQL = "" +
|
||||
|
@ -293,8 +299,12 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context,
|
|||
for _, v := range userNIDs {
|
||||
params = append(params, v)
|
||||
}
|
||||
|
||||
query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
|
||||
query = strings.Replace(query, "($2)", sqlutil.QueryVariadicOffset(len(userNIDs), len(roomNIDs)), 1)
|
||||
if len(userNIDs) > 0 {
|
||||
query = strings.Replace(selectJoinedUsersSetForRoomsAndUserSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
|
||||
query = strings.Replace(query, "($2)", sqlutil.QueryVariadicOffset(len(userNIDs), len(roomNIDs)), 1)
|
||||
}
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
if txn != nil {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue