mirror of
https://github.com/hoernschen/dendrite.git
synced 2025-07-31 13:22:46 +00:00
Finish inbound E2E device lists (#1243)
* Add tests for device list updates * Add stale_device_lists table and use db before asking remote for device keys * Fetch remote keys if all devices are requested * Add display_name col to store remote device names Few other tweaks to make `Server correctly handles incoming m.device_list_update` pass. * Fix sqlite otk bug * Unbuffered channel to block /send causing sytest to not race anymore * Linting and fix bug whereby we didn't send updated dl tokens to the client causing a tightloop on /sync sometimes * No longer assert staleness as Update blocks on workers now * Back out tweaks * Bugfixes
This commit is contained in:
parent
30c2325eaf
commit
f371783da7
15 changed files with 639 additions and 48 deletions
|
@ -23,7 +23,6 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/keyserver/api"
|
||||
"github.com/matrix-org/dendrite/keyserver/producers"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
@ -65,7 +64,7 @@ type DeviceListUpdater struct {
|
|||
mu *sync.Mutex // protects UserIDToMutex
|
||||
|
||||
db DeviceListUpdaterDatabase
|
||||
producer *producers.KeyChange
|
||||
producer KeyChangeProducer
|
||||
fedClient *gomatrixserverlib.FederationClient
|
||||
workerChans []chan gomatrixserverlib.ServerName
|
||||
}
|
||||
|
@ -88,9 +87,14 @@ type DeviceListUpdaterDatabase interface {
|
|||
PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error)
|
||||
}
|
||||
|
||||
// KeyChangeProducer is the interface for producers.KeyChange useful for testing.
|
||||
type KeyChangeProducer interface {
|
||||
ProduceKeyChanges(keys []api.DeviceMessage) error
|
||||
}
|
||||
|
||||
// NewDeviceListUpdater creates a new updater which fetches fresh device lists when they go stale.
|
||||
func NewDeviceListUpdater(
|
||||
db DeviceListUpdaterDatabase, producer *producers.KeyChange, fedClient *gomatrixserverlib.FederationClient,
|
||||
db DeviceListUpdaterDatabase, producer KeyChangeProducer, fedClient *gomatrixserverlib.FederationClient,
|
||||
numWorkers int,
|
||||
) *DeviceListUpdater {
|
||||
return &DeviceListUpdater{
|
||||
|
@ -154,12 +158,17 @@ func (u *DeviceListUpdater) update(ctx context.Context, event gomatrixserverlib.
|
|||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check prev IDs exist for %s (%s): %w", event.UserID, event.DeviceID, err)
|
||||
}
|
||||
// if this is the first time we're hearing about this user, sync the device list manually.
|
||||
if len(event.PrevID) == 0 {
|
||||
exists = false
|
||||
}
|
||||
util.GetLogger(ctx).WithFields(logrus.Fields{
|
||||
"prev_ids_exist": exists,
|
||||
"user_id": event.UserID,
|
||||
"device_id": event.DeviceID,
|
||||
"stream_id": event.StreamID,
|
||||
"prev_ids": event.PrevID,
|
||||
"display_name": event.DeviceDisplayName,
|
||||
}).Info("DeviceListUpdater.Update")
|
||||
|
||||
// if we haven't missed anything update the database and notify users
|
||||
|
@ -263,16 +272,17 @@ func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerNam
|
|||
hasFailures = true
|
||||
continue
|
||||
}
|
||||
err = u.updateDeviceList(ctx, &res)
|
||||
err = u.updateDeviceList(&res)
|
||||
if err != nil {
|
||||
logger.WithError(err).WithField("user_id", userID).Error("fetched device list but failed to store it")
|
||||
logger.WithError(err).WithField("user_id", userID).Error("fetched device list but failed to store/emit it")
|
||||
hasFailures = true
|
||||
}
|
||||
}
|
||||
return hasFailures
|
||||
}
|
||||
|
||||
func (u *DeviceListUpdater) updateDeviceList(ctx context.Context, res *gomatrixserverlib.RespUserDevices) error {
|
||||
func (u *DeviceListUpdater) updateDeviceList(res *gomatrixserverlib.RespUserDevices) error {
|
||||
ctx := context.Background() // we've got the keys, don't time out when persisting them to the database.
|
||||
keys := make([]api.DeviceMessage, len(res.Devices))
|
||||
for i, device := range res.Devices {
|
||||
keyJSON, err := json.Marshal(device.Keys)
|
||||
|
@ -292,7 +302,15 @@ func (u *DeviceListUpdater) updateDeviceList(ctx context.Context, res *gomatrixs
|
|||
}
|
||||
err := u.db.StoreRemoteDeviceKeys(ctx, keys)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed to store remote device keys: %w", err)
|
||||
}
|
||||
return u.db.MarkDeviceListStale(ctx, res.UserID, false)
|
||||
err = u.db.MarkDeviceListStale(ctx, res.UserID, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to mark device list as fresh: %w", err)
|
||||
}
|
||||
err = u.producer.ProduceKeyChanges(keys)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to emit key changes for fresh device list: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
242
keyserver/internal/device_list_update_test.go
Normal file
242
keyserver/internal/device_list_update_test.go
Normal file
|
@ -0,0 +1,242 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/keyserver/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
var (
|
||||
ctx = context.Background()
|
||||
)
|
||||
|
||||
type mockKeyChangeProducer struct {
|
||||
events []api.DeviceMessage
|
||||
}
|
||||
|
||||
func (p *mockKeyChangeProducer) ProduceKeyChanges(keys []api.DeviceMessage) error {
|
||||
p.events = append(p.events, keys...)
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockDeviceListUpdaterDatabase struct {
|
||||
staleUsers map[string]bool
|
||||
prevIDsExist func(string, []int) bool
|
||||
storedKeys []api.DeviceMessage
|
||||
}
|
||||
|
||||
// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
|
||||
// If no domains are given, all user IDs with stale device lists are returned.
|
||||
func (d *mockDeviceListUpdaterDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
|
||||
var result []string
|
||||
for userID := range d.staleUsers {
|
||||
_, remoteServer, err := gomatrixserverlib.SplitID('@', userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(domains) == 0 {
|
||||
result = append(result, userID)
|
||||
continue
|
||||
}
|
||||
for _, d := range domains {
|
||||
if remoteServer == d {
|
||||
result = append(result, userID)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// MarkDeviceListStale sets the stale bit for this user to isStale.
|
||||
func (d *mockDeviceListUpdaterDatabase) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error {
|
||||
d.staleUsers[userID] = isStale
|
||||
return nil
|
||||
}
|
||||
|
||||
// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
|
||||
// for this (user, device). Does not modify the stream ID for keys.
|
||||
func (d *mockDeviceListUpdaterDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error {
|
||||
d.storedKeys = append(d.storedKeys, keys...)
|
||||
return nil
|
||||
}
|
||||
|
||||
// PrevIDsExists returns true if all prev IDs exist for this user.
|
||||
func (d *mockDeviceListUpdaterDatabase) PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error) {
|
||||
return d.prevIDsExist(userID, prevIDs), nil
|
||||
}
|
||||
|
||||
type roundTripper struct {
|
||||
fn func(*http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
func (t *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return t.fn(req)
|
||||
}
|
||||
|
||||
func newFedClient(tripper func(*http.Request) (*http.Response, error)) *gomatrixserverlib.FederationClient {
|
||||
_, pkey, _ := ed25519.GenerateKey(nil)
|
||||
fedClient := gomatrixserverlib.NewFederationClient(
|
||||
gomatrixserverlib.ServerName("example.test"), gomatrixserverlib.KeyID("ed25519:test"), pkey,
|
||||
)
|
||||
fedClient.Client = *gomatrixserverlib.NewClientWithTransport(&roundTripper{tripper})
|
||||
return fedClient
|
||||
}
|
||||
|
||||
// Test that the device keys get persisted and emitted if we have the previous IDs.
|
||||
func TestUpdateHavePrevID(t *testing.T) {
|
||||
db := &mockDeviceListUpdaterDatabase{
|
||||
staleUsers: make(map[string]bool),
|
||||
prevIDsExist: func(string, []int) bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
producer := &mockKeyChangeProducer{}
|
||||
updater := NewDeviceListUpdater(db, producer, nil, 1)
|
||||
event := gomatrixserverlib.DeviceListUpdateEvent{
|
||||
DeviceDisplayName: "Foo Bar",
|
||||
Deleted: false,
|
||||
DeviceID: "FOO",
|
||||
Keys: []byte(`{"key":"value"}`),
|
||||
PrevID: []int{0},
|
||||
StreamID: 1,
|
||||
UserID: "@alice:localhost",
|
||||
}
|
||||
err := updater.Update(ctx, event)
|
||||
if err != nil {
|
||||
t.Fatalf("Update returned an error: %s", err)
|
||||
}
|
||||
want := api.DeviceMessage{
|
||||
StreamID: event.StreamID,
|
||||
DeviceKeys: api.DeviceKeys{
|
||||
DeviceID: event.DeviceID,
|
||||
DisplayName: event.DeviceDisplayName,
|
||||
KeyJSON: event.Keys,
|
||||
UserID: event.UserID,
|
||||
},
|
||||
}
|
||||
if !reflect.DeepEqual(producer.events, []api.DeviceMessage{want}) {
|
||||
t.Errorf("Update didn't produce correct event, got %v want %v", producer.events, want)
|
||||
}
|
||||
if !reflect.DeepEqual(db.storedKeys, []api.DeviceMessage{want}) {
|
||||
t.Errorf("DB didn't store correct event, got %v want %v", db.storedKeys, want)
|
||||
}
|
||||
if db.staleUsers[event.UserID] {
|
||||
t.Errorf("%s incorrectly marked as stale", event.UserID)
|
||||
}
|
||||
}
|
||||
|
||||
// Test that device keys are fetched from the remote server if we are missing prev IDs
|
||||
// and that the user's devices are marked as stale until it succeeds.
|
||||
func TestUpdateNoPrevID(t *testing.T) {
|
||||
db := &mockDeviceListUpdaterDatabase{
|
||||
staleUsers: make(map[string]bool),
|
||||
prevIDsExist: func(string, []int) bool {
|
||||
return false
|
||||
},
|
||||
}
|
||||
producer := &mockKeyChangeProducer{}
|
||||
remoteUserID := "@alice:example.somewhere"
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
keyJSON := `{"user_id":"` + remoteUserID + `","device_id":"JLAFKJWSCS","algorithms":["m.olm.v1.curve25519-aes-sha2","m.megolm.v1.aes-sha2"],"keys":{"curve25519:JLAFKJWSCS":"3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIArzgyI","ed25519:JLAFKJWSCS":"lEuiRJBit0IG6nUf5pUzWTUEsRVVe/HJkoKuEww9ULI"},"signatures":{"` + remoteUserID + `":{"ed25519:JLAFKJWSCS":"dSO80A01XiigH3uBiDVx/EjzaoycHcjq9lfQX0uWsqxl2giMIiSPR8a4d291W1ihKJL/a+myXS367WT6NAIcBA"}}}`
|
||||
fedClient := newFedClient(func(req *http.Request) (*http.Response, error) {
|
||||
defer wg.Done()
|
||||
if req.URL.Path != "/_matrix/federation/v1/user/devices/"+url.PathEscape(remoteUserID) {
|
||||
return nil, fmt.Errorf("test: invalid path: %s", req.URL.Path)
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(strings.NewReader(`
|
||||
{
|
||||
"user_id": "` + remoteUserID + `",
|
||||
"stream_id": 5,
|
||||
"devices": [
|
||||
{
|
||||
"device_id": "JLAFKJWSCS",
|
||||
"keys": ` + keyJSON + `,
|
||||
"device_display_name": "Mobile Phone"
|
||||
}
|
||||
]
|
||||
}
|
||||
`)),
|
||||
}, nil
|
||||
})
|
||||
updater := NewDeviceListUpdater(db, producer, fedClient, 2)
|
||||
if err := updater.Start(); err != nil {
|
||||
t.Fatalf("failed to start updater: %s", err)
|
||||
}
|
||||
event := gomatrixserverlib.DeviceListUpdateEvent{
|
||||
DeviceDisplayName: "Mobile Phone",
|
||||
Deleted: false,
|
||||
DeviceID: "another_device_id",
|
||||
Keys: []byte(`{"key":"value"}`),
|
||||
PrevID: []int{3},
|
||||
StreamID: 4,
|
||||
UserID: remoteUserID,
|
||||
}
|
||||
err := updater.Update(ctx, event)
|
||||
if err != nil {
|
||||
t.Fatalf("Update returned an error: %s", err)
|
||||
}
|
||||
// At this point we show have this device list marked as stale and not store the keys or emitted anything
|
||||
if !db.staleUsers[event.UserID] {
|
||||
t.Errorf("%s not marked as stale", event.UserID)
|
||||
}
|
||||
if len(producer.events) > 0 {
|
||||
t.Errorf("Update incorrect emitted %d device change events", len(producer.events))
|
||||
}
|
||||
if len(db.storedKeys) > 0 {
|
||||
t.Errorf("Update incorrect stored %d device change events", len(db.storedKeys))
|
||||
}
|
||||
t.Log("waiting for /users/devices to be called...")
|
||||
wg.Wait()
|
||||
// wait a bit for db to be updated...
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
want := api.DeviceMessage{
|
||||
StreamID: 5,
|
||||
DeviceKeys: api.DeviceKeys{
|
||||
DeviceID: "JLAFKJWSCS",
|
||||
DisplayName: "Mobile Phone",
|
||||
UserID: remoteUserID,
|
||||
KeyJSON: []byte(keyJSON),
|
||||
},
|
||||
}
|
||||
// Now we should have a fresh list and the keys and emitted something
|
||||
if db.staleUsers[event.UserID] {
|
||||
t.Errorf("%s still marked as stale", event.UserID)
|
||||
}
|
||||
if !reflect.DeepEqual(producer.events, []api.DeviceMessage{want}) {
|
||||
t.Logf("len got %d len want %d", len(producer.events[0].KeyJSON), len(want.KeyJSON))
|
||||
t.Errorf("Update didn't produce correct event, got %v want %v", producer.events, want)
|
||||
}
|
||||
if !reflect.DeepEqual(db.storedKeys, []api.DeviceMessage{want}) {
|
||||
t.Errorf("DB didn't store correct event, got %v want %v", db.storedKeys, want)
|
||||
}
|
||||
|
||||
}
|
|
@ -250,10 +250,14 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
|
|||
if len(dk.KeyJSON) == 0 {
|
||||
continue // don't include blank keys
|
||||
}
|
||||
// inject display name if known
|
||||
// inject display name if known (either locally or remotely)
|
||||
displayName := dk.DisplayName
|
||||
if queryRes.DeviceInfo[dk.DeviceID].DisplayName != "" {
|
||||
displayName = queryRes.DeviceInfo[dk.DeviceID].DisplayName
|
||||
}
|
||||
dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct {
|
||||
DisplayName string `json:"device_display_name,omitempty"`
|
||||
}{queryRes.DeviceInfo[dk.DeviceID].DisplayName})
|
||||
}{displayName})
|
||||
res.DeviceKeys[userID][dk.DeviceID] = dk.KeyJSON
|
||||
}
|
||||
} else {
|
||||
|
@ -261,12 +265,49 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
|
|||
domainToDeviceKeys[domain][userID] = append(domainToDeviceKeys[domain][userID], deviceIDs...)
|
||||
}
|
||||
}
|
||||
// TODO: set device display names when they are known
|
||||
|
||||
// attempt to satisfy key queries from the local database first as we should get device updates pushed to us
|
||||
domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, domainToDeviceKeys)
|
||||
if len(domainToDeviceKeys) == 0 {
|
||||
return // nothing to query
|
||||
}
|
||||
|
||||
// perform key queries for remote devices
|
||||
a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys)
|
||||
}
|
||||
|
||||
func (a *KeyInternalAPI) remoteKeysFromDatabase(
|
||||
ctx context.Context, res *api.QueryKeysResponse, domainToDeviceKeys map[string]map[string][]string,
|
||||
) map[string]map[string][]string {
|
||||
fetchRemote := make(map[string]map[string][]string)
|
||||
for domain, userToDeviceMap := range domainToDeviceKeys {
|
||||
for userID, deviceIDs := range userToDeviceMap {
|
||||
keys, err := a.DB.DeviceKeysForUser(ctx, userID, deviceIDs)
|
||||
// if we can't query the db or there are fewer keys than requested, fetch from remote.
|
||||
// Likewise, we can't safely return keys from the db when all devices are requested as we don't
|
||||
// know if one has just been added.
|
||||
if len(deviceIDs) == 0 || err != nil || len(keys) < len(deviceIDs) {
|
||||
if _, ok := fetchRemote[domain]; !ok {
|
||||
fetchRemote[domain] = make(map[string][]string)
|
||||
}
|
||||
fetchRemote[domain][userID] = append(fetchRemote[domain][userID], deviceIDs...)
|
||||
continue
|
||||
}
|
||||
if res.DeviceKeys[userID] == nil {
|
||||
res.DeviceKeys[userID] = make(map[string]json.RawMessage)
|
||||
}
|
||||
for _, key := range keys {
|
||||
// inject the display name
|
||||
key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct {
|
||||
DisplayName string `json:"device_display_name,omitempty"`
|
||||
}{key.DisplayName})
|
||||
res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON
|
||||
}
|
||||
}
|
||||
}
|
||||
return fetchRemote
|
||||
}
|
||||
|
||||
func (a *KeyInternalAPI) queryRemoteKeys(
|
||||
ctx context.Context, timeout time.Duration, res *api.QueryKeysResponse, domainToDeviceKeys map[string]map[string][]string,
|
||||
) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue