diff --git a/keyserver/internal/device_list_update.go b/keyserver/internal/device_list_update.go index de239271..962d43aa 100644 --- a/keyserver/internal/device_list_update.go +++ b/keyserver/internal/device_list_update.go @@ -214,9 +214,15 @@ func (u *DeviceListUpdater) update(ctx context.Context, event gomatrixserverlib. "prev_ids": event.PrevID, "display_name": event.DeviceDisplayName, "deleted": event.Deleted, - "keys": event.Keys, }).Info("DeviceListUpdater.Update") + event.Keys.Unsigned = map[string]interface{}{ + "device_display_name": event.DeviceDisplayName, + } + + fmt.Println("Display name:", event.DeviceDisplayName) + fmt.Println("Key display unsigned:", event.Keys.Unsigned) + // if we haven't missed anything update the database and notify users if exists { keys := []api.DeviceMessage{ @@ -385,6 +391,14 @@ func (u *DeviceListUpdater) updateDeviceList(res *gomatrixserverlib.RespUserDevi RespUserDeviceKeys: device.Keys, }, } + if device.DeviceID != "" { + keys[i].DeviceKeys.DeviceID = device.DeviceID + } + if device.DisplayName != "" { + keys[i].DeviceKeys.Unsigned = map[string]interface{}{ + "device_display_name": device.DisplayName, + } + } existingKeys[i] = api.DeviceMessage{ DeviceKeys: &gomatrixserverlib.DeviceKeys{ RespUserDeviceKeys: device.Keys, @@ -399,6 +413,9 @@ func (u *DeviceListUpdater) updateDeviceList(res *gomatrixserverlib.RespUserDevi ) } + fmt.Println("EXISTING KEYS:", existingKeys[0].Unsigned) + fmt.Println("NEW KEYS:", keys[0].Unsigned) + err := u.db.StoreRemoteDeviceKeys(ctx, keys, []string{res.UserID}) if err != nil { return fmt.Errorf("failed to store remote device keys: %w", err) diff --git a/keyserver/internal/device_list_update_test.go b/keyserver/internal/device_list_update_test.go index 344d091b..923bf570 100644 --- a/keyserver/internal/device_list_update_test.go +++ b/keyserver/internal/device_list_update_test.go @@ -158,13 +158,16 @@ func TestUpdateHavePrevID(t *testing.T) { "key": {1, 2, 3, 4, 5, 6}, }, }, + Unsigned: map[string]interface{}{ + "device_display_name": "Foo Bar", + }, }, } if !reflect.DeepEqual(producer.events, []api.DeviceMessage{want}) { t.Errorf("Update didn't produce correct event, got %v want %v", producer.events, want) } - if !reflect.DeepEqual(db.storedKeys, []api.DeviceMessage{want}) { - t.Errorf("DB didn't store correct event, got %v want %v", db.storedKeys, want) + if !reflect.DeepEqual(*db.storedKeys[0].DeviceKeys, *want.DeviceKeys) { + t.Errorf("DB didn't store correct event\ngot %+v\nwant %+v", *db.storedKeys[0].DeviceKeys, *want.DeviceKeys) } if db.staleUsers[event.UserID] { t.Errorf("%s incorrectly marked as stale", event.UserID) @@ -223,8 +226,6 @@ func TestUpdateNoPrevID(t *testing.T) { if err := json.Unmarshal([]byte(keyJSON), event.Keys); err != nil { t.Fatal(err) } - event.DeviceID = "another_device" - event.Keys.DeviceID = "another_device" err := updater.Update(ctx, event) if err != nil { t.Fatalf("Update returned an error: %s", err) diff --git a/keyserver/storage/postgres/device_keys_table.go b/keyserver/storage/postgres/device_keys_table.go index 243fd2f7..fb7b3658 100644 --- a/keyserver/storage/postgres/device_keys_table.go +++ b/keyserver/storage/postgres/device_keys_table.go @@ -158,7 +158,7 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx for _, key := range keys { now := time.Now().Unix() _, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext( - ctx, key.UserID, key.DeviceID, now, key.DeviceKeys, key.StreamID, key.DisplayName, + ctx, key.UserID, key.DeviceID, now, key.DeviceKeys, key.StreamID, key.DisplayName(), ) if err != nil { return err