Modify QuerySharedUsers to handle counts/include/exclude (#1219)

* Modify QuerySharedUsers to handle counts/include/exclude

We will need this functionality when working out whether to
send device list changes to users who have joined/left a room.

* Linting
This commit is contained in:
Kegsay 2020-07-24 10:33:41 +01:00 committed by GitHub
parent 98f2f09bb4
commit af5b4d1f6b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 95 additions and 29 deletions

View file

@ -36,11 +36,13 @@ type CurrentStateInternalAPI interface {
} }
type QuerySharedUsersRequest struct { type QuerySharedUsersRequest struct {
UserID string UserID string
ExcludeRoomIDs []string
IncludeRoomIDs []string
} }
type QuerySharedUsersResponse struct { type QuerySharedUsersResponse struct {
UserIDs []string UserIDsToCount map[string]int
} }
type QueryRoomsForUserRequest struct { type QueryRoomsForUserRequest struct {

View file

@ -20,7 +20,6 @@ import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"reflect" "reflect"
"sort"
"testing" "testing"
"time" "time"
@ -227,13 +226,31 @@ func TestQuerySharedUsers(t *testing.T) {
req api.QuerySharedUsersRequest req api.QuerySharedUsersRequest
wantRes api.QuerySharedUsersResponse wantRes api.QuerySharedUsersResponse
}{ }{
// Simple case: sharing (A,B) (A,C) (A,B) (A) produces (A,B,C) // Simple case: sharing (A,B) (A,C) (A,B) (A) produces (A:4,B:2,C:1)
{ {
req: api.QuerySharedUsersRequest{ req: api.QuerySharedUsersRequest{
UserID: "@alice:localhost", UserID: "@alice:localhost",
}, },
wantRes: api.QuerySharedUsersResponse{ wantRes: api.QuerySharedUsersResponse{
UserIDs: []string{"@alice:localhost", "@bob:localhost", "@charlie:localhost"}, UserIDsToCount: map[string]int{
"@alice:localhost": 4,
"@bob:localhost": 2,
"@charlie:localhost": 1,
},
},
},
// Exclude (A,C): sharing (A,B) (A,B) (A) produces (A:3,B:2)
{
req: api.QuerySharedUsersRequest{
UserID: "@alice:localhost",
ExcludeRoomIDs: []string{"!foo2:bar"},
},
wantRes: api.QuerySharedUsersResponse{
UserIDsToCount: map[string]int{
"@alice:localhost": 3,
"@bob:localhost": 2,
},
}, },
}, },
@ -243,7 +260,7 @@ func TestQuerySharedUsers(t *testing.T) {
UserID: "@unknownuser:localhost", UserID: "@unknownuser:localhost",
}, },
wantRes: api.QuerySharedUsersResponse{ wantRes: api.QuerySharedUsersResponse{
UserIDs: nil, UserIDsToCount: map[string]int{},
}, },
}, },
@ -253,7 +270,35 @@ func TestQuerySharedUsers(t *testing.T) {
UserID: "@dave:localhost", UserID: "@dave:localhost",
}, },
wantRes: api.QuerySharedUsersResponse{ wantRes: api.QuerySharedUsersResponse{
UserIDs: nil, UserIDsToCount: map[string]int{},
},
},
// left real user but with included room returns the included room member
{
req: api.QuerySharedUsersRequest{
UserID: "@dave:localhost",
IncludeRoomIDs: []string{"!foo:bar"},
},
wantRes: api.QuerySharedUsersResponse{
UserIDsToCount: map[string]int{
"@alice:localhost": 1,
"@bob:localhost": 1,
},
},
},
// including a room more than once doesn't double counts
{
req: api.QuerySharedUsersRequest{
UserID: "@dave:localhost",
IncludeRoomIDs: []string{"!foo:bar", "!foo:bar", "!foo:bar"},
},
wantRes: api.QuerySharedUsersResponse{
UserIDsToCount: map[string]int{
"@alice:localhost": 1,
"@bob:localhost": 1,
},
}, },
}, },
} }
@ -266,10 +311,8 @@ func TestQuerySharedUsers(t *testing.T) {
t.Errorf("QuerySharedUsers returned error: %s", err) t.Errorf("QuerySharedUsers returned error: %s", err)
continue continue
} }
sort.Strings(res.UserIDs) if !reflect.DeepEqual(res.UserIDsToCount, tc.wantRes.UserIDsToCount) {
sort.Strings(tc.wantRes.UserIDs) t.Errorf("QuerySharedUsers got users %+v want %+v", res.UserIDsToCount, tc.wantRes.UserIDsToCount)
if !reflect.DeepEqual(res.UserIDs, tc.wantRes.UserIDs) {
t.Errorf("QuerySharedUsers got users %+v want %+v", res.UserIDs, tc.wantRes.UserIDs)
} }
} }
} }

View file

@ -74,10 +74,27 @@ func (a *CurrentStateInternalAPI) QuerySharedUsers(ctx context.Context, req *api
if err != nil { if err != nil {
return err return err
} }
roomIDs = append(roomIDs, req.IncludeRoomIDs...)
excludeMap := make(map[string]bool)
for _, roomID := range req.ExcludeRoomIDs {
excludeMap[roomID] = true
}
// filter out excluded rooms
j := 0
for i := range roomIDs {
// move elements to include to the beginning of the slice
// then trim elements on the right
if !excludeMap[roomIDs[i]] {
roomIDs[j] = roomIDs[i]
j++
}
}
roomIDs = roomIDs[:j]
users, err := a.DB.JoinedUsersSetInRooms(ctx, roomIDs) users, err := a.DB.JoinedUsersSetInRooms(ctx, roomIDs)
if err != nil { if err != nil {
return err return err
} }
res.UserIDs = users res.UserIDsToCount = users
return nil return nil
} }

View file

@ -37,6 +37,6 @@ type Database interface {
GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error) GetBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]tables.StrippedEvent, error)
// Redact a state event // Redact a state event
RedactEvent(ctx context.Context, redactedEventID string, redactedBecause gomatrixserverlib.HeaderedEvent) error RedactEvent(ctx context.Context, redactedEventID string, redactedBecause gomatrixserverlib.HeaderedEvent) error
// JoinedUsersSetInRooms returns all joined users in the rooms given. // JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear.
JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) ([]string, error) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error)
} }

View file

@ -78,7 +78,8 @@ const selectBulkStateContentWildSQL = "" +
"SELECT room_id, type, state_key, content_value FROM currentstate_current_room_state WHERE room_id = ANY($1) AND type = ANY($2)" "SELECT room_id, type, state_key, content_value FROM currentstate_current_room_state WHERE room_id = ANY($1) AND type = ANY($2)"
const selectJoinedUsersSetForRoomsSQL = "" + const selectJoinedUsersSetForRoomsSQL = "" +
"SELECT DISTINCT state_key FROM currentstate_current_room_state WHERE room_id = ANY($1) AND type = 'm.room.member' and content_value = 'join'" "SELECT state_key, COUNT(room_id) FROM currentstate_current_room_state WHERE room_id = ANY($1) AND" +
" type = 'm.room.member' and content_value = 'join' GROUP BY state_key"
type currentRoomStateStatements struct { type currentRoomStateStatements struct {
upsertRoomStateStmt *sql.Stmt upsertRoomStateStmt *sql.Stmt
@ -124,21 +125,22 @@ func NewPostgresCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, erro
return s, nil return s, nil
} }
func (s *currentRoomStateStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomIDs []string) ([]string, error) { func (s *currentRoomStateStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomIDs []string) (map[string]int, error) {
rows, err := s.selectJoinedUsersSetForRoomsStmt.QueryContext(ctx, pq.StringArray(roomIDs)) rows, err := s.selectJoinedUsersSetForRoomsStmt.QueryContext(ctx, pq.StringArray(roomIDs))
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsersSetForRooms: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsersSetForRooms: rows.close() failed")
var userIDs []string result := make(map[string]int)
for rows.Next() { for rows.Next() {
var userID string var userID string
if err := rows.Scan(&userID); err != nil { var count int
if err := rows.Scan(&userID, &count); err != nil {
return nil, err return nil, err
} }
userIDs = append(userIDs, userID) result[userID] = count
} }
return userIDs, rows.Err() return result, rows.Err()
} }
// SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state. // SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state.

View file

@ -86,6 +86,6 @@ func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership
return d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, nil, userID, membership) return d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, nil, userID, membership)
} }
func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) ([]string, error) { func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) {
return d.CurrentRoomState.SelectJoinedUsersSetForRooms(ctx, roomIDs) return d.CurrentRoomState.SelectJoinedUsersSetForRooms(ctx, roomIDs)
} }

View file

@ -67,7 +67,7 @@ const selectBulkStateContentWildSQL = "" +
"SELECT room_id, type, state_key, content_value FROM currentstate_current_room_state WHERE room_id IN ($1) AND type IN ($2)" "SELECT room_id, type, state_key, content_value FROM currentstate_current_room_state WHERE room_id IN ($1) AND type IN ($2)"
const selectJoinedUsersSetForRoomsSQL = "" + const selectJoinedUsersSetForRoomsSQL = "" +
"SELECT DISTINCT state_key FROM currentstate_current_room_state WHERE room_id IN ($1) AND type = 'm.room.member' and content_value = 'join'" "SELECT state_key, COUNT(room_id) FROM currentstate_current_room_state WHERE room_id IN ($1) AND type = 'm.room.member' and content_value = 'join' GROUP BY state_key"
type currentRoomStateStatements struct { type currentRoomStateStatements struct {
db *sql.DB db *sql.DB
@ -106,7 +106,7 @@ func NewSqliteCurrentRoomStateTable(db *sql.DB) (tables.CurrentRoomState, error)
return s, nil return s, nil
} }
func (s *currentRoomStateStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomIDs []string) ([]string, error) { func (s *currentRoomStateStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomIDs []string) (map[string]int, error) {
iRoomIDs := make([]interface{}, len(roomIDs)) iRoomIDs := make([]interface{}, len(roomIDs))
for i, v := range roomIDs { for i, v := range roomIDs {
iRoomIDs[i] = v iRoomIDs[i] = v
@ -117,15 +117,16 @@ func (s *currentRoomStateStatements) SelectJoinedUsersSetForRooms(ctx context.Co
return nil, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsersSetForRooms: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedUsersSetForRooms: rows.close() failed")
var userIDs []string result := make(map[string]int)
for rows.Next() { for rows.Next() {
var userID string var userID string
if err := rows.Scan(&userID); err != nil { var count int
if err := rows.Scan(&userID, &count); err != nil {
return nil, err return nil, err
} }
userIDs = append(userIDs, userID) result[userID] = count
} }
return userIDs, rows.Err() return result, rows.Err()
} }
// SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state. // SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state.

View file

@ -36,8 +36,9 @@ type CurrentRoomState interface {
// SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state. // SelectRoomIDsWithMembership returns the list of room IDs which have the given user in the given membership state.
SelectRoomIDsWithMembership(ctx context.Context, txn *sql.Tx, userID string, membership string) ([]string, error) SelectRoomIDsWithMembership(ctx context.Context, txn *sql.Tx, userID string, membership string) ([]string, error)
SelectBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]StrippedEvent, error) SelectBulkStateContent(ctx context.Context, roomIDs []string, tuples []gomatrixserverlib.StateKeyTuple, allowWildcards bool) ([]StrippedEvent, error)
// SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms. // SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms, along with the
SelectJoinedUsersSetForRooms(ctx context.Context, roomIDs []string) ([]string, error) // counts of how many rooms they are joined.
SelectJoinedUsersSetForRooms(ctx context.Context, roomIDs []string) (map[string]int, error)
} }
// StrippedEvent represents a stripped event for returning extracted content values. // StrippedEvent represents a stripped event for returning extracted content values.