From 62bcd5ad4b352a183f0c28a8e848b6ac7c67ee9e Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 30 Jul 2021 16:27:55 +0100 Subject: [PATCH] Various refactoring --- clientapi/routing/keys.go | 15 ++-- federationapi/routing/devices.go | 12 +--- federationsender/consumers/keychange.go | 6 +- go.mod | 2 +- go.sum | 4 +- keyserver/api/api.go | 27 +------- keyserver/internal/device_list_update.go | 32 ++------- keyserver/internal/device_list_update_test.go | 62 +++++++++++------ keyserver/internal/internal.go | 69 +++++++++---------- .../storage/postgres/device_keys_table.go | 32 +++++---- .../storage/sqlite3/device_keys_table.go | 32 +++++---- keyserver/storage/storage_test.go | 43 +++++++----- userapi/internal/api.go | 24 ++++--- 13 files changed, 178 insertions(+), 182 deletions(-) diff --git a/clientapi/routing/keys.go b/clientapi/routing/keys.go index e2233642..ee1293b5 100644 --- a/clientapi/routing/keys.go +++ b/clientapi/routing/keys.go @@ -23,12 +23,13 @@ import ( "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/keyserver/api" userapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" ) type uploadKeysRequest struct { - DeviceKeys json.RawMessage `json:"device_keys"` - OneTimeKeys map[string]json.RawMessage `json:"one_time_keys"` + DeviceKeys gomatrixserverlib.DeviceKeys `json:"device_keys"` + OneTimeKeys map[string]json.RawMessage `json:"one_time_keys"` } func UploadKeys(req *http.Request, keyAPI api.KeyInternalAPI, device *userapi.Device) util.JSONResponse { @@ -42,14 +43,8 @@ func UploadKeys(req *http.Request, keyAPI api.KeyInternalAPI, device *userapi.De DeviceID: device.ID, UserID: device.UserID, } - if r.DeviceKeys != nil { - uploadReq.DeviceKeys = []api.DeviceKeys{ - { - DeviceID: device.ID, - UserID: device.UserID, - KeyJSON: r.DeviceKeys, - }, - } + if r.DeviceKeys.DeviceID != "" { + uploadReq.DeviceKeys = append(uploadReq.DeviceKeys, r.DeviceKeys) } if r.OneTimeKeys != nil { uploadReq.OneTimeKeys = []api.OneTimeKeys{ diff --git a/federationapi/routing/devices.go b/federationapi/routing/devices.go index 07862451..e70a0486 100644 --- a/federationapi/routing/devices.go +++ b/federationapi/routing/devices.go @@ -13,7 +13,6 @@ package routing import ( - "encoding/json" "net/http" "github.com/matrix-org/dendrite/clientapi/jsonerror" @@ -44,17 +43,10 @@ func GetUserDevices( } for _, dev := range res.Devices { - var key gomatrixserverlib.RespUserDeviceKeys - err := json.Unmarshal(dev.DeviceKeys.KeyJSON, &key) - if err != nil { - util.GetLogger(req.Context()).WithError(err).Warnf("malformed device key: %s", string(dev.DeviceKeys.KeyJSON)) - continue - } - device := gomatrixserverlib.RespUserDevice{ DeviceID: dev.DeviceID, - DisplayName: dev.DisplayName, - Keys: key, + DisplayName: dev.DisplayName(), + Keys: dev.RespUserDeviceKeys, } response.Devices = append(response.Devices, device) } diff --git a/federationsender/consumers/keychange.go b/federationsender/consumers/keychange.go index 9e146390..db1b8f60 100644 --- a/federationsender/consumers/keychange.go +++ b/federationsender/consumers/keychange.go @@ -119,11 +119,11 @@ func (t *KeyChangeConsumer) onMessage(msg *sarama.ConsumerMessage) error { event := gomatrixserverlib.DeviceListUpdateEvent{ UserID: m.UserID, DeviceID: m.DeviceID, - DeviceDisplayName: m.DisplayName, + DeviceDisplayName: m.DisplayName(), StreamID: m.StreamID, PrevID: prevID(m.StreamID), - Deleted: len(m.KeyJSON) == 0, - Keys: m.KeyJSON, + Deleted: m.DeviceKeys == nil, + Keys: m.DeviceKeys, } if edu.Content, err = json.Marshal(event); err != nil { return err diff --git a/go.mod b/go.mod index fbad10d6..cdfa60f8 100644 --- a/go.mod +++ b/go.mod @@ -31,7 +31,7 @@ require ( github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 - github.com/matrix-org/gomatrixserverlib v0.0.0-20210722110442-5061d6986876 + github.com/matrix-org/gomatrixserverlib v0.0.0-20210730143905-056ddf6fc446 github.com/matrix-org/naffka v0.0.0-20210623111924-14ff508b58e0 github.com/matrix-org/pinecone v0.0.0-20210623102758-74f885644c1b github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 diff --git a/go.sum b/go.sum index ff6090a5..5e35f7d4 100644 --- a/go.sum +++ b/go.sum @@ -1027,8 +1027,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d/go.mod h1 github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4= github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= -github.com/matrix-org/gomatrixserverlib v0.0.0-20210722110442-5061d6986876 h1:6ypwCtgRLK0v/hGWvnd847+KTo9BSkP9N0A4qSniP4E= -github.com/matrix-org/gomatrixserverlib v0.0.0-20210722110442-5061d6986876/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU= +github.com/matrix-org/gomatrixserverlib v0.0.0-20210730143905-056ddf6fc446 h1:ruNr86esoRKlzeJ9cRbOvARUvUowHBlX2sfzO93D+ws= +github.com/matrix-org/gomatrixserverlib v0.0.0-20210730143905-056ddf6fc446/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU= github.com/matrix-org/naffka v0.0.0-20210623111924-14ff508b58e0 h1:HZCzy4oVzz55e+cOMiX/JtSF2UOY1evBl2raaE7ACcU= github.com/matrix-org/naffka v0.0.0-20210623111924-14ff508b58e0/go.mod h1:sjyPyRxKM5uw1nD2cJ6O2OxI6GOqyVBfNXqKjBZTBZE= github.com/matrix-org/pinecone v0.0.0-20210623102758-74f885644c1b h1:5X5vdWQ13xrNkJVqaJHPsrt7rKkMJH5iac0EtfOuxSg= diff --git a/keyserver/api/api.go b/keyserver/api/api.go index 5cb287bc..07df9963 100644 --- a/keyserver/api/api.go +++ b/keyserver/api/api.go @@ -49,32 +49,11 @@ func (k *KeyError) Error() string { // DeviceMessage represents the message produced into Kafka by the key server. type DeviceMessage struct { - DeviceKeys + *gomatrixserverlib.DeviceKeys // A monotonically increasing number which represents device changes for this user. StreamID int } -// DeviceKeys represents a set of device keys for a single device -// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload -type DeviceKeys struct { - // The user who owns this device - UserID string - // The device ID of this device - DeviceID string - // The device display name - DisplayName string - // The raw device key JSON - KeyJSON []byte -} - -// WithStreamID returns a copy of this device message with the given stream ID -func (k *DeviceKeys) WithStreamID(streamID int) DeviceMessage { - return DeviceMessage{ - DeviceKeys: *k, - StreamID: streamID, - } -} - // OneTimeKeys represents a set of one-time keys for a single device // https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload type OneTimeKeys struct { @@ -110,7 +89,7 @@ type OneTimeKeysCount struct { type PerformUploadKeysRequest struct { UserID string // Required - User performing the request DeviceID string // Optional - Device performing the request, for fetching OTK count - DeviceKeys []DeviceKeys + DeviceKeys []gomatrixserverlib.DeviceKeys OneTimeKeys []OneTimeKeys // OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update // the display name for their respective device, and NOT to modify the keys. The key @@ -161,7 +140,7 @@ type QueryKeysResponse struct { // Map of remote server domain to error JSON Failures map[string]interface{} // Map of user_id to device_id to device_key - DeviceKeys map[string]map[string]json.RawMessage + DeviceKeys map[string]map[string]gomatrixserverlib.DeviceKeys // Set if there was a fatal error processing this query Error *KeyError } diff --git a/keyserver/internal/device_list_update.go b/keyserver/internal/device_list_update.go index 47bfb72c..de239271 100644 --- a/keyserver/internal/device_list_update.go +++ b/keyserver/internal/device_list_update.go @@ -16,7 +16,6 @@ package internal import ( "context" - "encoding/json" "fmt" "hash/fnv" "sync" @@ -215,23 +214,15 @@ func (u *DeviceListUpdater) update(ctx context.Context, event gomatrixserverlib. "prev_ids": event.PrevID, "display_name": event.DeviceDisplayName, "deleted": event.Deleted, + "keys": event.Keys, }).Info("DeviceListUpdater.Update") // if we haven't missed anything update the database and notify users if exists { - k := event.Keys - if event.Deleted { - k = nil - } keys := []api.DeviceMessage{ { - DeviceKeys: api.DeviceKeys{ - DeviceID: event.DeviceID, - DisplayName: event.DeviceDisplayName, - KeyJSON: k, - UserID: event.UserID, - }, - StreamID: event.StreamID, + DeviceKeys: event.Keys, + StreamID: event.StreamID, }, } err = u.db.StoreRemoteDeviceKeys(ctx, keys, nil) @@ -388,24 +379,15 @@ func (u *DeviceListUpdater) updateDeviceList(res *gomatrixserverlib.RespUserDevi keys := make([]api.DeviceMessage, len(res.Devices)) existingKeys := make([]api.DeviceMessage, len(res.Devices)) for i, device := range res.Devices { - keyJSON, err := json.Marshal(device.Keys) - if err != nil { - util.GetLogger(ctx).WithField("keys", device.Keys).Error("failed to marshal keys, skipping device") - continue - } keys[i] = api.DeviceMessage{ StreamID: res.StreamID, - DeviceKeys: api.DeviceKeys{ - DeviceID: device.DeviceID, - DisplayName: device.DisplayName, - UserID: res.UserID, - KeyJSON: keyJSON, + DeviceKeys: &gomatrixserverlib.DeviceKeys{ + RespUserDeviceKeys: device.Keys, }, } existingKeys[i] = api.DeviceMessage{ - DeviceKeys: api.DeviceKeys{ - UserID: res.UserID, - DeviceID: device.DeviceID, + DeviceKeys: &gomatrixserverlib.DeviceKeys{ + RespUserDeviceKeys: device.Keys, }, } } diff --git a/keyserver/internal/device_list_update_test.go b/keyserver/internal/device_list_update_test.go index eab2a78d..344d091b 100644 --- a/keyserver/internal/device_list_update_test.go +++ b/keyserver/internal/device_list_update_test.go @@ -17,6 +17,7 @@ package internal import ( "context" "crypto/ed25519" + "encoding/json" "fmt" "io/ioutil" "net/http" @@ -128,10 +129,19 @@ func TestUpdateHavePrevID(t *testing.T) { DeviceDisplayName: "Foo Bar", Deleted: false, DeviceID: "FOO", - Keys: []byte(`{"key":"value"}`), - PrevID: []int{0}, - StreamID: 1, - UserID: "@alice:localhost", + Keys: &gomatrixserverlib.DeviceKeys{ + RespUserDeviceKeys: gomatrixserverlib.RespUserDeviceKeys{ + DeviceID: "FOO", + UserID: "@alice:localhost", + Algorithms: []string{"TEST"}, + Keys: map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{ + "key": {1, 2, 3, 4, 5, 6}, + }, + }, + }, + PrevID: []int{0}, + StreamID: 1, + UserID: "@alice:localhost", } err := updater.Update(ctx, event) if err != nil { @@ -139,11 +149,15 @@ func TestUpdateHavePrevID(t *testing.T) { } want := api.DeviceMessage{ StreamID: event.StreamID, - DeviceKeys: api.DeviceKeys{ - DeviceID: event.DeviceID, - DisplayName: event.DeviceDisplayName, - KeyJSON: event.Keys, - UserID: event.UserID, + DeviceKeys: &gomatrixserverlib.DeviceKeys{ + RespUserDeviceKeys: gomatrixserverlib.RespUserDeviceKeys{ + DeviceID: "FOO", + UserID: "@alice:localhost", + Algorithms: []string{"TEST"}, + Keys: map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{ + "key": {1, 2, 3, 4, 5, 6}, + }, + }, }, } if !reflect.DeepEqual(producer.events, []api.DeviceMessage{want}) { @@ -201,11 +215,16 @@ func TestUpdateNoPrevID(t *testing.T) { DeviceDisplayName: "Mobile Phone", Deleted: false, DeviceID: "another_device_id", - Keys: []byte(`{"key":"value"}`), + Keys: &gomatrixserverlib.DeviceKeys{}, PrevID: []int{3}, StreamID: 4, UserID: remoteUserID, } + if err := json.Unmarshal([]byte(keyJSON), event.Keys); err != nil { + t.Fatal(err) + } + event.DeviceID = "another_device" + event.Keys.DeviceID = "another_device" err := updater.Update(ctx, event) if err != nil { t.Fatalf("Update returned an error: %s", err) @@ -215,24 +234,25 @@ func TestUpdateNoPrevID(t *testing.T) { // wait a bit for db to be updated... time.Sleep(100 * time.Millisecond) want := api.DeviceMessage{ - StreamID: 5, - DeviceKeys: api.DeviceKeys{ - DeviceID: "JLAFKJWSCS", - DisplayName: "Mobile Phone", - UserID: remoteUserID, - KeyJSON: []byte(keyJSON), - }, + StreamID: 5, + DeviceKeys: &gomatrixserverlib.DeviceKeys{}, + } + if err := json.Unmarshal([]byte(keyJSON), want.DeviceKeys); err != nil { + t.Fatal(err) + } + want.Unsigned = map[string]interface{}{ + "device_display_name": "Mobile Phone", } // Now we should have a fresh list and the keys and emitted something if db.staleUsers[event.UserID] { t.Errorf("%s still marked as stale", event.UserID) } - if !reflect.DeepEqual(producer.events, []api.DeviceMessage{want}) { - t.Logf("len got %d len want %d", len(producer.events[0].KeyJSON), len(want.KeyJSON)) + if len(producer.events) != 1 { + t.Logf("len got %+v len want 1", len(producer.events)) t.Errorf("Update didn't produce correct event, got %v want %v", producer.events, want) } - if !reflect.DeepEqual(db.storedKeys, []api.DeviceMessage{want}) { - t.Errorf("DB didn't store correct event, got %v want %v", db.storedKeys, want) + if !reflect.DeepEqual(*db.storedKeys[0].DeviceKeys, *want.DeviceKeys) { + t.Errorf("DB didn't store correct event\ngot %+v\nwant %+v", *db.storedKeys[0].DeviceKeys, *want.DeviceKeys) } } diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index f53a0761..cb9af4c4 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -30,8 +30,6 @@ import ( "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" "github.com/sirupsen/logrus" - "github.com/tidwall/gjson" - "github.com/tidwall/sjson" ) type KeyInternalAPI struct { @@ -210,7 +208,7 @@ func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.Query // remove deleted devices var result []api.DeviceMessage for _, m := range msgs { - if m.KeyJSON == nil { + if m.DeviceKeys == nil { continue } result = append(result, m) @@ -220,7 +218,7 @@ func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.Query } func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) { - res.DeviceKeys = make(map[string]map[string]json.RawMessage) + res.DeviceKeys = make(map[string]map[string]gomatrixserverlib.DeviceKeys) res.Failures = make(map[string]interface{}) // make a map from domain to device keys domainToDeviceKeys := make(map[string]map[string][]string) @@ -254,21 +252,19 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques } if res.DeviceKeys[userID] == nil { - res.DeviceKeys[userID] = make(map[string]json.RawMessage) + res.DeviceKeys[userID] = make(map[string]gomatrixserverlib.DeviceKeys) } for _, dk := range deviceKeys { - if len(dk.KeyJSON) == 0 { + if dk.DeviceKeys == nil { continue // don't include blank keys } // inject display name if known (either locally or remotely) - displayName := dk.DisplayName + displayName := dk.DisplayName() if queryRes.DeviceInfo[dk.DeviceID].DisplayName != "" { displayName = queryRes.DeviceInfo[dk.DeviceID].DisplayName } - dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct { - DisplayName string `json:"device_display_name,omitempty"` - }{displayName}) - res.DeviceKeys[userID][dk.DeviceID] = dk.KeyJSON + dk.Unsigned["device_display_name"] = displayName + res.DeviceKeys[userID][dk.DeviceID] = *dk.DeviceKeys } } else { domainToDeviceKeys[domain] = make(map[string][]string) @@ -335,13 +331,9 @@ func (a *KeyInternalAPI) queryRemoteKeys( for result := range resultCh { for userID, nest := range result.DeviceKeys { - res.DeviceKeys[userID] = make(map[string]json.RawMessage) + res.DeviceKeys[userID] = make(map[string]gomatrixserverlib.DeviceKeys) for deviceID, deviceKey := range nest { - keyJSON, err := json.Marshal(deviceKey) - if err != nil { - continue - } - res.DeviceKeys[userID][deviceID] = keyJSON + res.DeviceKeys[userID][deviceID] = deviceKey } } } @@ -432,18 +424,16 @@ func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase( return fmt.Errorf("DeviceKeysForUser %s returned no keys but wanted all keys, falling back to remote", userID) } if res.DeviceKeys[userID] == nil { - res.DeviceKeys[userID] = make(map[string]json.RawMessage) + res.DeviceKeys[userID] = make(map[string]gomatrixserverlib.DeviceKeys) } for _, key := range keys { - if len(key.KeyJSON) == 0 { + if key.DeviceKeys == nil { continue // ignore deleted keys } // inject the display name - key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct { - DisplayName string `json:"device_display_name,omitempty"` - }{key.DisplayName}) - res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON + key.DeviceKeys.Unsigned["device_display_name"] = key.DisplayName() + res.DeviceKeys[userID][key.DeviceID] = *key.DeviceKeys } return nil } @@ -459,21 +449,19 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per if serverName != a.ThisServer { continue // ignore remote users } - if len(key.KeyJSON) == 0 { - keysToStore = append(keysToStore, key.WithStreamID(0)) + if len(key.Keys) == 0 { + keysToStore = append(keysToStore, api.DeviceMessage{DeviceKeys: &key, StreamID: 0}) continue // deleted keys don't need sanity checking } - gotUserID := gjson.GetBytes(key.KeyJSON, "user_id").Str - gotDeviceID := gjson.GetBytes(key.KeyJSON, "device_id").Str - if gotUserID == key.UserID && gotDeviceID == key.DeviceID { - keysToStore = append(keysToStore, key.WithStreamID(0)) + if req.UserID == key.UserID && req.DeviceID == key.DeviceID { + keysToStore = append(keysToStore, api.DeviceMessage{DeviceKeys: &key, StreamID: 0}) continue } res.KeyError(key.UserID, key.DeviceID, &api.KeyError{ Err: fmt.Sprintf( "user_id or device_id mismatch: users: %s - %s, devices: %s - %s", - gotUserID, key.UserID, gotDeviceID, key.DeviceID, + req.UserID, key.UserID, req.DeviceID, key.DeviceID, ), }) } @@ -482,9 +470,11 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per existingKeys := make([]api.DeviceMessage, len(keysToStore)) for i := range keysToStore { existingKeys[i] = api.DeviceMessage{ - DeviceKeys: api.DeviceKeys{ - UserID: keysToStore[i].UserID, - DeviceID: keysToStore[i].DeviceID, + DeviceKeys: &gomatrixserverlib.DeviceKeys{ + RespUserDeviceKeys: gomatrixserverlib.RespUserDeviceKeys{ + UserID: keysToStore[i].UserID, + DeviceID: keysToStore[i].DeviceID, + }, }, } } @@ -574,9 +564,18 @@ func emitDeviceKeyChanges(producer KeyChangeProducer, existing, new []api.Device for _, newKey := range new { exists := false for _, existingKey := range existing { + newJSON, err := json.Marshal(newKey) + if err != nil { + return fmt.Errorf("json.Marshal(newKey): %w", err) + } + existingJSON, err := json.Marshal(existingKey) + if err != nil { + return fmt.Errorf("json.Marshal(existingKey): %w", err) + } + // Do not treat the absence of keys as equal, or else we will not emit key changes // when users delete devices which never had a key to begin with as both KeyJSONs are nil. - if bytes.Equal(existingKey.KeyJSON, newKey.KeyJSON) && len(existingKey.KeyJSON) > 0 { + if bytes.Equal(existingJSON, newJSON) && len(existingJSON) > 0 { exists = true break } @@ -594,7 +593,7 @@ func appendDisplayNames(existing, new []api.DeviceMessage) []api.DeviceMessage { if existingDevice.DeviceID != newDevice.DeviceID { continue } - existingDevice.DisplayName = newDevice.DisplayName + existingDevice.Unsigned["device_display_name"] = newDevice.DisplayName() existing[i] = existingDevice } } diff --git a/keyserver/storage/postgres/device_keys_table.go b/keyserver/storage/postgres/device_keys_table.go index 95064fc8..243fd2f7 100644 --- a/keyserver/storage/postgres/device_keys_table.go +++ b/keyserver/storage/postgres/device_keys_table.go @@ -24,6 +24,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/storage/tables" + "github.com/matrix-org/gomatrixserverlib" ) var deviceKeysSchema = ` @@ -105,19 +106,23 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { } func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { - for i, key := range keys { - var keyJSONStr string + for _, key := range keys { var streamID int var displayName sql.NullString - err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName) + if key.DeviceKeys == nil { + key.DeviceKeys = &gomatrixserverlib.DeviceKeys{} + } + err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&key.DeviceKeys, &streamID, &displayName) if err != nil && err != sql.ErrNoRows { return err } // this will be '' when there is no device - keys[i].KeyJSON = []byte(keyJSONStr) - keys[i].StreamID = streamID + key.StreamID = streamID if displayName.Valid { - keys[i].DisplayName = displayName.String + if key.DeviceKeys.Unsigned == nil { + key.DeviceKeys.Unsigned = make(map[string]interface{}) + } + key.DeviceKeys.Unsigned["device_display_name"] = displayName.String } } return nil @@ -153,7 +158,7 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx for _, key := range keys { now := time.Now().Unix() _, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext( - ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName, + ctx, key.UserID, key.DeviceID, now, key.DeviceKeys, key.StreamID, key.DisplayName, ) if err != nil { return err @@ -179,18 +184,21 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID } var result []api.DeviceMessage for rows.Next() { - var dk api.DeviceMessage + dk := api.DeviceMessage{ + DeviceKeys: &gomatrixserverlib.DeviceKeys{}, + } dk.UserID = userID - var keyJSON string var streamID int var displayName sql.NullString - if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil { + if err := rows.Scan(&dk.DeviceID, &dk.DeviceKeys, &streamID, &displayName); err != nil { return nil, err } - dk.KeyJSON = []byte(keyJSON) dk.StreamID = streamID if displayName.Valid { - dk.DisplayName = displayName.String + if dk.DeviceKeys.Unsigned == nil { + dk.DeviceKeys.Unsigned = make(map[string]interface{}) + } + dk.DeviceKeys.Unsigned["device_display_name"] = displayName.String } // include the key if we want all keys (no device) or it was asked if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 { diff --git a/keyserver/storage/sqlite3/device_keys_table.go b/keyserver/storage/sqlite3/device_keys_table.go index 9112fc6e..d569dec4 100644 --- a/keyserver/storage/sqlite3/device_keys_table.go +++ b/keyserver/storage/sqlite3/device_keys_table.go @@ -24,6 +24,7 @@ import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/storage/tables" + "github.com/matrix-org/gomatrixserverlib" ) var deviceKeysSchema = ` @@ -113,18 +114,21 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed") var result []api.DeviceMessage for rows.Next() { - var dk api.DeviceMessage + dk := api.DeviceMessage{ + DeviceKeys: &gomatrixserverlib.DeviceKeys{}, + } dk.UserID = userID - var keyJSON string var streamID int var displayName sql.NullString - if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil { + if err := rows.Scan(&dk.DeviceID, &dk.DeviceKeys, &streamID, &displayName); err != nil { return nil, err } - dk.KeyJSON = []byte(keyJSON) dk.StreamID = streamID if displayName.Valid { - dk.DisplayName = displayName.String + if dk.DeviceKeys.Unsigned == nil { + dk.DeviceKeys.Unsigned = make(map[string]interface{}) + } + dk.DeviceKeys.Unsigned["device_display_name"] = displayName.String } // include the key if we want all keys (no device) or it was asked if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 { @@ -135,19 +139,23 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID } func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error { - for i, key := range keys { - var keyJSONStr string + for _, key := range keys { var streamID int var displayName sql.NullString - err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName) + if key.DeviceKeys == nil { + key.DeviceKeys = &gomatrixserverlib.DeviceKeys{} + } + err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&key.DeviceKeys, &streamID, &displayName) if err != nil && err != sql.ErrNoRows { return err } // this will be '' when there is no device - keys[i].KeyJSON = []byte(keyJSONStr) - keys[i].StreamID = streamID + key.StreamID = streamID if displayName.Valid { - keys[i].DisplayName = displayName.String + if key.DeviceKeys.Unsigned == nil { + key.DeviceKeys.Unsigned = make(map[string]interface{}) + } + key.DeviceKeys.Unsigned["device_display_name"] = displayName.String } } return nil @@ -189,7 +197,7 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx for _, key := range keys { now := time.Now().Unix() _, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext( - ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName, + ctx, key.UserID, key.DeviceID, now, key.DeviceKeys, key.StreamID, key.DisplayName(), ) if err != nil { return err diff --git a/keyserver/storage/storage_test.go b/keyserver/storage/storage_test.go index afdb086d..75d1a1ba 100644 --- a/keyserver/storage/storage_test.go +++ b/keyserver/storage/storage_test.go @@ -12,6 +12,7 @@ import ( "github.com/Shopify/sarama" "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" ) var ctx = context.Background() @@ -105,28 +106,34 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) { bob := "@bob:TestDeviceKeysStreamIDGeneration" msgs := []api.DeviceMessage{ { - DeviceKeys: api.DeviceKeys{ - DeviceID: "AAA", - UserID: alice, - KeyJSON: []byte(`{"key":"v1"}`), + DeviceKeys: &gomatrixserverlib.DeviceKeys{ + RespUserDeviceKeys: gomatrixserverlib.RespUserDeviceKeys{ + DeviceID: "AAA", + UserID: alice, + Algorithms: []string{"v1"}, + }, }, // StreamID: 1 }, { - DeviceKeys: api.DeviceKeys{ - DeviceID: "AAA", - UserID: bob, - KeyJSON: []byte(`{"key":"v1"}`), + DeviceKeys: &gomatrixserverlib.DeviceKeys{ + RespUserDeviceKeys: gomatrixserverlib.RespUserDeviceKeys{ + DeviceID: "AAA", + UserID: bob, + Algorithms: []string{"v1"}, + }, }, // StreamID: 1 as this is a different user }, { - DeviceKeys: api.DeviceKeys{ - DeviceID: "another_device", - UserID: alice, - KeyJSON: []byte(`{"key":"v1"}`), + DeviceKeys: &gomatrixserverlib.DeviceKeys{ + RespUserDeviceKeys: gomatrixserverlib.RespUserDeviceKeys{ + DeviceID: "another_device", + UserID: alice, + Algorithms: []string{"v2"}, + }, }, - // StreamID: 2 as this is a 2nd device key + // StreamID: 2 as this is a 2nd key }, } MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs)) @@ -143,10 +150,12 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) { // updating a device sets the next stream ID for that user msgs = []api.DeviceMessage{ { - DeviceKeys: api.DeviceKeys{ - DeviceID: "AAA", - UserID: alice, - KeyJSON: []byte(`{"key":"v2"}`), + DeviceKeys: &gomatrixserverlib.DeviceKeys{ + RespUserDeviceKeys: gomatrixserverlib.RespUserDeviceKeys{ + DeviceID: "AAA", + UserID: alice, + Algorithms: []string{"v3"}, + }, }, // StreamID: 3 }, diff --git a/userapi/internal/api.go b/userapi/internal/api.go index a2bc8ecf..fc6868d1 100644 --- a/userapi/internal/api.go +++ b/userapi/internal/api.go @@ -150,12 +150,13 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe } func (a *UserInternalAPI) deviceListUpdate(userID string, deviceIDs []string) error { - deviceKeys := make([]keyapi.DeviceKeys, len(deviceIDs)) + deviceKeys := make([]gomatrixserverlib.DeviceKeys, len(deviceIDs)) for i, did := range deviceIDs { - deviceKeys[i] = keyapi.DeviceKeys{ - UserID: userID, - DeviceID: did, - KeyJSON: nil, + deviceKeys[i] = gomatrixserverlib.DeviceKeys{ + RespUserDeviceKeys: gomatrixserverlib.RespUserDeviceKeys{ + UserID: userID, + DeviceID: did, + }, } } @@ -219,12 +220,15 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf var uploadRes keyapi.PerformUploadKeysResponse a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{ UserID: req.RequestingUserID, - DeviceKeys: []keyapi.DeviceKeys{ + DeviceKeys: []gomatrixserverlib.DeviceKeys{ { - DeviceID: dev.ID, - DisplayName: *req.DisplayName, - KeyJSON: nil, - UserID: dev.UserID, + RespUserDeviceKeys: gomatrixserverlib.RespUserDeviceKeys{ + DeviceID: dev.ID, + UserID: dev.UserID, + }, + Unsigned: map[string]interface{}{ + "device_display_name": *req.DisplayName, + }, }, }, OnlyDisplayNameUpdates: true,