mirror of
https://github.com/hoernschen/dendrite.git
synced 2024-12-27 07:28:27 +00:00
Fix syncapi shared users query & device lists (#2614)
* Fix query issue, only add "changed" users if we actually share a room * Avoid log spam if context is done * Undo changes to filterSharedUsers * Add logging again.. * Fix SQLite shared users query * Change query to include invited users
This commit is contained in:
parent
2250768be1
commit
9fe509b18d
5 changed files with 62 additions and 43 deletions
|
@ -18,6 +18,7 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
@ -314,6 +315,11 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
|
||||||
for targetKeyID := range masterKey.Keys {
|
for targetKeyID := range masterKey.Keys {
|
||||||
sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, targetKeyID)
|
sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, targetKeyID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// Stop executing the function if the context was canceled/the deadline was exceeded,
|
||||||
|
// as we can't continue without a valid context.
|
||||||
|
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
return
|
||||||
|
}
|
||||||
logrus.WithError(err).Errorf("a.DB.CrossSigningSigsForTarget failed")
|
logrus.WithError(err).Errorf("a.DB.CrossSigningSigsForTarget failed")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
@ -335,6 +341,11 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
|
||||||
for targetKeyID, key := range forUserID {
|
for targetKeyID, key := range forUserID {
|
||||||
sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, gomatrixserverlib.KeyID(targetKeyID))
|
sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, gomatrixserverlib.KeyID(targetKeyID))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// Stop executing the function if the context was canceled/the deadline was exceeded,
|
||||||
|
// as we can't continue without a valid context.
|
||||||
|
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
return
|
||||||
|
}
|
||||||
logrus.WithError(err).Errorf("a.DB.CrossSigningSigsForTarget failed")
|
logrus.WithError(err).Errorf("a.DB.CrossSigningSigsForTarget failed")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,10 +25,9 @@ import (
|
||||||
"github.com/matrix-org/dendrite/syncapi/types"
|
"github.com/matrix-org/dendrite/syncapi/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
const DeviceListLogName = "dl"
|
|
||||||
|
|
||||||
// DeviceOTKCounts adds one-time key counts to the /sync response
|
// DeviceOTKCounts adds one-time key counts to the /sync response
|
||||||
func DeviceOTKCounts(ctx context.Context, keyAPI keyapi.SyncKeyAPI, userID, deviceID string, res *types.Response) error {
|
func DeviceOTKCounts(ctx context.Context, keyAPI keyapi.SyncKeyAPI, userID, deviceID string, res *types.Response) error {
|
||||||
var queryRes keyapi.QueryOneTimeKeysResponse
|
var queryRes keyapi.QueryOneTimeKeysResponse
|
||||||
|
@ -93,18 +92,13 @@ func DeviceListCatchup(
|
||||||
queryRes.UserIDs = append(queryRes.UserIDs, joinUserIDs...)
|
queryRes.UserIDs = append(queryRes.UserIDs, joinUserIDs...)
|
||||||
queryRes.UserIDs = append(queryRes.UserIDs, leaveUserIDs...)
|
queryRes.UserIDs = append(queryRes.UserIDs, leaveUserIDs...)
|
||||||
queryRes.UserIDs = util.UniqueStrings(queryRes.UserIDs)
|
queryRes.UserIDs = util.UniqueStrings(queryRes.UserIDs)
|
||||||
var sharedUsersMap map[string]int
|
sharedUsersMap := filterSharedUsers(ctx, db, userID, queryRes.UserIDs)
|
||||||
sharedUsersMap, queryRes.UserIDs = filterSharedUsers(ctx, db, userID, queryRes.UserIDs)
|
|
||||||
util.GetLogger(ctx).Debugf(
|
|
||||||
"QueryKeyChanges request off=%d,to=%d response off=%d uids=%v",
|
|
||||||
offset, toOffset, queryRes.Offset, queryRes.UserIDs,
|
|
||||||
)
|
|
||||||
userSet := make(map[string]bool)
|
userSet := make(map[string]bool)
|
||||||
for _, userID := range res.DeviceLists.Changed {
|
for _, userID := range res.DeviceLists.Changed {
|
||||||
userSet[userID] = true
|
userSet[userID] = true
|
||||||
}
|
}
|
||||||
for _, userID := range queryRes.UserIDs {
|
for userID, count := range sharedUsersMap {
|
||||||
if !userSet[userID] {
|
if !userSet[userID] && count > 0 {
|
||||||
res.DeviceLists.Changed = append(res.DeviceLists.Changed, userID)
|
res.DeviceLists.Changed = append(res.DeviceLists.Changed, userID)
|
||||||
hasNew = true
|
hasNew = true
|
||||||
userSet[userID] = true
|
userSet[userID] = true
|
||||||
|
@ -113,7 +107,7 @@ func DeviceListCatchup(
|
||||||
// Finally, add in users who have joined or left.
|
// Finally, add in users who have joined or left.
|
||||||
// TODO: This is sub-optimal because we will add users to `changed` even if we already shared a room with them.
|
// TODO: This is sub-optimal because we will add users to `changed` even if we already shared a room with them.
|
||||||
for _, userID := range joinUserIDs {
|
for _, userID := range joinUserIDs {
|
||||||
if !userSet[userID] {
|
if !userSet[userID] && sharedUsersMap[userID] > 0 {
|
||||||
res.DeviceLists.Changed = append(res.DeviceLists.Changed, userID)
|
res.DeviceLists.Changed = append(res.DeviceLists.Changed, userID)
|
||||||
hasNew = true
|
hasNew = true
|
||||||
userSet[userID] = true
|
userSet[userID] = true
|
||||||
|
@ -126,6 +120,13 @@ func DeviceListCatchup(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
util.GetLogger(ctx).WithFields(logrus.Fields{
|
||||||
|
"user_id": userID,
|
||||||
|
"from": offset,
|
||||||
|
"to": toOffset,
|
||||||
|
"response_offset": queryRes.Offset,
|
||||||
|
}).Debugf("QueryKeyChanges request result: %+v", res.DeviceLists)
|
||||||
|
|
||||||
return types.StreamPosition(queryRes.Offset), hasNew, nil
|
return types.StreamPosition(queryRes.Offset), hasNew, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -220,24 +221,27 @@ func TrackChangedUsers(
|
||||||
// it down to include only users who the requesting user shares a room with.
|
// it down to include only users who the requesting user shares a room with.
|
||||||
func filterSharedUsers(
|
func filterSharedUsers(
|
||||||
ctx context.Context, db storage.SharedUsers, userID string, usersWithChangedKeys []string,
|
ctx context.Context, db storage.SharedUsers, userID string, usersWithChangedKeys []string,
|
||||||
) (map[string]int, []string) {
|
) map[string]int {
|
||||||
sharedUsersMap := make(map[string]int, len(usersWithChangedKeys))
|
sharedUsersMap := make(map[string]int, len(usersWithChangedKeys))
|
||||||
for _, userID := range usersWithChangedKeys {
|
for _, changedUserID := range usersWithChangedKeys {
|
||||||
sharedUsersMap[userID] = 0
|
sharedUsersMap[changedUserID] = 0
|
||||||
|
if changedUserID == userID {
|
||||||
|
// We forcibly put ourselves in this list because we should be notified about our own device updates
|
||||||
|
// and if we are in 0 rooms then we don't technically share any room with ourselves so we wouldn't
|
||||||
|
// be notified about key changes.
|
||||||
|
sharedUsersMap[userID] = 1
|
||||||
|
}
|
||||||
}
|
}
|
||||||
sharedUsers, err := db.SharedUsers(ctx, userID, usersWithChangedKeys)
|
sharedUsers, err := db.SharedUsers(ctx, userID, usersWithChangedKeys)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
util.GetLogger(ctx).WithError(err).Errorf("db.SharedUsers failed: %s", err)
|
||||||
// default to all users so we do needless queries rather than miss some important device update
|
// default to all users so we do needless queries rather than miss some important device update
|
||||||
return nil, usersWithChangedKeys
|
return sharedUsersMap
|
||||||
}
|
}
|
||||||
for _, userID := range sharedUsers {
|
for _, userID := range sharedUsers {
|
||||||
sharedUsersMap[userID]++
|
sharedUsersMap[userID]++
|
||||||
}
|
}
|
||||||
// We forcibly put ourselves in this list because we should be notified about our own device updates
|
return sharedUsersMap
|
||||||
// and if we are in 0 rooms then we don't technically share any room with ourselves so we wouldn't
|
|
||||||
// be notified about key changes.
|
|
||||||
sharedUsersMap[userID] = 1
|
|
||||||
return sharedUsersMap, sharedUsers
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func joinedRooms(res *types.Response, userID string) []string {
|
func joinedRooms(res *types.Response, userID string) []string {
|
||||||
|
|
|
@ -129,6 +129,7 @@ type wantCatchup struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertCatchup(t *testing.T, hasNew bool, syncResponse *types.Response, want wantCatchup) {
|
func assertCatchup(t *testing.T, hasNew bool, syncResponse *types.Response, want wantCatchup) {
|
||||||
|
t.Helper()
|
||||||
if hasNew != want.hasNew {
|
if hasNew != want.hasNew {
|
||||||
t.Errorf("got hasNew=%v want %v", hasNew, want.hasNew)
|
t.Errorf("got hasNew=%v want %v", hasNew, want.hasNew)
|
||||||
}
|
}
|
||||||
|
|
|
@ -112,7 +112,7 @@ const selectEventsWithEventIDsSQL = "" +
|
||||||
const selectSharedUsersSQL = "" +
|
const selectSharedUsersSQL = "" +
|
||||||
"SELECT state_key FROM syncapi_current_room_state WHERE room_id = ANY(" +
|
"SELECT state_key FROM syncapi_current_room_state WHERE room_id = ANY(" +
|
||||||
" SELECT room_id FROM syncapi_current_room_state WHERE state_key = $1 AND membership='join'" +
|
" SELECT room_id FROM syncapi_current_room_state WHERE state_key = $1 AND membership='join'" +
|
||||||
") AND state_key = ANY($2) AND membership='join';"
|
") AND state_key = ANY($2) AND membership IN ('join', 'invite');"
|
||||||
|
|
||||||
type currentRoomStateStatements struct {
|
type currentRoomStateStatements struct {
|
||||||
upsertRoomStateStmt *sql.Stmt
|
upsertRoomStateStmt *sql.Stmt
|
||||||
|
@ -407,7 +407,7 @@ func (s *currentRoomStateStatements) SelectSharedUsers(
|
||||||
ctx context.Context, txn *sql.Tx, userID string, otherUserIDs []string,
|
ctx context.Context, txn *sql.Tx, userID string, otherUserIDs []string,
|
||||||
) ([]string, error) {
|
) ([]string, error) {
|
||||||
stmt := sqlutil.TxStmt(txn, s.selectSharedUsersStmt)
|
stmt := sqlutil.TxStmt(txn, s.selectSharedUsersStmt)
|
||||||
rows, err := stmt.QueryContext(ctx, userID, otherUserIDs)
|
rows, err := stmt.QueryContext(ctx, userID, pq.Array(otherUserIDs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -94,9 +94,9 @@ const selectEventsWithEventIDsSQL = "" +
|
||||||
" FROM syncapi_current_room_state WHERE event_id IN ($1)"
|
" FROM syncapi_current_room_state WHERE event_id IN ($1)"
|
||||||
|
|
||||||
const selectSharedUsersSQL = "" +
|
const selectSharedUsersSQL = "" +
|
||||||
"SELECT state_key FROM syncapi_current_room_state WHERE room_id = ANY(" +
|
"SELECT state_key FROM syncapi_current_room_state WHERE room_id IN(" +
|
||||||
" SELECT room_id FROM syncapi_current_room_state WHERE state_key = $1 AND membership='join'" +
|
" SELECT room_id FROM syncapi_current_room_state WHERE state_key = $1 AND membership='join'" +
|
||||||
") AND state_key IN ($2) AND membership='join';"
|
") AND state_key IN ($2) AND membership IN ('join', 'invite');"
|
||||||
|
|
||||||
type currentRoomStateStatements struct {
|
type currentRoomStateStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
@ -420,25 +420,28 @@ func (s *currentRoomStateStatements) SelectStateEvent(
|
||||||
func (s *currentRoomStateStatements) SelectSharedUsers(
|
func (s *currentRoomStateStatements) SelectSharedUsers(
|
||||||
ctx context.Context, txn *sql.Tx, userID string, otherUserIDs []string,
|
ctx context.Context, txn *sql.Tx, userID string, otherUserIDs []string,
|
||||||
) ([]string, error) {
|
) ([]string, error) {
|
||||||
query := strings.Replace(selectSharedUsersSQL, "($2)", sqlutil.QueryVariadicOffset(len(otherUserIDs), 1), 1)
|
|
||||||
stmt, err := s.db.Prepare(query)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("SelectSharedUsers s.db.Prepare: %w", err)
|
|
||||||
}
|
|
||||||
defer internal.CloseAndLogIfError(ctx, stmt, "SelectSharedUsers: stmt.close() failed")
|
|
||||||
rows, err := sqlutil.TxStmt(txn, stmt).QueryContext(ctx, userID, otherUserIDs)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "selectSharedUsersStmt: rows.close() failed")
|
|
||||||
|
|
||||||
var stateKey string
|
params := make([]interface{}, len(otherUserIDs)+1)
|
||||||
result := make([]string, 0, len(otherUserIDs))
|
params[0] = userID
|
||||||
for rows.Next() {
|
for k, v := range otherUserIDs {
|
||||||
if err := rows.Scan(&stateKey); err != nil {
|
params[k+1] = v
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
result = append(result, stateKey)
|
|
||||||
}
|
}
|
||||||
return result, rows.Err()
|
|
||||||
|
result := make([]string, 0, len(otherUserIDs))
|
||||||
|
query := strings.Replace(selectSharedUsersSQL, "($2)", sqlutil.QueryVariadicOffset(len(otherUserIDs), 1), 1)
|
||||||
|
err := sqlutil.RunLimitedVariablesQuery(
|
||||||
|
ctx, query, s.db, params, sqlutil.SQLite3MaxVariables,
|
||||||
|
func(rows *sql.Rows) error {
|
||||||
|
var stateKey string
|
||||||
|
for rows.Next() {
|
||||||
|
if err := rows.Scan(&stateKey); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
result = append(result, stateKey)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return result, err
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue