mirror of
https://github.com/hoernschen/dendrite.git
synced 2024-12-27 23:48:27 +00:00
Various refactoring
This commit is contained in:
parent
ed4097825b
commit
62bcd5ad4b
13 changed files with 178 additions and 182 deletions
|
@ -23,12 +23,13 @@ import (
|
||||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
"github.com/matrix-org/dendrite/keyserver/api"
|
"github.com/matrix-org/dendrite/keyserver/api"
|
||||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
type uploadKeysRequest struct {
|
type uploadKeysRequest struct {
|
||||||
DeviceKeys json.RawMessage `json:"device_keys"`
|
DeviceKeys gomatrixserverlib.DeviceKeys `json:"device_keys"`
|
||||||
OneTimeKeys map[string]json.RawMessage `json:"one_time_keys"`
|
OneTimeKeys map[string]json.RawMessage `json:"one_time_keys"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func UploadKeys(req *http.Request, keyAPI api.KeyInternalAPI, device *userapi.Device) util.JSONResponse {
|
func UploadKeys(req *http.Request, keyAPI api.KeyInternalAPI, device *userapi.Device) util.JSONResponse {
|
||||||
|
@ -42,14 +43,8 @@ func UploadKeys(req *http.Request, keyAPI api.KeyInternalAPI, device *userapi.De
|
||||||
DeviceID: device.ID,
|
DeviceID: device.ID,
|
||||||
UserID: device.UserID,
|
UserID: device.UserID,
|
||||||
}
|
}
|
||||||
if r.DeviceKeys != nil {
|
if r.DeviceKeys.DeviceID != "" {
|
||||||
uploadReq.DeviceKeys = []api.DeviceKeys{
|
uploadReq.DeviceKeys = append(uploadReq.DeviceKeys, r.DeviceKeys)
|
||||||
{
|
|
||||||
DeviceID: device.ID,
|
|
||||||
UserID: device.UserID,
|
|
||||||
KeyJSON: r.DeviceKeys,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if r.OneTimeKeys != nil {
|
if r.OneTimeKeys != nil {
|
||||||
uploadReq.OneTimeKeys = []api.OneTimeKeys{
|
uploadReq.OneTimeKeys = []api.OneTimeKeys{
|
||||||
|
|
|
@ -13,7 +13,6 @@
|
||||||
package routing
|
package routing
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||||
|
@ -44,17 +43,10 @@ func GetUserDevices(
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, dev := range res.Devices {
|
for _, dev := range res.Devices {
|
||||||
var key gomatrixserverlib.RespUserDeviceKeys
|
|
||||||
err := json.Unmarshal(dev.DeviceKeys.KeyJSON, &key)
|
|
||||||
if err != nil {
|
|
||||||
util.GetLogger(req.Context()).WithError(err).Warnf("malformed device key: %s", string(dev.DeviceKeys.KeyJSON))
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
device := gomatrixserverlib.RespUserDevice{
|
device := gomatrixserverlib.RespUserDevice{
|
||||||
DeviceID: dev.DeviceID,
|
DeviceID: dev.DeviceID,
|
||||||
DisplayName: dev.DisplayName,
|
DisplayName: dev.DisplayName(),
|
||||||
Keys: key,
|
Keys: dev.RespUserDeviceKeys,
|
||||||
}
|
}
|
||||||
response.Devices = append(response.Devices, device)
|
response.Devices = append(response.Devices, device)
|
||||||
}
|
}
|
||||||
|
|
|
@ -119,11 +119,11 @@ func (t *KeyChangeConsumer) onMessage(msg *sarama.ConsumerMessage) error {
|
||||||
event := gomatrixserverlib.DeviceListUpdateEvent{
|
event := gomatrixserverlib.DeviceListUpdateEvent{
|
||||||
UserID: m.UserID,
|
UserID: m.UserID,
|
||||||
DeviceID: m.DeviceID,
|
DeviceID: m.DeviceID,
|
||||||
DeviceDisplayName: m.DisplayName,
|
DeviceDisplayName: m.DisplayName(),
|
||||||
StreamID: m.StreamID,
|
StreamID: m.StreamID,
|
||||||
PrevID: prevID(m.StreamID),
|
PrevID: prevID(m.StreamID),
|
||||||
Deleted: len(m.KeyJSON) == 0,
|
Deleted: m.DeviceKeys == nil,
|
||||||
Keys: m.KeyJSON,
|
Keys: m.DeviceKeys,
|
||||||
}
|
}
|
||||||
if edu.Content, err = json.Marshal(event); err != nil {
|
if edu.Content, err = json.Marshal(event); err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
2
go.mod
2
go.mod
|
@ -31,7 +31,7 @@ require (
|
||||||
github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4
|
github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4
|
||||||
github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d
|
github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d
|
||||||
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16
|
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16
|
||||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20210722110442-5061d6986876
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20210730143905-056ddf6fc446
|
||||||
github.com/matrix-org/naffka v0.0.0-20210623111924-14ff508b58e0
|
github.com/matrix-org/naffka v0.0.0-20210623111924-14ff508b58e0
|
||||||
github.com/matrix-org/pinecone v0.0.0-20210623102758-74f885644c1b
|
github.com/matrix-org/pinecone v0.0.0-20210623102758-74f885644c1b
|
||||||
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
|
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
|
||||||
|
|
4
go.sum
4
go.sum
|
@ -1027,8 +1027,8 @@ github.com/matrix-org/go-sqlite3-js v0.0.0-20210709140738-b0d1ba599a6d/go.mod h1
|
||||||
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
|
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
|
||||||
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4=
|
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 h1:ZtO5uywdd5dLDCud4r0r55eP4j9FuUNpl60Gmntcop4=
|
||||||
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
|
github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
|
||||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20210722110442-5061d6986876 h1:6ypwCtgRLK0v/hGWvnd847+KTo9BSkP9N0A4qSniP4E=
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20210730143905-056ddf6fc446 h1:ruNr86esoRKlzeJ9cRbOvARUvUowHBlX2sfzO93D+ws=
|
||||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20210722110442-5061d6986876/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU=
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20210730143905-056ddf6fc446/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU=
|
||||||
github.com/matrix-org/naffka v0.0.0-20210623111924-14ff508b58e0 h1:HZCzy4oVzz55e+cOMiX/JtSF2UOY1evBl2raaE7ACcU=
|
github.com/matrix-org/naffka v0.0.0-20210623111924-14ff508b58e0 h1:HZCzy4oVzz55e+cOMiX/JtSF2UOY1evBl2raaE7ACcU=
|
||||||
github.com/matrix-org/naffka v0.0.0-20210623111924-14ff508b58e0/go.mod h1:sjyPyRxKM5uw1nD2cJ6O2OxI6GOqyVBfNXqKjBZTBZE=
|
github.com/matrix-org/naffka v0.0.0-20210623111924-14ff508b58e0/go.mod h1:sjyPyRxKM5uw1nD2cJ6O2OxI6GOqyVBfNXqKjBZTBZE=
|
||||||
github.com/matrix-org/pinecone v0.0.0-20210623102758-74f885644c1b h1:5X5vdWQ13xrNkJVqaJHPsrt7rKkMJH5iac0EtfOuxSg=
|
github.com/matrix-org/pinecone v0.0.0-20210623102758-74f885644c1b h1:5X5vdWQ13xrNkJVqaJHPsrt7rKkMJH5iac0EtfOuxSg=
|
||||||
|
|
|
@ -49,32 +49,11 @@ func (k *KeyError) Error() string {
|
||||||
|
|
||||||
// DeviceMessage represents the message produced into Kafka by the key server.
|
// DeviceMessage represents the message produced into Kafka by the key server.
|
||||||
type DeviceMessage struct {
|
type DeviceMessage struct {
|
||||||
DeviceKeys
|
*gomatrixserverlib.DeviceKeys
|
||||||
// A monotonically increasing number which represents device changes for this user.
|
// A monotonically increasing number which represents device changes for this user.
|
||||||
StreamID int
|
StreamID int
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeviceKeys represents a set of device keys for a single device
|
|
||||||
// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload
|
|
||||||
type DeviceKeys struct {
|
|
||||||
// The user who owns this device
|
|
||||||
UserID string
|
|
||||||
// The device ID of this device
|
|
||||||
DeviceID string
|
|
||||||
// The device display name
|
|
||||||
DisplayName string
|
|
||||||
// The raw device key JSON
|
|
||||||
KeyJSON []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithStreamID returns a copy of this device message with the given stream ID
|
|
||||||
func (k *DeviceKeys) WithStreamID(streamID int) DeviceMessage {
|
|
||||||
return DeviceMessage{
|
|
||||||
DeviceKeys: *k,
|
|
||||||
StreamID: streamID,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// OneTimeKeys represents a set of one-time keys for a single device
|
// OneTimeKeys represents a set of one-time keys for a single device
|
||||||
// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload
|
// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload
|
||||||
type OneTimeKeys struct {
|
type OneTimeKeys struct {
|
||||||
|
@ -110,7 +89,7 @@ type OneTimeKeysCount struct {
|
||||||
type PerformUploadKeysRequest struct {
|
type PerformUploadKeysRequest struct {
|
||||||
UserID string // Required - User performing the request
|
UserID string // Required - User performing the request
|
||||||
DeviceID string // Optional - Device performing the request, for fetching OTK count
|
DeviceID string // Optional - Device performing the request, for fetching OTK count
|
||||||
DeviceKeys []DeviceKeys
|
DeviceKeys []gomatrixserverlib.DeviceKeys
|
||||||
OneTimeKeys []OneTimeKeys
|
OneTimeKeys []OneTimeKeys
|
||||||
// OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update
|
// OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update
|
||||||
// the display name for their respective device, and NOT to modify the keys. The key
|
// the display name for their respective device, and NOT to modify the keys. The key
|
||||||
|
@ -161,7 +140,7 @@ type QueryKeysResponse struct {
|
||||||
// Map of remote server domain to error JSON
|
// Map of remote server domain to error JSON
|
||||||
Failures map[string]interface{}
|
Failures map[string]interface{}
|
||||||
// Map of user_id to device_id to device_key
|
// Map of user_id to device_id to device_key
|
||||||
DeviceKeys map[string]map[string]json.RawMessage
|
DeviceKeys map[string]map[string]gomatrixserverlib.DeviceKeys
|
||||||
// Set if there was a fatal error processing this query
|
// Set if there was a fatal error processing this query
|
||||||
Error *KeyError
|
Error *KeyError
|
||||||
}
|
}
|
||||||
|
|
|
@ -16,7 +16,6 @@ package internal
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"hash/fnv"
|
"hash/fnv"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -215,23 +214,15 @@ func (u *DeviceListUpdater) update(ctx context.Context, event gomatrixserverlib.
|
||||||
"prev_ids": event.PrevID,
|
"prev_ids": event.PrevID,
|
||||||
"display_name": event.DeviceDisplayName,
|
"display_name": event.DeviceDisplayName,
|
||||||
"deleted": event.Deleted,
|
"deleted": event.Deleted,
|
||||||
|
"keys": event.Keys,
|
||||||
}).Info("DeviceListUpdater.Update")
|
}).Info("DeviceListUpdater.Update")
|
||||||
|
|
||||||
// if we haven't missed anything update the database and notify users
|
// if we haven't missed anything update the database and notify users
|
||||||
if exists {
|
if exists {
|
||||||
k := event.Keys
|
|
||||||
if event.Deleted {
|
|
||||||
k = nil
|
|
||||||
}
|
|
||||||
keys := []api.DeviceMessage{
|
keys := []api.DeviceMessage{
|
||||||
{
|
{
|
||||||
DeviceKeys: api.DeviceKeys{
|
DeviceKeys: event.Keys,
|
||||||
DeviceID: event.DeviceID,
|
StreamID: event.StreamID,
|
||||||
DisplayName: event.DeviceDisplayName,
|
|
||||||
KeyJSON: k,
|
|
||||||
UserID: event.UserID,
|
|
||||||
},
|
|
||||||
StreamID: event.StreamID,
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
err = u.db.StoreRemoteDeviceKeys(ctx, keys, nil)
|
err = u.db.StoreRemoteDeviceKeys(ctx, keys, nil)
|
||||||
|
@ -388,24 +379,15 @@ func (u *DeviceListUpdater) updateDeviceList(res *gomatrixserverlib.RespUserDevi
|
||||||
keys := make([]api.DeviceMessage, len(res.Devices))
|
keys := make([]api.DeviceMessage, len(res.Devices))
|
||||||
existingKeys := make([]api.DeviceMessage, len(res.Devices))
|
existingKeys := make([]api.DeviceMessage, len(res.Devices))
|
||||||
for i, device := range res.Devices {
|
for i, device := range res.Devices {
|
||||||
keyJSON, err := json.Marshal(device.Keys)
|
|
||||||
if err != nil {
|
|
||||||
util.GetLogger(ctx).WithField("keys", device.Keys).Error("failed to marshal keys, skipping device")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
keys[i] = api.DeviceMessage{
|
keys[i] = api.DeviceMessage{
|
||||||
StreamID: res.StreamID,
|
StreamID: res.StreamID,
|
||||||
DeviceKeys: api.DeviceKeys{
|
DeviceKeys: &gomatrixserverlib.DeviceKeys{
|
||||||
DeviceID: device.DeviceID,
|
RespUserDeviceKeys: device.Keys,
|
||||||
DisplayName: device.DisplayName,
|
|
||||||
UserID: res.UserID,
|
|
||||||
KeyJSON: keyJSON,
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
existingKeys[i] = api.DeviceMessage{
|
existingKeys[i] = api.DeviceMessage{
|
||||||
DeviceKeys: api.DeviceKeys{
|
DeviceKeys: &gomatrixserverlib.DeviceKeys{
|
||||||
UserID: res.UserID,
|
RespUserDeviceKeys: device.Keys,
|
||||||
DeviceID: device.DeviceID,
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,6 +17,7 @@ package internal
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/ed25519"
|
"crypto/ed25519"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -128,10 +129,19 @@ func TestUpdateHavePrevID(t *testing.T) {
|
||||||
DeviceDisplayName: "Foo Bar",
|
DeviceDisplayName: "Foo Bar",
|
||||||
Deleted: false,
|
Deleted: false,
|
||||||
DeviceID: "FOO",
|
DeviceID: "FOO",
|
||||||
Keys: []byte(`{"key":"value"}`),
|
Keys: &gomatrixserverlib.DeviceKeys{
|
||||||
PrevID: []int{0},
|
RespUserDeviceKeys: gomatrixserverlib.RespUserDeviceKeys{
|
||||||
StreamID: 1,
|
DeviceID: "FOO",
|
||||||
UserID: "@alice:localhost",
|
UserID: "@alice:localhost",
|
||||||
|
Algorithms: []string{"TEST"},
|
||||||
|
Keys: map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{
|
||||||
|
"key": {1, 2, 3, 4, 5, 6},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
PrevID: []int{0},
|
||||||
|
StreamID: 1,
|
||||||
|
UserID: "@alice:localhost",
|
||||||
}
|
}
|
||||||
err := updater.Update(ctx, event)
|
err := updater.Update(ctx, event)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -139,11 +149,15 @@ func TestUpdateHavePrevID(t *testing.T) {
|
||||||
}
|
}
|
||||||
want := api.DeviceMessage{
|
want := api.DeviceMessage{
|
||||||
StreamID: event.StreamID,
|
StreamID: event.StreamID,
|
||||||
DeviceKeys: api.DeviceKeys{
|
DeviceKeys: &gomatrixserverlib.DeviceKeys{
|
||||||
DeviceID: event.DeviceID,
|
RespUserDeviceKeys: gomatrixserverlib.RespUserDeviceKeys{
|
||||||
DisplayName: event.DeviceDisplayName,
|
DeviceID: "FOO",
|
||||||
KeyJSON: event.Keys,
|
UserID: "@alice:localhost",
|
||||||
UserID: event.UserID,
|
Algorithms: []string{"TEST"},
|
||||||
|
Keys: map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{
|
||||||
|
"key": {1, 2, 3, 4, 5, 6},
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(producer.events, []api.DeviceMessage{want}) {
|
if !reflect.DeepEqual(producer.events, []api.DeviceMessage{want}) {
|
||||||
|
@ -201,11 +215,16 @@ func TestUpdateNoPrevID(t *testing.T) {
|
||||||
DeviceDisplayName: "Mobile Phone",
|
DeviceDisplayName: "Mobile Phone",
|
||||||
Deleted: false,
|
Deleted: false,
|
||||||
DeviceID: "another_device_id",
|
DeviceID: "another_device_id",
|
||||||
Keys: []byte(`{"key":"value"}`),
|
Keys: &gomatrixserverlib.DeviceKeys{},
|
||||||
PrevID: []int{3},
|
PrevID: []int{3},
|
||||||
StreamID: 4,
|
StreamID: 4,
|
||||||
UserID: remoteUserID,
|
UserID: remoteUserID,
|
||||||
}
|
}
|
||||||
|
if err := json.Unmarshal([]byte(keyJSON), event.Keys); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
event.DeviceID = "another_device"
|
||||||
|
event.Keys.DeviceID = "another_device"
|
||||||
err := updater.Update(ctx, event)
|
err := updater.Update(ctx, event)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Update returned an error: %s", err)
|
t.Fatalf("Update returned an error: %s", err)
|
||||||
|
@ -215,24 +234,25 @@ func TestUpdateNoPrevID(t *testing.T) {
|
||||||
// wait a bit for db to be updated...
|
// wait a bit for db to be updated...
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
want := api.DeviceMessage{
|
want := api.DeviceMessage{
|
||||||
StreamID: 5,
|
StreamID: 5,
|
||||||
DeviceKeys: api.DeviceKeys{
|
DeviceKeys: &gomatrixserverlib.DeviceKeys{},
|
||||||
DeviceID: "JLAFKJWSCS",
|
}
|
||||||
DisplayName: "Mobile Phone",
|
if err := json.Unmarshal([]byte(keyJSON), want.DeviceKeys); err != nil {
|
||||||
UserID: remoteUserID,
|
t.Fatal(err)
|
||||||
KeyJSON: []byte(keyJSON),
|
}
|
||||||
},
|
want.Unsigned = map[string]interface{}{
|
||||||
|
"device_display_name": "Mobile Phone",
|
||||||
}
|
}
|
||||||
// Now we should have a fresh list and the keys and emitted something
|
// Now we should have a fresh list and the keys and emitted something
|
||||||
if db.staleUsers[event.UserID] {
|
if db.staleUsers[event.UserID] {
|
||||||
t.Errorf("%s still marked as stale", event.UserID)
|
t.Errorf("%s still marked as stale", event.UserID)
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(producer.events, []api.DeviceMessage{want}) {
|
if len(producer.events) != 1 {
|
||||||
t.Logf("len got %d len want %d", len(producer.events[0].KeyJSON), len(want.KeyJSON))
|
t.Logf("len got %+v len want 1", len(producer.events))
|
||||||
t.Errorf("Update didn't produce correct event, got %v want %v", producer.events, want)
|
t.Errorf("Update didn't produce correct event, got %v want %v", producer.events, want)
|
||||||
}
|
}
|
||||||
if !reflect.DeepEqual(db.storedKeys, []api.DeviceMessage{want}) {
|
if !reflect.DeepEqual(*db.storedKeys[0].DeviceKeys, *want.DeviceKeys) {
|
||||||
t.Errorf("DB didn't store correct event, got %v want %v", db.storedKeys, want)
|
t.Errorf("DB didn't store correct event\ngot %+v\nwant %+v", *db.storedKeys[0].DeviceKeys, *want.DeviceKeys)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -30,8 +30,6 @@ import (
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/tidwall/gjson"
|
|
||||||
"github.com/tidwall/sjson"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type KeyInternalAPI struct {
|
type KeyInternalAPI struct {
|
||||||
|
@ -210,7 +208,7 @@ func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.Query
|
||||||
// remove deleted devices
|
// remove deleted devices
|
||||||
var result []api.DeviceMessage
|
var result []api.DeviceMessage
|
||||||
for _, m := range msgs {
|
for _, m := range msgs {
|
||||||
if m.KeyJSON == nil {
|
if m.DeviceKeys == nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
result = append(result, m)
|
result = append(result, m)
|
||||||
|
@ -220,7 +218,7 @@ func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.Query
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) {
|
func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) {
|
||||||
res.DeviceKeys = make(map[string]map[string]json.RawMessage)
|
res.DeviceKeys = make(map[string]map[string]gomatrixserverlib.DeviceKeys)
|
||||||
res.Failures = make(map[string]interface{})
|
res.Failures = make(map[string]interface{})
|
||||||
// make a map from domain to device keys
|
// make a map from domain to device keys
|
||||||
domainToDeviceKeys := make(map[string]map[string][]string)
|
domainToDeviceKeys := make(map[string]map[string][]string)
|
||||||
|
@ -254,21 +252,19 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
|
||||||
}
|
}
|
||||||
|
|
||||||
if res.DeviceKeys[userID] == nil {
|
if res.DeviceKeys[userID] == nil {
|
||||||
res.DeviceKeys[userID] = make(map[string]json.RawMessage)
|
res.DeviceKeys[userID] = make(map[string]gomatrixserverlib.DeviceKeys)
|
||||||
}
|
}
|
||||||
for _, dk := range deviceKeys {
|
for _, dk := range deviceKeys {
|
||||||
if len(dk.KeyJSON) == 0 {
|
if dk.DeviceKeys == nil {
|
||||||
continue // don't include blank keys
|
continue // don't include blank keys
|
||||||
}
|
}
|
||||||
// inject display name if known (either locally or remotely)
|
// inject display name if known (either locally or remotely)
|
||||||
displayName := dk.DisplayName
|
displayName := dk.DisplayName()
|
||||||
if queryRes.DeviceInfo[dk.DeviceID].DisplayName != "" {
|
if queryRes.DeviceInfo[dk.DeviceID].DisplayName != "" {
|
||||||
displayName = queryRes.DeviceInfo[dk.DeviceID].DisplayName
|
displayName = queryRes.DeviceInfo[dk.DeviceID].DisplayName
|
||||||
}
|
}
|
||||||
dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct {
|
dk.Unsigned["device_display_name"] = displayName
|
||||||
DisplayName string `json:"device_display_name,omitempty"`
|
res.DeviceKeys[userID][dk.DeviceID] = *dk.DeviceKeys
|
||||||
}{displayName})
|
|
||||||
res.DeviceKeys[userID][dk.DeviceID] = dk.KeyJSON
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
domainToDeviceKeys[domain] = make(map[string][]string)
|
domainToDeviceKeys[domain] = make(map[string][]string)
|
||||||
|
@ -335,13 +331,9 @@ func (a *KeyInternalAPI) queryRemoteKeys(
|
||||||
|
|
||||||
for result := range resultCh {
|
for result := range resultCh {
|
||||||
for userID, nest := range result.DeviceKeys {
|
for userID, nest := range result.DeviceKeys {
|
||||||
res.DeviceKeys[userID] = make(map[string]json.RawMessage)
|
res.DeviceKeys[userID] = make(map[string]gomatrixserverlib.DeviceKeys)
|
||||||
for deviceID, deviceKey := range nest {
|
for deviceID, deviceKey := range nest {
|
||||||
keyJSON, err := json.Marshal(deviceKey)
|
res.DeviceKeys[userID][deviceID] = deviceKey
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
res.DeviceKeys[userID][deviceID] = keyJSON
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -432,18 +424,16 @@ func (a *KeyInternalAPI) populateResponseWithDeviceKeysFromDatabase(
|
||||||
return fmt.Errorf("DeviceKeysForUser %s returned no keys but wanted all keys, falling back to remote", userID)
|
return fmt.Errorf("DeviceKeysForUser %s returned no keys but wanted all keys, falling back to remote", userID)
|
||||||
}
|
}
|
||||||
if res.DeviceKeys[userID] == nil {
|
if res.DeviceKeys[userID] == nil {
|
||||||
res.DeviceKeys[userID] = make(map[string]json.RawMessage)
|
res.DeviceKeys[userID] = make(map[string]gomatrixserverlib.DeviceKeys)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
if len(key.KeyJSON) == 0 {
|
if key.DeviceKeys == nil {
|
||||||
continue // ignore deleted keys
|
continue // ignore deleted keys
|
||||||
}
|
}
|
||||||
// inject the display name
|
// inject the display name
|
||||||
key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct {
|
key.DeviceKeys.Unsigned["device_display_name"] = key.DisplayName()
|
||||||
DisplayName string `json:"device_display_name,omitempty"`
|
res.DeviceKeys[userID][key.DeviceID] = *key.DeviceKeys
|
||||||
}{key.DisplayName})
|
|
||||||
res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON
|
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -459,21 +449,19 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
|
||||||
if serverName != a.ThisServer {
|
if serverName != a.ThisServer {
|
||||||
continue // ignore remote users
|
continue // ignore remote users
|
||||||
}
|
}
|
||||||
if len(key.KeyJSON) == 0 {
|
if len(key.Keys) == 0 {
|
||||||
keysToStore = append(keysToStore, key.WithStreamID(0))
|
keysToStore = append(keysToStore, api.DeviceMessage{DeviceKeys: &key, StreamID: 0})
|
||||||
continue // deleted keys don't need sanity checking
|
continue // deleted keys don't need sanity checking
|
||||||
}
|
}
|
||||||
gotUserID := gjson.GetBytes(key.KeyJSON, "user_id").Str
|
if req.UserID == key.UserID && req.DeviceID == key.DeviceID {
|
||||||
gotDeviceID := gjson.GetBytes(key.KeyJSON, "device_id").Str
|
keysToStore = append(keysToStore, api.DeviceMessage{DeviceKeys: &key, StreamID: 0})
|
||||||
if gotUserID == key.UserID && gotDeviceID == key.DeviceID {
|
|
||||||
keysToStore = append(keysToStore, key.WithStreamID(0))
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
res.KeyError(key.UserID, key.DeviceID, &api.KeyError{
|
res.KeyError(key.UserID, key.DeviceID, &api.KeyError{
|
||||||
Err: fmt.Sprintf(
|
Err: fmt.Sprintf(
|
||||||
"user_id or device_id mismatch: users: %s - %s, devices: %s - %s",
|
"user_id or device_id mismatch: users: %s - %s, devices: %s - %s",
|
||||||
gotUserID, key.UserID, gotDeviceID, key.DeviceID,
|
req.UserID, key.UserID, req.DeviceID, key.DeviceID,
|
||||||
),
|
),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -482,9 +470,11 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
|
||||||
existingKeys := make([]api.DeviceMessage, len(keysToStore))
|
existingKeys := make([]api.DeviceMessage, len(keysToStore))
|
||||||
for i := range keysToStore {
|
for i := range keysToStore {
|
||||||
existingKeys[i] = api.DeviceMessage{
|
existingKeys[i] = api.DeviceMessage{
|
||||||
DeviceKeys: api.DeviceKeys{
|
DeviceKeys: &gomatrixserverlib.DeviceKeys{
|
||||||
UserID: keysToStore[i].UserID,
|
RespUserDeviceKeys: gomatrixserverlib.RespUserDeviceKeys{
|
||||||
DeviceID: keysToStore[i].DeviceID,
|
UserID: keysToStore[i].UserID,
|
||||||
|
DeviceID: keysToStore[i].DeviceID,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -574,9 +564,18 @@ func emitDeviceKeyChanges(producer KeyChangeProducer, existing, new []api.Device
|
||||||
for _, newKey := range new {
|
for _, newKey := range new {
|
||||||
exists := false
|
exists := false
|
||||||
for _, existingKey := range existing {
|
for _, existingKey := range existing {
|
||||||
|
newJSON, err := json.Marshal(newKey)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("json.Marshal(newKey): %w", err)
|
||||||
|
}
|
||||||
|
existingJSON, err := json.Marshal(existingKey)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("json.Marshal(existingKey): %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
// Do not treat the absence of keys as equal, or else we will not emit key changes
|
// Do not treat the absence of keys as equal, or else we will not emit key changes
|
||||||
// when users delete devices which never had a key to begin with as both KeyJSONs are nil.
|
// when users delete devices which never had a key to begin with as both KeyJSONs are nil.
|
||||||
if bytes.Equal(existingKey.KeyJSON, newKey.KeyJSON) && len(existingKey.KeyJSON) > 0 {
|
if bytes.Equal(existingJSON, newJSON) && len(existingJSON) > 0 {
|
||||||
exists = true
|
exists = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
@ -594,7 +593,7 @@ func appendDisplayNames(existing, new []api.DeviceMessage) []api.DeviceMessage {
|
||||||
if existingDevice.DeviceID != newDevice.DeviceID {
|
if existingDevice.DeviceID != newDevice.DeviceID {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
existingDevice.DisplayName = newDevice.DisplayName
|
existingDevice.Unsigned["device_display_name"] = newDevice.DisplayName()
|
||||||
existing[i] = existingDevice
|
existing[i] = existingDevice
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,6 +24,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/keyserver/api"
|
"github.com/matrix-org/dendrite/keyserver/api"
|
||||||
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
var deviceKeysSchema = `
|
var deviceKeysSchema = `
|
||||||
|
@ -105,19 +106,23 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
|
func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
|
||||||
for i, key := range keys {
|
for _, key := range keys {
|
||||||
var keyJSONStr string
|
|
||||||
var streamID int
|
var streamID int
|
||||||
var displayName sql.NullString
|
var displayName sql.NullString
|
||||||
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName)
|
if key.DeviceKeys == nil {
|
||||||
|
key.DeviceKeys = &gomatrixserverlib.DeviceKeys{}
|
||||||
|
}
|
||||||
|
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&key.DeviceKeys, &streamID, &displayName)
|
||||||
if err != nil && err != sql.ErrNoRows {
|
if err != nil && err != sql.ErrNoRows {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// this will be '' when there is no device
|
// this will be '' when there is no device
|
||||||
keys[i].KeyJSON = []byte(keyJSONStr)
|
key.StreamID = streamID
|
||||||
keys[i].StreamID = streamID
|
|
||||||
if displayName.Valid {
|
if displayName.Valid {
|
||||||
keys[i].DisplayName = displayName.String
|
if key.DeviceKeys.Unsigned == nil {
|
||||||
|
key.DeviceKeys.Unsigned = make(map[string]interface{})
|
||||||
|
}
|
||||||
|
key.DeviceKeys.Unsigned["device_display_name"] = displayName.String
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
@ -153,7 +158,7 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
now := time.Now().Unix()
|
now := time.Now().Unix()
|
||||||
_, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext(
|
_, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext(
|
||||||
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName,
|
ctx, key.UserID, key.DeviceID, now, key.DeviceKeys, key.StreamID, key.DisplayName,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -179,18 +184,21 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
|
||||||
}
|
}
|
||||||
var result []api.DeviceMessage
|
var result []api.DeviceMessage
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var dk api.DeviceMessage
|
dk := api.DeviceMessage{
|
||||||
|
DeviceKeys: &gomatrixserverlib.DeviceKeys{},
|
||||||
|
}
|
||||||
dk.UserID = userID
|
dk.UserID = userID
|
||||||
var keyJSON string
|
|
||||||
var streamID int
|
var streamID int
|
||||||
var displayName sql.NullString
|
var displayName sql.NullString
|
||||||
if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil {
|
if err := rows.Scan(&dk.DeviceID, &dk.DeviceKeys, &streamID, &displayName); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
dk.KeyJSON = []byte(keyJSON)
|
|
||||||
dk.StreamID = streamID
|
dk.StreamID = streamID
|
||||||
if displayName.Valid {
|
if displayName.Valid {
|
||||||
dk.DisplayName = displayName.String
|
if dk.DeviceKeys.Unsigned == nil {
|
||||||
|
dk.DeviceKeys.Unsigned = make(map[string]interface{})
|
||||||
|
}
|
||||||
|
dk.DeviceKeys.Unsigned["device_display_name"] = displayName.String
|
||||||
}
|
}
|
||||||
// include the key if we want all keys (no device) or it was asked
|
// include the key if we want all keys (no device) or it was asked
|
||||||
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
|
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
|
||||||
|
|
|
@ -24,6 +24,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/keyserver/api"
|
"github.com/matrix-org/dendrite/keyserver/api"
|
||||||
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
var deviceKeysSchema = `
|
var deviceKeysSchema = `
|
||||||
|
@ -113,18 +114,21 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed")
|
defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed")
|
||||||
var result []api.DeviceMessage
|
var result []api.DeviceMessage
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var dk api.DeviceMessage
|
dk := api.DeviceMessage{
|
||||||
|
DeviceKeys: &gomatrixserverlib.DeviceKeys{},
|
||||||
|
}
|
||||||
dk.UserID = userID
|
dk.UserID = userID
|
||||||
var keyJSON string
|
|
||||||
var streamID int
|
var streamID int
|
||||||
var displayName sql.NullString
|
var displayName sql.NullString
|
||||||
if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil {
|
if err := rows.Scan(&dk.DeviceID, &dk.DeviceKeys, &streamID, &displayName); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
dk.KeyJSON = []byte(keyJSON)
|
|
||||||
dk.StreamID = streamID
|
dk.StreamID = streamID
|
||||||
if displayName.Valid {
|
if displayName.Valid {
|
||||||
dk.DisplayName = displayName.String
|
if dk.DeviceKeys.Unsigned == nil {
|
||||||
|
dk.DeviceKeys.Unsigned = make(map[string]interface{})
|
||||||
|
}
|
||||||
|
dk.DeviceKeys.Unsigned["device_display_name"] = displayName.String
|
||||||
}
|
}
|
||||||
// include the key if we want all keys (no device) or it was asked
|
// include the key if we want all keys (no device) or it was asked
|
||||||
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
|
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
|
||||||
|
@ -135,19 +139,23 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
|
func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
|
||||||
for i, key := range keys {
|
for _, key := range keys {
|
||||||
var keyJSONStr string
|
|
||||||
var streamID int
|
var streamID int
|
||||||
var displayName sql.NullString
|
var displayName sql.NullString
|
||||||
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName)
|
if key.DeviceKeys == nil {
|
||||||
|
key.DeviceKeys = &gomatrixserverlib.DeviceKeys{}
|
||||||
|
}
|
||||||
|
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&key.DeviceKeys, &streamID, &displayName)
|
||||||
if err != nil && err != sql.ErrNoRows {
|
if err != nil && err != sql.ErrNoRows {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
// this will be '' when there is no device
|
// this will be '' when there is no device
|
||||||
keys[i].KeyJSON = []byte(keyJSONStr)
|
key.StreamID = streamID
|
||||||
keys[i].StreamID = streamID
|
|
||||||
if displayName.Valid {
|
if displayName.Valid {
|
||||||
keys[i].DisplayName = displayName.String
|
if key.DeviceKeys.Unsigned == nil {
|
||||||
|
key.DeviceKeys.Unsigned = make(map[string]interface{})
|
||||||
|
}
|
||||||
|
key.DeviceKeys.Unsigned["device_display_name"] = displayName.String
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
@ -189,7 +197,7 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx
|
||||||
for _, key := range keys {
|
for _, key := range keys {
|
||||||
now := time.Now().Unix()
|
now := time.Now().Unix()
|
||||||
_, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext(
|
_, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext(
|
||||||
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName,
|
ctx, key.UserID, key.DeviceID, now, key.DeviceKeys, key.StreamID, key.DisplayName(),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
|
|
@ -12,6 +12,7 @@ import (
|
||||||
"github.com/Shopify/sarama"
|
"github.com/Shopify/sarama"
|
||||||
"github.com/matrix-org/dendrite/keyserver/api"
|
"github.com/matrix-org/dendrite/keyserver/api"
|
||||||
"github.com/matrix-org/dendrite/setup/config"
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ctx = context.Background()
|
var ctx = context.Background()
|
||||||
|
@ -105,28 +106,34 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) {
|
||||||
bob := "@bob:TestDeviceKeysStreamIDGeneration"
|
bob := "@bob:TestDeviceKeysStreamIDGeneration"
|
||||||
msgs := []api.DeviceMessage{
|
msgs := []api.DeviceMessage{
|
||||||
{
|
{
|
||||||
DeviceKeys: api.DeviceKeys{
|
DeviceKeys: &gomatrixserverlib.DeviceKeys{
|
||||||
DeviceID: "AAA",
|
RespUserDeviceKeys: gomatrixserverlib.RespUserDeviceKeys{
|
||||||
UserID: alice,
|
DeviceID: "AAA",
|
||||||
KeyJSON: []byte(`{"key":"v1"}`),
|
UserID: alice,
|
||||||
|
Algorithms: []string{"v1"},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
// StreamID: 1
|
// StreamID: 1
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
DeviceKeys: api.DeviceKeys{
|
DeviceKeys: &gomatrixserverlib.DeviceKeys{
|
||||||
DeviceID: "AAA",
|
RespUserDeviceKeys: gomatrixserverlib.RespUserDeviceKeys{
|
||||||
UserID: bob,
|
DeviceID: "AAA",
|
||||||
KeyJSON: []byte(`{"key":"v1"}`),
|
UserID: bob,
|
||||||
|
Algorithms: []string{"v1"},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
// StreamID: 1 as this is a different user
|
// StreamID: 1 as this is a different user
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
DeviceKeys: api.DeviceKeys{
|
DeviceKeys: &gomatrixserverlib.DeviceKeys{
|
||||||
DeviceID: "another_device",
|
RespUserDeviceKeys: gomatrixserverlib.RespUserDeviceKeys{
|
||||||
UserID: alice,
|
DeviceID: "another_device",
|
||||||
KeyJSON: []byte(`{"key":"v1"}`),
|
UserID: alice,
|
||||||
|
Algorithms: []string{"v2"},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
// StreamID: 2 as this is a 2nd device key
|
// StreamID: 2 as this is a 2nd key
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs))
|
MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs))
|
||||||
|
@ -143,10 +150,12 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) {
|
||||||
// updating a device sets the next stream ID for that user
|
// updating a device sets the next stream ID for that user
|
||||||
msgs = []api.DeviceMessage{
|
msgs = []api.DeviceMessage{
|
||||||
{
|
{
|
||||||
DeviceKeys: api.DeviceKeys{
|
DeviceKeys: &gomatrixserverlib.DeviceKeys{
|
||||||
DeviceID: "AAA",
|
RespUserDeviceKeys: gomatrixserverlib.RespUserDeviceKeys{
|
||||||
UserID: alice,
|
DeviceID: "AAA",
|
||||||
KeyJSON: []byte(`{"key":"v2"}`),
|
UserID: alice,
|
||||||
|
Algorithms: []string{"v3"},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
// StreamID: 3
|
// StreamID: 3
|
||||||
},
|
},
|
||||||
|
|
|
@ -150,12 +150,13 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *UserInternalAPI) deviceListUpdate(userID string, deviceIDs []string) error {
|
func (a *UserInternalAPI) deviceListUpdate(userID string, deviceIDs []string) error {
|
||||||
deviceKeys := make([]keyapi.DeviceKeys, len(deviceIDs))
|
deviceKeys := make([]gomatrixserverlib.DeviceKeys, len(deviceIDs))
|
||||||
for i, did := range deviceIDs {
|
for i, did := range deviceIDs {
|
||||||
deviceKeys[i] = keyapi.DeviceKeys{
|
deviceKeys[i] = gomatrixserverlib.DeviceKeys{
|
||||||
UserID: userID,
|
RespUserDeviceKeys: gomatrixserverlib.RespUserDeviceKeys{
|
||||||
DeviceID: did,
|
UserID: userID,
|
||||||
KeyJSON: nil,
|
DeviceID: did,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -219,12 +220,15 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
|
||||||
var uploadRes keyapi.PerformUploadKeysResponse
|
var uploadRes keyapi.PerformUploadKeysResponse
|
||||||
a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{
|
a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{
|
||||||
UserID: req.RequestingUserID,
|
UserID: req.RequestingUserID,
|
||||||
DeviceKeys: []keyapi.DeviceKeys{
|
DeviceKeys: []gomatrixserverlib.DeviceKeys{
|
||||||
{
|
{
|
||||||
DeviceID: dev.ID,
|
RespUserDeviceKeys: gomatrixserverlib.RespUserDeviceKeys{
|
||||||
DisplayName: *req.DisplayName,
|
DeviceID: dev.ID,
|
||||||
KeyJSON: nil,
|
UserID: dev.UserID,
|
||||||
UserID: dev.UserID,
|
},
|
||||||
|
Unsigned: map[string]interface{}{
|
||||||
|
"device_display_name": *req.DisplayName,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
OnlyDisplayNameUpdates: true,
|
OnlyDisplayNameUpdates: true,
|
||||||
|
|
Loading…
Reference in a new issue