mirror of
https://github.com/hoernschen/dendrite.git
synced 2025-08-02 14:12:47 +00:00
Various refactoring
This commit is contained in:
parent
ed4097825b
commit
62bcd5ad4b
13 changed files with 178 additions and 182 deletions
|
@ -16,7 +16,6 @@ package internal
|
|||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"sync"
|
||||
|
@ -215,23 +214,15 @@ func (u *DeviceListUpdater) update(ctx context.Context, event gomatrixserverlib.
|
|||
"prev_ids": event.PrevID,
|
||||
"display_name": event.DeviceDisplayName,
|
||||
"deleted": event.Deleted,
|
||||
"keys": event.Keys,
|
||||
}).Info("DeviceListUpdater.Update")
|
||||
|
||||
// if we haven't missed anything update the database and notify users
|
||||
if exists {
|
||||
k := event.Keys
|
||||
if event.Deleted {
|
||||
k = nil
|
||||
}
|
||||
keys := []api.DeviceMessage{
|
||||
{
|
||||
DeviceKeys: api.DeviceKeys{
|
||||
DeviceID: event.DeviceID,
|
||||
DisplayName: event.DeviceDisplayName,
|
||||
KeyJSON: k,
|
||||
UserID: event.UserID,
|
||||
},
|
||||
StreamID: event.StreamID,
|
||||
DeviceKeys: event.Keys,
|
||||
StreamID: event.StreamID,
|
||||
},
|
||||
}
|
||||
err = u.db.StoreRemoteDeviceKeys(ctx, keys, nil)
|
||||
|
@ -388,24 +379,15 @@ func (u *DeviceListUpdater) updateDeviceList(res *gomatrixserverlib.RespUserDevi
|
|||
keys := make([]api.DeviceMessage, len(res.Devices))
|
||||
existingKeys := make([]api.DeviceMessage, len(res.Devices))
|
||||
for i, device := range res.Devices {
|
||||
keyJSON, err := json.Marshal(device.Keys)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithField("keys", device.Keys).Error("failed to marshal keys, skipping device")
|
||||
continue
|
||||
}
|
||||
keys[i] = api.DeviceMessage{
|
||||
StreamID: res.StreamID,
|
||||
DeviceKeys: api.DeviceKeys{
|
||||
DeviceID: device.DeviceID,
|
||||
DisplayName: device.DisplayName,
|
||||
UserID: res.UserID,
|
||||
KeyJSON: keyJSON,
|
||||
DeviceKeys: &gomatrixserverlib.DeviceKeys{
|
||||
RespUserDeviceKeys: device.Keys,
|
||||
},
|
||||
}
|
||||
existingKeys[i] = api.DeviceMessage{
|
||||
DeviceKeys: api.DeviceKeys{
|
||||
UserID: res.UserID,
|
||||
DeviceID: device.DeviceID,
|
||||
DeviceKeys: &gomatrixserverlib.DeviceKeys{
|
||||
RespUserDeviceKeys: device.Keys,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,6 +17,7 @@ package internal
|
|||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
|
@ -128,10 +129,19 @@ func TestUpdateHavePrevID(t *testing.T) {
|
|||
DeviceDisplayName: "Foo Bar",
|
||||
Deleted: false,
|
||||
DeviceID: "FOO",
|
||||
Keys: []byte(`{"key":"value"}`),
|
||||
PrevID: []int{0},
|
||||
StreamID: 1,
|
||||
UserID: "@alice:localhost",
|
||||
Keys: &gomatrixserverlib.DeviceKeys{
|
||||
RespUserDeviceKeys: gomatrixserverlib.RespUserDeviceKeys{
|
||||
DeviceID: "FOO",
|
||||
UserID: "@alice:localhost",
|
||||
Algorithms: []string{"TEST"},
|
||||
Keys: map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{
|
||||
"key": {1, 2, 3, 4, 5, 6},
|
||||
},
|
||||
},
|
||||
},
|
||||
PrevID: []int{0},
|
||||
StreamID: 1,
|
||||
UserID: "@alice:localhost",
|
||||
}
|
||||
err := updater.Update(ctx, event)
|
||||
if err != nil {
|
||||
|
@ -139,11 +149,15 @@ func TestUpdateHavePrevID(t *testing.T) {
|
|||
}
|
||||
want := api.DeviceMessage{
|
||||
StreamID: event.StreamID,
|
||||
DeviceKeys: api.DeviceKeys{
|
||||
DeviceID: event.DeviceID,
|
||||
DisplayName: event.DeviceDisplayName,
|
||||
KeyJSON: event.Keys,
|
||||
UserID: event.UserID,
|
||||
DeviceKeys: &gomatrixserverlib.DeviceKeys{
|
||||
RespUserDeviceKeys: gomatrixserverlib.RespUserDeviceKeys{
|
||||
DeviceID: "FOO",
|
||||
UserID: "@alice:localhost",
|
||||
Algorithms: []string{"TEST"},
|
||||
Keys: map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{
|
||||
"key": {1, 2, 3, 4, 5, 6},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
if !reflect.DeepEqual(producer.events, []api.DeviceMessage{want}) {
|
||||
|
@ -201,11 +215,16 @@ func TestUpdateNoPrevID(t *testing.T) {
|
|||
DeviceDisplayName: "Mobile Phone",
|
||||
Deleted: false,
|
||||
DeviceID: "another_device_id",
|
||||
Keys: []byte(`{"key":"value"}`),
|
||||
Keys: &gomatrixserverlib.DeviceKeys{},
|
||||
PrevID: []int{3},
|
||||
StreamID: 4,
|
||||
UserID: remoteUserID,
|
||||
}
|
||||
if err := json.Unmarshal([]byte(keyJSON), event.Keys); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
event.DeviceID = "another_device"
|
||||
event.Keys.DeviceID = "another_device"
|
||||
err := updater.Update(ctx, event)
|
||||
if err != nil {
|
||||
t.Fatalf("Update returned an error: %s", err)
|
||||
|
@ -215,24 +234,25 @@ func TestUpdateNoPrevID(t *testing.T) {
|
|||
// wait a bit for db to be updated...
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
want := api.DeviceMessage{
|
||||
StreamID: 5,
|
||||
DeviceKeys: api.DeviceKeys{
|
||||
DeviceID: "JLAFKJWSCS",
|
||||
DisplayName: "Mobile Phone",
|
||||
UserID: remoteUserID,
|
||||
KeyJSON: []byte(keyJSON),
|
||||
},
|
||||
StreamID: 5,
|
||||
DeviceKeys: &gomatrixserverlib.DeviceKeys{},
|
||||
}
|
||||
if err := json.Unmarshal([]byte(keyJSON), want.DeviceKeys); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
want.Unsigned = map[string]interface{}{
|
||||
"device_display_name": "Mobile Phone",
|
||||
}
|
||||
// Now we should have a fresh list and the keys and emitted something
|
||||
if db.staleUsers[event.UserID] {
|
||||
t.Errorf("%s still marked as stale", event.UserID)
|
||||
}
|
||||
if !reflect.DeepEqual(producer.events, []api.DeviceMessage{want}) {
|
||||
t.Logf("len got %d len want %d", len(producer.events[0].KeyJSON), len(want.KeyJSON))
|
||||
if len(producer.events) != 1 {
|
||||
t.Logf("len got %+v len want 1", len(producer.events))
|
||||
t.Errorf("Update didn't produce correct event, got %v want %v", producer.events, want)
|
||||
}
|
||||
if !reflect.DeepEqual(db.storedKeys, []api.DeviceMessage{want}) {
|
||||
t.Errorf("DB didn't store correct event, got %v want %v", db.storedKeys, want)
|
||||
if !reflect.DeepEqual(*db.storedKeys[0].DeviceKeys, *want.DeviceKeys) {
|
||||
t.Errorf("DB didn't store correct event\ngot %+v\nwant %+v", *db.storedKeys[0].DeviceKeys, *want.DeviceKeys)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -30,8 +30,6 @@ import (
|
|||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
type KeyInternalAPI struct {
|
||||
|
@ -210,7 +208,7 @@ func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.Query
|
|||
// remove deleted devices
|
||||
var result []api.DeviceMessage
|
||||
for _, m := range msgs {
|
||||
if m.KeyJSON == nil {
|
||||
if m.DeviceKeys == nil {
|
||||
continue
|
||||
}
|
||||
result = append(result, m)
|
||||
|
@ -220,7 +218,7 @@ func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.Query
|
|||
}
|
||||
|
||||
func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) {
|
||||
res.DeviceKeys = make(map[string]map[string]json.RawMessage)
|
||||
res.DeviceKeys = make(map[string]map[string]gomatrixserverlib.DeviceKeys)
|
||||
res.Failures = make(map[string]interface{})
|
||||
// make a map from domain to device keys
|
||||
domainToDeviceKeys := make(map[string]map[string][]string)
|
||||
|
@ -254,21 +252,19 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
|
|||
}
|
||||
|
||||
if res.DeviceKeys[userID] == nil {
|
||||
res.DeviceKeys[userID] = make(map[string]json.RawMessage)
|
||||
res.DeviceKeys[userID] = make(map[string]gomatrixserverlib.DeviceKeys)
|
||||
}
|
||||
for _, dk := range deviceKeys {
|
||||
if len(dk.KeyJSON) == 0 {
|
||||
if dk.DeviceKeys == nil {
|
||||
continue // don't include blank keys
|
||||
}
|
||||
// inject display name if known (either locally or remotely)
|
||||
displayName := dk.DisplayName
|
||||
displayName := dk.DisplayName()
|
||||
if queryRes.DeviceInfo[dk.DeviceID].DisplayName != "" {
|
||||
displayName = queryRes.DeviceInfo[dk.DeviceID].DisplayName
|
||||
}
|
||||
dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct {
|
||||
DisplayName string `json:"device_display_name,omitempty"`
|
||||
}{displayName})
|
||||
res.DeviceKeys[userID][dk.DeviceID] = dk.KeyJSON
|
||||
dk.Unsigned["device_display_name"] = displayName
|
||||
res.DeviceKeys[userID][dk.DeviceID] = *dk.DeviceKeys
|
||||
}
|
||||
} else {
|
||||
domainToDeviceKeys[domain] = make(map[string][]string)
|
||||
|
@ -335,13 +331,9 @@ func (a *KeyInternalAPI) queryRemoteKeys(
|
|||
|
||||
for result := range resultCh {
|
||||
for userID, nest := range result.DeviceKeys {
|
||||
res.DeviceKeys[userID] = make(map[string]json.RawMessage)
|
||||
res.DeviceKeys[userID] = make(map[string]gomatrixserverlib.DeviceKeys)
|
||||
for deviceID, deviceKey := range nest {
|
||||
keyJSON, err := json.Marshal(deviceKey)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
res.DeviceKeys[userID][deviceID] = keyJSON
|
||||
res.DeviceKeys[userID][deviceID] = deviceKey
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -432,18 +424,16 @@ func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
|
|||
return fmt.Errorf("DeviceKeysForUser %s returned no keys but wanted all keys, falling back to remote", userID)
|
||||
}
|
||||
if res.DeviceKeys[userID] == nil {
|
||||
res.DeviceKeys[userID] = make(map[string]json.RawMessage)
|
||||
res.DeviceKeys[userID] = make(map[string]gomatrixserverlib.DeviceKeys)
|
||||
}
|
||||
|
||||
for _, key := range keys {
|
||||
if len(key.KeyJSON) == 0 {
|
||||
if key.DeviceKeys == nil {
|
||||
continue // ignore deleted keys
|
||||
}
|
||||
// inject the display name
|
||||
key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct {
|
||||
DisplayName string `json:"device_display_name,omitempty"`
|
||||
}{key.DisplayName})
|
||||
res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON
|
||||
key.DeviceKeys.Unsigned["device_display_name"] = key.DisplayName()
|
||||
res.DeviceKeys[userID][key.DeviceID] = *key.DeviceKeys
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -459,21 +449,19 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
|
|||
if serverName != a.ThisServer {
|
||||
continue // ignore remote users
|
||||
}
|
||||
if len(key.KeyJSON) == 0 {
|
||||
keysToStore = append(keysToStore, key.WithStreamID(0))
|
||||
if len(key.Keys) == 0 {
|
||||
keysToStore = append(keysToStore, api.DeviceMessage{DeviceKeys: &key, StreamID: 0})
|
||||
continue // deleted keys don't need sanity checking
|
||||
}
|
||||
gotUserID := gjson.GetBytes(key.KeyJSON, "user_id").Str
|
||||
gotDeviceID := gjson.GetBytes(key.KeyJSON, "device_id").Str
|
||||
if gotUserID == key.UserID && gotDeviceID == key.DeviceID {
|
||||
keysToStore = append(keysToStore, key.WithStreamID(0))
|
||||
if req.UserID == key.UserID && req.DeviceID == key.DeviceID {
|
||||
keysToStore = append(keysToStore, api.DeviceMessage{DeviceKeys: &key, StreamID: 0})
|
||||
continue
|
||||
}
|
||||
|
||||
res.KeyError(key.UserID, key.DeviceID, &api.KeyError{
|
||||
Err: fmt.Sprintf(
|
||||
"user_id or device_id mismatch: users: %s - %s, devices: %s - %s",
|
||||
gotUserID, key.UserID, gotDeviceID, key.DeviceID,
|
||||
req.UserID, key.UserID, req.DeviceID, key.DeviceID,
|
||||
),
|
||||
})
|
||||
}
|
||||
|
@ -482,9 +470,11 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
|
|||
existingKeys := make([]api.DeviceMessage, len(keysToStore))
|
||||
for i := range keysToStore {
|
||||
existingKeys[i] = api.DeviceMessage{
|
||||
DeviceKeys: api.DeviceKeys{
|
||||
UserID: keysToStore[i].UserID,
|
||||
DeviceID: keysToStore[i].DeviceID,
|
||||
DeviceKeys: &gomatrixserverlib.DeviceKeys{
|
||||
RespUserDeviceKeys: gomatrixserverlib.RespUserDeviceKeys{
|
||||
UserID: keysToStore[i].UserID,
|
||||
DeviceID: keysToStore[i].DeviceID,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
@ -574,9 +564,18 @@ func emitDeviceKeyChanges(producer KeyChangeProducer, existing, new []api.Device
|
|||
for _, newKey := range new {
|
||||
exists := false
|
||||
for _, existingKey := range existing {
|
||||
newJSON, err := json.Marshal(newKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("json.Marshal(newKey): %w", err)
|
||||
}
|
||||
existingJSON, err := json.Marshal(existingKey)
|
||||
if err != nil {
|
||||
return fmt.Errorf("json.Marshal(existingKey): %w", err)
|
||||
}
|
||||
|
||||
// Do not treat the absence of keys as equal, or else we will not emit key changes
|
||||
// when users delete devices which never had a key to begin with as both KeyJSONs are nil.
|
||||
if bytes.Equal(existingKey.KeyJSON, newKey.KeyJSON) && len(existingKey.KeyJSON) > 0 {
|
||||
if bytes.Equal(existingJSON, newJSON) && len(existingJSON) > 0 {
|
||||
exists = true
|
||||
break
|
||||
}
|
||||
|
@ -594,7 +593,7 @@ func appendDisplayNames(existing, new []api.DeviceMessage) []api.DeviceMessage {
|
|||
if existingDevice.DeviceID != newDevice.DeviceID {
|
||||
continue
|
||||
}
|
||||
existingDevice.DisplayName = newDevice.DisplayName
|
||||
existingDevice.Unsigned["device_display_name"] = newDevice.DisplayName()
|
||||
existing[i] = existingDevice
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue