mirror of
https://github.com/hoernschen/dendrite.git
synced 2024-12-27 07:28:27 +00:00
da7bca0224
This makes the following changes: - Adds two new metrics observing the usage of the `DeviceListUpdater` workers - Makes the number of workers configurable - Adds a 30s timeout for DB requests when receiving a device list update over federation
637 lines
22 KiB
Go
637 lines
22 KiB
Go
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package internal
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"hash/fnv"
|
|
"net"
|
|
"strconv"
|
|
"sync"
|
|
"time"
|
|
|
|
rsapi "github.com/matrix-org/dendrite/roomserver/api"
|
|
"github.com/matrix-org/gomatrixserverlib/fclient"
|
|
"github.com/matrix-org/gomatrixserverlib/spec"
|
|
|
|
"github.com/matrix-org/gomatrix"
|
|
"github.com/matrix-org/gomatrixserverlib"
|
|
"github.com/matrix-org/util"
|
|
"github.com/prometheus/client_golang/prometheus"
|
|
"github.com/sirupsen/logrus"
|
|
|
|
fedsenderapi "github.com/matrix-org/dendrite/federationapi/api"
|
|
"github.com/matrix-org/dendrite/setup/process"
|
|
"github.com/matrix-org/dendrite/userapi/api"
|
|
)
|
|
|
|
var (
|
|
deviceListUpdateCount = prometheus.NewCounterVec(
|
|
prometheus.CounterOpts{
|
|
Namespace: "dendrite",
|
|
Subsystem: "keyserver",
|
|
Name: "device_list_update",
|
|
Help: "Number of times we have attempted to update device lists from this server",
|
|
},
|
|
[]string{"server"},
|
|
)
|
|
)
|
|
|
|
const requestTimeout = time.Second * 30
|
|
|
|
func init() {
|
|
prometheus.MustRegister(
|
|
deviceListUpdateCount,
|
|
)
|
|
}
|
|
|
|
// DeviceListUpdater handles device list updates from remote servers.
|
|
//
|
|
// In the case where we have the prev_id for an update, the updater just stores the update (after acquiring a per-user lock).
|
|
// In the case where we do not have the prev_id for an update, the updater marks the user_id as stale and notifies
|
|
// a worker to get the latest device list for this user. Note: stream IDs are scoped per user so missing a prev_id
|
|
// for a (user, device) does not mean that DEVICE is outdated as the previous ID could be for a different device:
|
|
// we have to invalidate all devices for that user. Once the list has been fetched, the per-user lock is acquired and the
|
|
// updater stores the latest list along with the latest stream ID.
|
|
//
|
|
// On startup, the updater spins up N workers which are responsible for querying device keys from remote servers.
|
|
// Workers are scoped by homeserver domain, with one worker responsible for many domains, determined by hashing
|
|
// mod N the server name. Work is sent via a channel which just serves to "poke" the worker as the data is retrieved
|
|
// from the database (which allows us to batch requests to the same server). This has a number of desirable properties:
|
|
// - We guarantee only 1 in-flight /keys/query request per server at any time as there is exactly 1 worker responsible
|
|
// for that domain.
|
|
// - We don't have unbounded growth in proportion to the number of servers (this is more important in a P2P world where
|
|
// we have many many servers)
|
|
// - We can adjust concurrency (at the cost of memory usage) by tuning N, to accommodate mobile devices vs servers.
|
|
//
|
|
// The downsides are that:
|
|
// - Query requests can get queued behind other servers if they hash to the same worker, even if there are other free
|
|
// workers elsewhere. Whilst suboptimal, provided we cap how long a single request can last (e.g using context timeouts)
|
|
// we guarantee we will get around to it. Also, more users on a given server does not increase the number of requests
|
|
// (as /keys/query allows multiple users to be specified) so being stuck behind matrix.org won't materially be any worse
|
|
// than being stuck behind foo.bar
|
|
//
|
|
// In the event that the query fails, a lock is acquired and the server name along with the time to wait before retrying is
|
|
// set in a map. A restarter goroutine periodically probes this map and injects servers which are ready to be retried.
|
|
type DeviceListUpdater struct {
|
|
process *process.ProcessContext
|
|
// A map from user_id to a mutex. Used when we are missing prev IDs so we don't make more than 1
|
|
// request to the remote server and race.
|
|
// TODO: Put in an LRU cache to bound growth
|
|
userIDToMutex map[string]*sync.Mutex
|
|
mu *sync.Mutex // protects UserIDToMutex
|
|
|
|
db DeviceListUpdaterDatabase
|
|
api DeviceListUpdaterAPI
|
|
producer KeyChangeProducer
|
|
fedClient fedsenderapi.KeyserverFederationAPI
|
|
workerChans []chan spec.ServerName
|
|
thisServer spec.ServerName
|
|
|
|
// When device lists are stale for a user, they get inserted into this map with a channel which `Update` will
|
|
// block on or timeout via a select.
|
|
userIDToChan map[string]chan bool
|
|
userIDToChanMu *sync.Mutex
|
|
rsAPI rsapi.KeyserverRoomserverAPI
|
|
}
|
|
|
|
// DeviceListUpdaterDatabase is the subset of functionality from storage.Database required for the updater.
|
|
// Useful for testing.
|
|
type DeviceListUpdaterDatabase interface {
|
|
// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
|
|
// If no domains are given, all user IDs with stale device lists are returned.
|
|
StaleDeviceLists(ctx context.Context, domains []spec.ServerName) ([]string, error)
|
|
|
|
// MarkDeviceListStale sets the stale bit for this user to isStale.
|
|
MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error
|
|
|
|
// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
|
|
// for this (user, device). Does not modify the stream ID for keys. User IDs in `clearUserIDs` will have all their device keys deleted prior
|
|
// to insertion - use this when you have a complete snapshot of a user's keys in order to track device deletions correctly.
|
|
StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error
|
|
|
|
// PrevIDsExists returns true if all prev IDs exist for this user.
|
|
PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error)
|
|
|
|
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
|
|
DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
|
|
|
|
DeleteStaleDeviceLists(ctx context.Context, userIDs []string) error
|
|
}
|
|
|
|
type DeviceListUpdaterAPI interface {
|
|
PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse)
|
|
}
|
|
|
|
// KeyChangeProducer is the interface for producers.KeyChange useful for testing.
|
|
type KeyChangeProducer interface {
|
|
ProduceKeyChanges(keys []api.DeviceMessage) error
|
|
}
|
|
|
|
var deviceListUpdaterBackpressure = prometheus.NewGaugeVec(
|
|
prometheus.GaugeOpts{
|
|
Namespace: "dendrite",
|
|
Subsystem: "keyserver",
|
|
Name: "worker_backpressure",
|
|
Help: "How many device list updater requests are queued",
|
|
},
|
|
[]string{"worker_id"},
|
|
)
|
|
var deviceListUpdaterServersRetrying = prometheus.NewGaugeVec(
|
|
prometheus.GaugeOpts{
|
|
Namespace: "dendrite",
|
|
Subsystem: "keyserver",
|
|
Name: "worker_servers_retrying",
|
|
Help: "How many servers are queued for retry",
|
|
},
|
|
[]string{"worker_id"},
|
|
)
|
|
|
|
// NewDeviceListUpdater creates a new updater which fetches fresh device lists when they go stale.
|
|
func NewDeviceListUpdater(
|
|
process *process.ProcessContext, db DeviceListUpdaterDatabase,
|
|
api DeviceListUpdaterAPI, producer KeyChangeProducer,
|
|
fedClient fedsenderapi.KeyserverFederationAPI, numWorkers int,
|
|
rsAPI rsapi.KeyserverRoomserverAPI, thisServer spec.ServerName,
|
|
enableMetrics bool,
|
|
) *DeviceListUpdater {
|
|
if enableMetrics {
|
|
prometheus.MustRegister(deviceListUpdaterBackpressure, deviceListUpdaterServersRetrying)
|
|
}
|
|
return &DeviceListUpdater{
|
|
process: process,
|
|
userIDToMutex: make(map[string]*sync.Mutex),
|
|
mu: &sync.Mutex{},
|
|
db: db,
|
|
api: api,
|
|
producer: producer,
|
|
fedClient: fedClient,
|
|
thisServer: thisServer,
|
|
workerChans: make([]chan spec.ServerName, numWorkers),
|
|
userIDToChan: make(map[string]chan bool),
|
|
userIDToChanMu: &sync.Mutex{},
|
|
rsAPI: rsAPI,
|
|
}
|
|
}
|
|
|
|
// Start the device list updater, which will try to refresh any stale device lists.
|
|
func (u *DeviceListUpdater) Start() error {
|
|
for i := 0; i < len(u.workerChans); i++ {
|
|
// Allocate a small buffer per channel.
|
|
// If the buffer limit is reached, backpressure will cause the processing of EDUs
|
|
// to stop (in this transaction) until key requests can be made.
|
|
ch := make(chan spec.ServerName, 10)
|
|
u.workerChans[i] = ch
|
|
go u.worker(ch, i)
|
|
}
|
|
|
|
staleLists, err := u.db.StaleDeviceLists(u.process.Context(), []spec.ServerName{})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
newStaleLists := dedupeStaleLists(staleLists)
|
|
offset, step := time.Second*10, time.Second
|
|
if max := len(newStaleLists); max > 120 {
|
|
step = (time.Second * 120) / time.Duration(max)
|
|
}
|
|
for _, userID := range newStaleLists {
|
|
userID := userID // otherwise we are only sending the last entry
|
|
time.AfterFunc(offset, func() {
|
|
u.notifyWorkers(userID)
|
|
})
|
|
offset += step
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// CleanUp removes stale device entries for users we don't share a room with anymore
|
|
func (u *DeviceListUpdater) CleanUp() error {
|
|
staleUsers, err := u.db.StaleDeviceLists(u.process.Context(), []spec.ServerName{})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
res := rsapi.QueryLeftUsersResponse{}
|
|
if err = u.rsAPI.QueryLeftUsers(u.process.Context(), &rsapi.QueryLeftUsersRequest{StaleDeviceListUsers: staleUsers}, &res); err != nil {
|
|
return err
|
|
}
|
|
|
|
if len(res.LeftUsers) == 0 {
|
|
return nil
|
|
}
|
|
logrus.Debugf("Deleting %d stale device list entries", len(res.LeftUsers))
|
|
return u.db.DeleteStaleDeviceLists(u.process.Context(), res.LeftUsers)
|
|
}
|
|
|
|
func (u *DeviceListUpdater) mutex(userID string) *sync.Mutex {
|
|
u.mu.Lock()
|
|
defer u.mu.Unlock()
|
|
if u.userIDToMutex[userID] == nil {
|
|
u.userIDToMutex[userID] = &sync.Mutex{}
|
|
}
|
|
return u.userIDToMutex[userID]
|
|
}
|
|
|
|
// ManualUpdate invalidates the device list for the given user and fetches the latest and tracks it.
|
|
// Blocks until the device list is synced or the timeout is reached.
|
|
func (u *DeviceListUpdater) ManualUpdate(ctx context.Context, serverName spec.ServerName, userID string) error {
|
|
mu := u.mutex(userID)
|
|
mu.Lock()
|
|
err := u.db.MarkDeviceListStale(ctx, userID, true)
|
|
mu.Unlock()
|
|
if err != nil {
|
|
return fmt.Errorf("ManualUpdate: failed to mark device list for %s as stale: %w", userID, err)
|
|
}
|
|
u.notifyWorkers(userID)
|
|
return nil
|
|
}
|
|
|
|
// Update blocks until the update has been stored in the database. It blocks primarily for satisfying sytest,
|
|
// which assumes when /send 200 OKs that the device lists have been updated.
|
|
func (u *DeviceListUpdater) Update(ctx context.Context, event gomatrixserverlib.DeviceListUpdateEvent) error {
|
|
isDeviceListStale, err := u.update(ctx, event)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if isDeviceListStale {
|
|
// poke workers to handle stale device lists
|
|
u.notifyWorkers(event.UserID)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (u *DeviceListUpdater) update(ctx context.Context, event gomatrixserverlib.DeviceListUpdateEvent) (bool, error) {
|
|
mu := u.mutex(event.UserID)
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
// check if we have the prev IDs
|
|
exists, err := u.db.PrevIDsExists(ctx, event.UserID, event.PrevID)
|
|
if err != nil {
|
|
return false, fmt.Errorf("failed to check prev IDs exist for %s (%s): %w", event.UserID, event.DeviceID, err)
|
|
}
|
|
// if this is the first time we're hearing about this user, sync the device list manually.
|
|
if len(event.PrevID) == 0 {
|
|
exists = false
|
|
}
|
|
util.GetLogger(ctx).WithFields(logrus.Fields{
|
|
"prev_ids_exist": exists,
|
|
"user_id": event.UserID,
|
|
"device_id": event.DeviceID,
|
|
"stream_id": event.StreamID,
|
|
"prev_ids": event.PrevID,
|
|
"display_name": event.DeviceDisplayName,
|
|
"deleted": event.Deleted,
|
|
}).Trace("DeviceListUpdater.Update")
|
|
|
|
// if we haven't missed anything update the database and notify users
|
|
if exists || event.Deleted {
|
|
k := event.Keys
|
|
if event.Deleted {
|
|
k = nil
|
|
}
|
|
keys := []api.DeviceMessage{
|
|
{
|
|
Type: api.TypeDeviceKeyUpdate,
|
|
DeviceKeys: &api.DeviceKeys{
|
|
DeviceID: event.DeviceID,
|
|
DisplayName: event.DeviceDisplayName,
|
|
KeyJSON: k,
|
|
UserID: event.UserID,
|
|
},
|
|
StreamID: event.StreamID,
|
|
},
|
|
}
|
|
|
|
// DeviceKeysJSON will side-effect modify this, so it needs
|
|
// to be a copy, not sharing any pointers with the above.
|
|
deviceKeysCopy := *keys[0].DeviceKeys
|
|
deviceKeysCopy.KeyJSON = nil
|
|
existingKeys := []api.DeviceMessage{
|
|
{
|
|
Type: keys[0].Type,
|
|
DeviceKeys: &deviceKeysCopy,
|
|
StreamID: keys[0].StreamID,
|
|
},
|
|
}
|
|
|
|
// fetch what keys we had already and only emit changes
|
|
if err = u.db.DeviceKeysJSON(ctx, existingKeys); err != nil {
|
|
// non-fatal, log and continue
|
|
util.GetLogger(ctx).WithError(err).WithField("user_id", event.UserID).Errorf(
|
|
"failed to query device keys json for calculating diffs",
|
|
)
|
|
}
|
|
|
|
err = u.db.StoreRemoteDeviceKeys(ctx, keys, nil)
|
|
if err != nil {
|
|
return false, fmt.Errorf("failed to store remote device keys for %s (%s): %w", event.UserID, event.DeviceID, err)
|
|
}
|
|
|
|
if err = emitDeviceKeyChanges(u.producer, existingKeys, keys, false); err != nil {
|
|
return false, fmt.Errorf("failed to produce device key changes for %s (%s): %w", event.UserID, event.DeviceID, err)
|
|
}
|
|
return false, nil
|
|
}
|
|
|
|
err = u.db.MarkDeviceListStale(ctx, event.UserID, true)
|
|
if err != nil {
|
|
return false, fmt.Errorf("failed to mark device list for %s as stale: %w", event.UserID, err)
|
|
}
|
|
|
|
return true, nil
|
|
}
|
|
|
|
func (u *DeviceListUpdater) notifyWorkers(userID string) {
|
|
_, remoteServer, err := gomatrixserverlib.SplitID('@', userID)
|
|
if err != nil {
|
|
return
|
|
}
|
|
hash := fnv.New32a()
|
|
_, _ = hash.Write([]byte(remoteServer))
|
|
index := int(int64(hash.Sum32()) % int64(len(u.workerChans)))
|
|
|
|
ch := u.assignChannel(userID)
|
|
deviceListUpdaterBackpressure.With(prometheus.Labels{"worker_id": strconv.Itoa(index)}).Inc()
|
|
defer deviceListUpdaterBackpressure.With(prometheus.Labels{"worker_id": strconv.Itoa(index)}).Dec()
|
|
u.workerChans[index] <- remoteServer
|
|
select {
|
|
case <-ch:
|
|
case <-time.After(10 * time.Second):
|
|
// we don't return an error in this case as it's not a failure condition.
|
|
// we mainly block for the benefit of sytest anyway
|
|
}
|
|
}
|
|
|
|
func (u *DeviceListUpdater) assignChannel(userID string) chan bool {
|
|
u.userIDToChanMu.Lock()
|
|
defer u.userIDToChanMu.Unlock()
|
|
if ch, ok := u.userIDToChan[userID]; ok {
|
|
return ch
|
|
}
|
|
ch := make(chan bool)
|
|
u.userIDToChan[userID] = ch
|
|
return ch
|
|
}
|
|
|
|
func (u *DeviceListUpdater) clearChannel(userID string) {
|
|
u.userIDToChanMu.Lock()
|
|
defer u.userIDToChanMu.Unlock()
|
|
if ch, ok := u.userIDToChan[userID]; ok {
|
|
close(ch)
|
|
delete(u.userIDToChan, userID)
|
|
}
|
|
}
|
|
|
|
func (u *DeviceListUpdater) worker(ch chan spec.ServerName, workerID int) {
|
|
retries := make(map[spec.ServerName]time.Time)
|
|
retriesMu := &sync.Mutex{}
|
|
// restarter goroutine which will inject failed servers into ch when it is time
|
|
go func() {
|
|
var serversToRetry []spec.ServerName
|
|
for {
|
|
serversToRetry = serversToRetry[:0] // reuse memory
|
|
time.Sleep(time.Second)
|
|
retriesMu.Lock()
|
|
now := time.Now()
|
|
for srv, retryAt := range retries {
|
|
if now.After(retryAt) {
|
|
serversToRetry = append(serversToRetry, srv)
|
|
}
|
|
}
|
|
for _, srv := range serversToRetry {
|
|
delete(retries, srv)
|
|
}
|
|
deviceListUpdaterServersRetrying.With(prometheus.Labels{"worker_id": strconv.Itoa(workerID)}).Set(float64(len(retries)))
|
|
retriesMu.Unlock()
|
|
for _, srv := range serversToRetry {
|
|
deviceListUpdaterBackpressure.With(prometheus.Labels{"worker_id": strconv.Itoa(workerID)}).Inc()
|
|
ch <- srv
|
|
deviceListUpdaterBackpressure.With(prometheus.Labels{"worker_id": strconv.Itoa(workerID)}).Dec()
|
|
}
|
|
}
|
|
}()
|
|
for serverName := range ch {
|
|
retriesMu.Lock()
|
|
_, exists := retries[serverName]
|
|
retriesMu.Unlock()
|
|
if exists {
|
|
// Don't retry a server that we're already waiting for.
|
|
continue
|
|
}
|
|
waitTime, shouldRetry := u.processServer(serverName)
|
|
if shouldRetry {
|
|
retriesMu.Lock()
|
|
if _, exists = retries[serverName]; !exists {
|
|
retries[serverName] = time.Now().Add(waitTime)
|
|
}
|
|
retriesMu.Unlock()
|
|
}
|
|
}
|
|
}
|
|
|
|
func (u *DeviceListUpdater) processServer(serverName spec.ServerName) (time.Duration, bool) {
|
|
ctx := u.process.Context()
|
|
// If the process.Context is canceled, there is no need to go further.
|
|
// This avoids spamming the logs when shutting down
|
|
if errors.Is(ctx.Err(), context.Canceled) {
|
|
return defaultWaitTime, false
|
|
}
|
|
|
|
logger := util.GetLogger(ctx).WithField("server_name", serverName)
|
|
deviceListUpdateCount.WithLabelValues(string(serverName)).Inc()
|
|
|
|
waitTime := defaultWaitTime // How long should we wait to try again?
|
|
successCount := 0 // How many user requests failed?
|
|
|
|
userIDs, err := u.db.StaleDeviceLists(ctx, []spec.ServerName{serverName})
|
|
if err != nil {
|
|
logger.WithError(err).Error("Failed to load stale device lists")
|
|
return waitTime, true
|
|
}
|
|
|
|
for _, userID := range userIDs {
|
|
userWait, err := u.processServerUser(ctx, serverName, userID)
|
|
if err != nil {
|
|
if userWait > waitTime {
|
|
waitTime = userWait
|
|
}
|
|
break
|
|
}
|
|
successCount++
|
|
}
|
|
|
|
allUsersSucceeded := successCount == len(userIDs)
|
|
if !allUsersSucceeded {
|
|
logger.WithFields(logrus.Fields{
|
|
"total": len(userIDs),
|
|
"succeeded": successCount,
|
|
"failed": len(userIDs) - successCount,
|
|
"wait_time": waitTime,
|
|
}).Debug("Failed to query device keys for some users")
|
|
}
|
|
return waitTime, !allUsersSucceeded
|
|
}
|
|
|
|
func (u *DeviceListUpdater) processServerUser(ctx context.Context, serverName spec.ServerName, userID string) (time.Duration, error) {
|
|
ctx, cancel := context.WithTimeout(ctx, requestTimeout)
|
|
defer cancel()
|
|
|
|
// If we are processing more than one user per server, this unblocks further calls to Update
|
|
// immediately instead of just after **all** users have been processed.
|
|
defer u.clearChannel(userID)
|
|
|
|
logger := util.GetLogger(ctx).WithFields(logrus.Fields{
|
|
"server_name": serverName,
|
|
"user_id": userID,
|
|
})
|
|
res, err := u.fedClient.GetUserDevices(ctx, u.thisServer, serverName, userID)
|
|
if err != nil {
|
|
if errors.Is(err, context.DeadlineExceeded) {
|
|
return time.Minute * 10, err
|
|
}
|
|
switch e := err.(type) {
|
|
case *json.UnmarshalTypeError, *json.SyntaxError:
|
|
logger.WithError(err).Debugf("Device list update for %q contained invalid JSON", userID)
|
|
return defaultWaitTime, nil
|
|
case *fedsenderapi.FederationClientError:
|
|
if e.RetryAfter > 0 {
|
|
return e.RetryAfter, err
|
|
} else if e.Blacklisted {
|
|
return time.Hour * 8, err
|
|
}
|
|
case net.Error:
|
|
// Use the default waitTime, if it's a timeout.
|
|
// It probably doesn't make sense to try further users.
|
|
if !e.Timeout() {
|
|
logger.WithError(e).Debug("GetUserDevices returned net.Error")
|
|
return time.Minute * 10, err
|
|
}
|
|
case gomatrix.HTTPError:
|
|
// The remote server returned an error, give it some time to recover.
|
|
// This is to avoid spamming remote servers, which may not be Matrix servers anymore.
|
|
if e.Code >= 300 {
|
|
logger.WithError(e).Debug("GetUserDevices returned gomatrix.HTTPError")
|
|
return hourWaitTime, err
|
|
}
|
|
default:
|
|
// Something else failed
|
|
logger.WithError(err).Debugf("GetUserDevices returned unknown error type: %T", err)
|
|
return time.Minute * 10, err
|
|
}
|
|
}
|
|
if res.UserID != userID {
|
|
logger.WithError(err).Debugf("User ID %q in device list update response doesn't match expected %q", res.UserID, userID)
|
|
return defaultWaitTime, nil
|
|
}
|
|
if res.MasterKey != nil || res.SelfSigningKey != nil {
|
|
uploadReq := &api.PerformUploadDeviceKeysRequest{
|
|
UserID: userID,
|
|
}
|
|
uploadRes := &api.PerformUploadDeviceKeysResponse{}
|
|
if res.MasterKey != nil {
|
|
if err = sanityCheckKey(*res.MasterKey, userID, fclient.CrossSigningKeyPurposeMaster); err == nil {
|
|
uploadReq.MasterKey = *res.MasterKey
|
|
}
|
|
}
|
|
if res.SelfSigningKey != nil {
|
|
if err = sanityCheckKey(*res.SelfSigningKey, userID, fclient.CrossSigningKeyPurposeSelfSigning); err == nil {
|
|
uploadReq.SelfSigningKey = *res.SelfSigningKey
|
|
}
|
|
}
|
|
u.api.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes)
|
|
}
|
|
err = u.updateDeviceList(&res)
|
|
if err != nil {
|
|
logger.WithError(err).Error("Fetched device list but failed to store/emit it")
|
|
return defaultWaitTime, err
|
|
}
|
|
return defaultWaitTime, nil
|
|
}
|
|
|
|
func (u *DeviceListUpdater) updateDeviceList(res *fclient.RespUserDevices) error {
|
|
ctx := context.Background() // we've got the keys, don't time out when persisting them to the database.
|
|
keys := make([]api.DeviceMessage, len(res.Devices))
|
|
existingKeys := make([]api.DeviceMessage, len(res.Devices))
|
|
for i, device := range res.Devices {
|
|
keyJSON, err := json.Marshal(device.Keys)
|
|
if err != nil {
|
|
util.GetLogger(ctx).WithField("keys", device.Keys).Error("failed to marshal keys, skipping device")
|
|
continue
|
|
}
|
|
keys[i] = api.DeviceMessage{
|
|
Type: api.TypeDeviceKeyUpdate,
|
|
StreamID: res.StreamID,
|
|
DeviceKeys: &api.DeviceKeys{
|
|
DeviceID: device.DeviceID,
|
|
DisplayName: device.DisplayName,
|
|
UserID: res.UserID,
|
|
KeyJSON: keyJSON,
|
|
},
|
|
}
|
|
existingKeys[i] = api.DeviceMessage{
|
|
Type: api.TypeDeviceKeyUpdate,
|
|
DeviceKeys: &api.DeviceKeys{
|
|
UserID: res.UserID,
|
|
DeviceID: device.DeviceID,
|
|
},
|
|
}
|
|
}
|
|
// fetch what keys we had already and only emit changes
|
|
if err := u.db.DeviceKeysJSON(ctx, existingKeys); err != nil {
|
|
// non-fatal, log and continue
|
|
util.GetLogger(ctx).WithError(err).WithField("user_id", res.UserID).Errorf(
|
|
"failed to query device keys json for calculating diffs",
|
|
)
|
|
}
|
|
|
|
err := u.db.StoreRemoteDeviceKeys(ctx, keys, []string{res.UserID})
|
|
if err != nil {
|
|
return fmt.Errorf("failed to store remote device keys: %w", err)
|
|
}
|
|
err = u.db.MarkDeviceListStale(ctx, res.UserID, false)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to mark device list as fresh: %w", err)
|
|
}
|
|
err = emitDeviceKeyChanges(u.producer, existingKeys, keys, false)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to emit key changes for fresh device list: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// dedupeStaleLists de-duplicates the stateList entries using the domain.
|
|
// This is used on startup, processServer is getting all users anyway, so
|
|
// there is no need to send every user to the workers.
|
|
func dedupeStaleLists(staleLists []string) []string {
|
|
seenDomains := make(map[spec.ServerName]struct{})
|
|
newStaleLists := make([]string, 0, len(staleLists))
|
|
for _, userID := range staleLists {
|
|
_, domain, err := gomatrixserverlib.SplitID('@', userID)
|
|
if err != nil {
|
|
// non-fatal and should not block starting up
|
|
continue
|
|
}
|
|
if _, ok := seenDomains[domain]; ok {
|
|
continue
|
|
}
|
|
newStaleLists = append(newStaleLists, userID)
|
|
seenDomains[domain] = struct{}{}
|
|
}
|
|
return newStaleLists
|
|
}
|