mirror of
https://github.com/hoernschen/dendrite.git
synced 2025-07-31 13:22:46 +00:00
Merge keyserver & userapi (#2972)
As discussed yesterday, a first draft of merging the keyserver and the userapi.
This commit is contained in:
parent
bd6f0c14e5
commit
4594233f89
107 changed files with 1730 additions and 1863 deletions
587
userapi/internal/cross_signing.go
Normal file
587
userapi/internal/cross_signing.go
Normal file
|
@ -0,0 +1,587 @@
|
|||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/curve25519"
|
||||
)
|
||||
|
||||
func sanityCheckKey(key gomatrixserverlib.CrossSigningKey, userID string, purpose gomatrixserverlib.CrossSigningKeyPurpose) error {
|
||||
// Is there exactly one key?
|
||||
if len(key.Keys) != 1 {
|
||||
return fmt.Errorf("should contain exactly one key")
|
||||
}
|
||||
|
||||
// Does the key ID match the key value? Iterates exactly once
|
||||
for keyID, keyData := range key.Keys {
|
||||
b64 := keyData.Encode()
|
||||
tokens := strings.Split(string(keyID), ":")
|
||||
if len(tokens) != 2 {
|
||||
return fmt.Errorf("key ID is incorrectly formatted")
|
||||
}
|
||||
if tokens[1] != b64 {
|
||||
return fmt.Errorf("key ID isn't correct")
|
||||
}
|
||||
switch tokens[0] {
|
||||
case "ed25519":
|
||||
if len(keyData) != ed25519.PublicKeySize {
|
||||
return fmt.Errorf("ed25519 key is not the correct length")
|
||||
}
|
||||
case "curve25519":
|
||||
if len(keyData) != curve25519.PointSize {
|
||||
return fmt.Errorf("curve25519 key is not the correct length")
|
||||
}
|
||||
default:
|
||||
// We can't enforce the key length to be correct for an
|
||||
// algorithm that we don't recognise, so instead we'll
|
||||
// just make sure that it isn't incredibly excessive.
|
||||
if l := len(keyData); l > 4096 {
|
||||
return fmt.Errorf("unknown key type is too long (%d bytes)", l)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check to see if the signatures make sense
|
||||
for _, forOriginUser := range key.Signatures {
|
||||
for originKeyID, originSignature := range forOriginUser {
|
||||
switch strings.SplitN(string(originKeyID), ":", 1)[0] {
|
||||
case "ed25519":
|
||||
if len(originSignature) != ed25519.SignatureSize {
|
||||
return fmt.Errorf("ed25519 signature is not the correct length")
|
||||
}
|
||||
case "curve25519":
|
||||
return fmt.Errorf("curve25519 signatures are impossible")
|
||||
default:
|
||||
if l := len(originSignature); l > 4096 {
|
||||
return fmt.Errorf("unknown signature type is too long (%d bytes)", l)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Does the key claim to be from the right user?
|
||||
if userID != key.UserID {
|
||||
return fmt.Errorf("key has a user ID mismatch")
|
||||
}
|
||||
|
||||
// Does the key contain the correct purpose?
|
||||
useful := false
|
||||
for _, usage := range key.Usage {
|
||||
if usage == purpose {
|
||||
useful = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !useful {
|
||||
return fmt.Errorf("key does not contain correct usage purpose")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// nolint:gocyclo
|
||||
func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error {
|
||||
// Find the keys to store.
|
||||
byPurpose := map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey{}
|
||||
toStore := types.CrossSigningKeyMap{}
|
||||
hasMasterKey := false
|
||||
|
||||
if len(req.MasterKey.Keys) > 0 {
|
||||
if err := sanityCheckKey(req.MasterKey, req.UserID, gomatrixserverlib.CrossSigningKeyPurposeMaster); err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: "Master key sanity check failed: " + err.Error(),
|
||||
IsInvalidParam: true,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
byPurpose[gomatrixserverlib.CrossSigningKeyPurposeMaster] = req.MasterKey
|
||||
for _, key := range req.MasterKey.Keys { // iterates once, see sanityCheckKey
|
||||
toStore[gomatrixserverlib.CrossSigningKeyPurposeMaster] = key
|
||||
}
|
||||
hasMasterKey = true
|
||||
}
|
||||
|
||||
if len(req.SelfSigningKey.Keys) > 0 {
|
||||
if err := sanityCheckKey(req.SelfSigningKey, req.UserID, gomatrixserverlib.CrossSigningKeyPurposeSelfSigning); err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: "Self-signing key sanity check failed: " + err.Error(),
|
||||
IsInvalidParam: true,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
byPurpose[gomatrixserverlib.CrossSigningKeyPurposeSelfSigning] = req.SelfSigningKey
|
||||
for _, key := range req.SelfSigningKey.Keys { // iterates once, see sanityCheckKey
|
||||
toStore[gomatrixserverlib.CrossSigningKeyPurposeSelfSigning] = key
|
||||
}
|
||||
}
|
||||
|
||||
if len(req.UserSigningKey.Keys) > 0 {
|
||||
if err := sanityCheckKey(req.UserSigningKey, req.UserID, gomatrixserverlib.CrossSigningKeyPurposeUserSigning); err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: "User-signing key sanity check failed: " + err.Error(),
|
||||
IsInvalidParam: true,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
byPurpose[gomatrixserverlib.CrossSigningKeyPurposeUserSigning] = req.UserSigningKey
|
||||
for _, key := range req.UserSigningKey.Keys { // iterates once, see sanityCheckKey
|
||||
toStore[gomatrixserverlib.CrossSigningKeyPurposeUserSigning] = key
|
||||
}
|
||||
}
|
||||
|
||||
// If there's nothing to do then stop here.
|
||||
if len(toStore) == 0 {
|
||||
res.Error = &api.KeyError{
|
||||
Err: "No keys were supplied in the request",
|
||||
IsMissingParam: true,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// We can't have a self-signing or user-signing key without a master
|
||||
// key, so make sure we have one of those. We will also only actually do
|
||||
// something if any of the specified keys in the request are different
|
||||
// to what we've got in the database, to avoid generating key change
|
||||
// notifications unnecessarily.
|
||||
existingKeys, err := a.KeyDatabase.CrossSigningKeysDataForUser(ctx, req.UserID)
|
||||
if err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: "Retrieving cross-signing keys from database failed: " + err.Error(),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// If we still can't find a master key for the user then stop the upload.
|
||||
// This satisfies the "Fails to upload self-signing key without master key" test.
|
||||
if !hasMasterKey {
|
||||
if _, hasMasterKey = existingKeys[gomatrixserverlib.CrossSigningKeyPurposeMaster]; !hasMasterKey {
|
||||
res.Error = &api.KeyError{
|
||||
Err: "No master key was found",
|
||||
IsMissingParam: true,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Check if anything actually changed compared to what we have in the database.
|
||||
changed := false
|
||||
for _, purpose := range []gomatrixserverlib.CrossSigningKeyPurpose{
|
||||
gomatrixserverlib.CrossSigningKeyPurposeMaster,
|
||||
gomatrixserverlib.CrossSigningKeyPurposeSelfSigning,
|
||||
gomatrixserverlib.CrossSigningKeyPurposeUserSigning,
|
||||
} {
|
||||
old, gotOld := existingKeys[purpose]
|
||||
new, gotNew := toStore[purpose]
|
||||
if gotOld != gotNew {
|
||||
// A new key purpose has been specified that we didn't know before,
|
||||
// or one has been removed.
|
||||
changed = true
|
||||
break
|
||||
}
|
||||
if !bytes.Equal(old, new) {
|
||||
// One of the existing keys for a purpose we already knew about has
|
||||
// changed.
|
||||
changed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !changed {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Store the keys.
|
||||
if err := a.KeyDatabase.StoreCrossSigningKeysForUser(ctx, req.UserID, toStore); err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("a.DB.StoreCrossSigningKeysForUser: %s", err),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Now upload any signatures that were included with the keys.
|
||||
for _, key := range byPurpose {
|
||||
var targetKeyID gomatrixserverlib.KeyID
|
||||
for targetKey := range key.Keys { // iterates once, see sanityCheckKey
|
||||
targetKeyID = targetKey
|
||||
}
|
||||
for sigUserID, forSigUserID := range key.Signatures {
|
||||
if sigUserID != req.UserID {
|
||||
continue
|
||||
}
|
||||
for sigKeyID, sigBytes := range forSigUserID {
|
||||
if err := a.KeyDatabase.StoreCrossSigningSigsForTarget(ctx, sigUserID, sigKeyID, req.UserID, targetKeyID, sigBytes); err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("a.DB.StoreCrossSigningSigsForTarget: %s", err),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Finally, generate a notification that we updated the keys.
|
||||
update := api.CrossSigningKeyUpdate{
|
||||
UserID: req.UserID,
|
||||
}
|
||||
if mk, ok := byPurpose[gomatrixserverlib.CrossSigningKeyPurposeMaster]; ok {
|
||||
update.MasterKey = &mk
|
||||
}
|
||||
if ssk, ok := byPurpose[gomatrixserverlib.CrossSigningKeyPurposeSelfSigning]; ok {
|
||||
update.SelfSigningKey = &ssk
|
||||
}
|
||||
if update.MasterKey == nil && update.SelfSigningKey == nil {
|
||||
return nil
|
||||
}
|
||||
if err := a.KeyChangeProducer.ProduceSigningKeyUpdate(update); err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req *api.PerformUploadDeviceSignaturesRequest, res *api.PerformUploadDeviceSignaturesResponse) error {
|
||||
// Before we do anything, we need the master and self-signing keys for this user.
|
||||
// Then we can verify the signatures make sense.
|
||||
queryReq := &api.QueryKeysRequest{
|
||||
UserID: req.UserID,
|
||||
UserToDevices: map[string][]string{},
|
||||
}
|
||||
queryRes := &api.QueryKeysResponse{}
|
||||
for userID := range req.Signatures {
|
||||
queryReq.UserToDevices[userID] = []string{}
|
||||
}
|
||||
_ = a.QueryKeys(ctx, queryReq, queryRes)
|
||||
|
||||
selfSignatures := map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{}
|
||||
otherSignatures := map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{}
|
||||
|
||||
// Sort signatures into two groups: one where people have signed their own
|
||||
// keys and one where people have signed someone elses
|
||||
for userID, forUserID := range req.Signatures {
|
||||
for keyID, keyOrDevice := range forUserID {
|
||||
switch key := keyOrDevice.CrossSigningBody.(type) {
|
||||
case *gomatrixserverlib.CrossSigningKey:
|
||||
if key.UserID == req.UserID {
|
||||
if _, ok := selfSignatures[userID]; !ok {
|
||||
selfSignatures[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{}
|
||||
}
|
||||
selfSignatures[userID][keyID] = keyOrDevice
|
||||
} else {
|
||||
if _, ok := otherSignatures[userID]; !ok {
|
||||
otherSignatures[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{}
|
||||
}
|
||||
otherSignatures[userID][keyID] = keyOrDevice
|
||||
}
|
||||
|
||||
case *gomatrixserverlib.DeviceKeys:
|
||||
if key.UserID == req.UserID {
|
||||
if _, ok := selfSignatures[userID]; !ok {
|
||||
selfSignatures[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{}
|
||||
}
|
||||
selfSignatures[userID][keyID] = keyOrDevice
|
||||
} else {
|
||||
if _, ok := otherSignatures[userID]; !ok {
|
||||
otherSignatures[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{}
|
||||
}
|
||||
otherSignatures[userID][keyID] = keyOrDevice
|
||||
}
|
||||
|
||||
default:
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := a.processSelfSignatures(ctx, selfSignatures); err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("a.processSelfSignatures: %s", err),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := a.processOtherSignatures(ctx, req.UserID, queryRes, otherSignatures); err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("a.processOtherSignatures: %s", err),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Finally, generate a notification that we updated the signatures.
|
||||
for userID := range req.Signatures {
|
||||
masterKey := queryRes.MasterKeys[userID]
|
||||
selfSigningKey := queryRes.SelfSigningKeys[userID]
|
||||
update := api.CrossSigningKeyUpdate{
|
||||
UserID: userID,
|
||||
MasterKey: &masterKey,
|
||||
SelfSigningKey: &selfSigningKey,
|
||||
}
|
||||
if err := a.KeyChangeProducer.ProduceSigningKeyUpdate(update); err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) processSelfSignatures(
|
||||
ctx context.Context,
|
||||
signatures map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice,
|
||||
) error {
|
||||
// Here we will process:
|
||||
// * The user signing their own devices using their self-signing key
|
||||
// * The user signing their master key using one of their devices
|
||||
|
||||
for targetUserID, forTargetUserID := range signatures {
|
||||
for targetKeyID, signature := range forTargetUserID {
|
||||
switch sig := signature.CrossSigningBody.(type) {
|
||||
case *gomatrixserverlib.CrossSigningKey:
|
||||
for keyID := range sig.Keys {
|
||||
split := strings.SplitN(string(keyID), ":", 2)
|
||||
if len(split) > 1 && gomatrixserverlib.KeyID(split[1]) == targetKeyID {
|
||||
targetKeyID = keyID // contains the ed25519: or other scheme
|
||||
break
|
||||
}
|
||||
}
|
||||
for originUserID, forOriginUserID := range sig.Signatures {
|
||||
for originKeyID, originSig := range forOriginUserID {
|
||||
if err := a.KeyDatabase.StoreCrossSigningSigsForTarget(
|
||||
ctx, originUserID, originKeyID, targetUserID, targetKeyID, originSig,
|
||||
); err != nil {
|
||||
return fmt.Errorf("a.DB.StoreCrossSigningKeysForTarget: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case *gomatrixserverlib.DeviceKeys:
|
||||
for originUserID, forOriginUserID := range sig.Signatures {
|
||||
for originKeyID, originSig := range forOriginUserID {
|
||||
if err := a.KeyDatabase.StoreCrossSigningSigsForTarget(
|
||||
ctx, originUserID, originKeyID, targetUserID, targetKeyID, originSig,
|
||||
); err != nil {
|
||||
return fmt.Errorf("a.DB.StoreCrossSigningKeysForTarget: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unexpected type assertion")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) processOtherSignatures(
|
||||
ctx context.Context, userID string, queryRes *api.QueryKeysResponse,
|
||||
signatures map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice,
|
||||
) error {
|
||||
// Here we will process:
|
||||
// * A user signing someone else's master keys using their user-signing keys
|
||||
|
||||
for targetUserID, forTargetUserID := range signatures {
|
||||
for _, signature := range forTargetUserID {
|
||||
switch sig := signature.CrossSigningBody.(type) {
|
||||
case *gomatrixserverlib.CrossSigningKey:
|
||||
// Find the local copy of the master key. We'll use this to be
|
||||
// sure that the supplied stanza matches the key that we think it
|
||||
// should be.
|
||||
masterKey, ok := queryRes.MasterKeys[targetUserID]
|
||||
if !ok {
|
||||
return fmt.Errorf("failed to find master key for user %q", targetUserID)
|
||||
}
|
||||
|
||||
// For each key ID, write the signatures. Maybe there'll be more
|
||||
// than one algorithm in the future so it's best not to focus on
|
||||
// everything being ed25519:.
|
||||
for targetKeyID, suppliedKeyData := range sig.Keys {
|
||||
// The master key will be supplied in the request, but we should
|
||||
// make sure that it matches what we think the master key should
|
||||
// actually be.
|
||||
localKeyData, lok := masterKey.Keys[targetKeyID]
|
||||
if !lok {
|
||||
return fmt.Errorf("uploaded master key %q for user %q doesn't match local copy", targetKeyID, targetUserID)
|
||||
} else if !bytes.Equal(suppliedKeyData, localKeyData) {
|
||||
return fmt.Errorf("uploaded master key %q for user %q doesn't match local copy", targetKeyID, targetUserID)
|
||||
}
|
||||
|
||||
// We only care about the signatures from the uploading user, so
|
||||
// we will ignore anything that didn't originate from them.
|
||||
userSigs, ok := sig.Signatures[userID]
|
||||
if !ok {
|
||||
return fmt.Errorf("there are no signatures on master key %q from uploading user %q", targetKeyID, userID)
|
||||
}
|
||||
|
||||
for originKeyID, originSig := range userSigs {
|
||||
if err := a.KeyDatabase.StoreCrossSigningSigsForTarget(
|
||||
ctx, userID, originKeyID, targetUserID, targetKeyID, originSig,
|
||||
); err != nil {
|
||||
return fmt.Errorf("a.DB.StoreCrossSigningKeysForTarget: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
// Users should only be signing another person's master key,
|
||||
// so if we're here, it's probably because it's actually a
|
||||
// gomatrixserverlib.DeviceKeys, which doesn't make sense.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) crossSigningKeysFromDatabase(
|
||||
ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse,
|
||||
) {
|
||||
for targetUserID := range req.UserToDevices {
|
||||
keys, err := a.KeyDatabase.CrossSigningKeysForUser(ctx, targetUserID)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Errorf("Failed to get cross-signing keys for user %q", targetUserID)
|
||||
continue
|
||||
}
|
||||
|
||||
for keyType, key := range keys {
|
||||
var keyID gomatrixserverlib.KeyID
|
||||
for id := range key.Keys {
|
||||
keyID = id
|
||||
break
|
||||
}
|
||||
|
||||
sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, keyID)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
logrus.WithError(err).Errorf("Failed to get cross-signing signatures for user %q key %q", targetUserID, keyID)
|
||||
continue
|
||||
}
|
||||
|
||||
appendSignature := func(originUserID string, originKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) {
|
||||
if key.Signatures == nil {
|
||||
key.Signatures = types.CrossSigningSigMap{}
|
||||
}
|
||||
if _, ok := key.Signatures[originUserID]; !ok {
|
||||
key.Signatures[originUserID] = make(map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes)
|
||||
}
|
||||
key.Signatures[originUserID][originKeyID] = signature
|
||||
}
|
||||
|
||||
for originUserID, forOrigin := range sigMap {
|
||||
for originKeyID, signature := range forOrigin {
|
||||
switch {
|
||||
case req.UserID != "" && originUserID == req.UserID:
|
||||
// Include signatures that we created
|
||||
appendSignature(originUserID, originKeyID, signature)
|
||||
case originUserID == targetUserID:
|
||||
// Include signatures that were created by the person whose key
|
||||
// we are processing
|
||||
appendSignature(originUserID, originKeyID, signature)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch keyType {
|
||||
case gomatrixserverlib.CrossSigningKeyPurposeMaster:
|
||||
res.MasterKeys[targetUserID] = key
|
||||
|
||||
case gomatrixserverlib.CrossSigningKeyPurposeSelfSigning:
|
||||
res.SelfSigningKeys[targetUserID] = key
|
||||
|
||||
case gomatrixserverlib.CrossSigningKeyPurposeUserSigning:
|
||||
res.UserSigningKeys[targetUserID] = key
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySignaturesRequest, res *api.QuerySignaturesResponse) error {
|
||||
for targetUserID, forTargetUser := range req.TargetIDs {
|
||||
keyMap, err := a.KeyDatabase.CrossSigningKeysForUser(ctx, targetUserID)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("a.DB.CrossSigningKeysForUser: %s", err),
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
for targetPurpose, targetKey := range keyMap {
|
||||
switch targetPurpose {
|
||||
case gomatrixserverlib.CrossSigningKeyPurposeMaster:
|
||||
if res.MasterKeys == nil {
|
||||
res.MasterKeys = map[string]gomatrixserverlib.CrossSigningKey{}
|
||||
}
|
||||
res.MasterKeys[targetUserID] = targetKey
|
||||
|
||||
case gomatrixserverlib.CrossSigningKeyPurposeSelfSigning:
|
||||
if res.SelfSigningKeys == nil {
|
||||
res.SelfSigningKeys = map[string]gomatrixserverlib.CrossSigningKey{}
|
||||
}
|
||||
res.SelfSigningKeys[targetUserID] = targetKey
|
||||
|
||||
case gomatrixserverlib.CrossSigningKeyPurposeUserSigning:
|
||||
if res.UserSigningKeys == nil {
|
||||
res.UserSigningKeys = map[string]gomatrixserverlib.CrossSigningKey{}
|
||||
}
|
||||
res.UserSigningKeys[targetUserID] = targetKey
|
||||
}
|
||||
}
|
||||
|
||||
for _, targetKeyID := range forTargetUser {
|
||||
// Get own signatures only.
|
||||
sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, targetUserID, targetUserID, targetKeyID)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("a.DB.CrossSigningSigsForTarget: %s", err),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
for sourceUserID, forSourceUser := range sigMap {
|
||||
for sourceKeyID, sourceSig := range forSourceUser {
|
||||
if res.Signatures == nil {
|
||||
res.Signatures = map[string]map[gomatrixserverlib.KeyID]types.CrossSigningSigMap{}
|
||||
}
|
||||
if _, ok := res.Signatures[targetUserID]; !ok {
|
||||
res.Signatures[targetUserID] = map[gomatrixserverlib.KeyID]types.CrossSigningSigMap{}
|
||||
}
|
||||
if _, ok := res.Signatures[targetUserID][targetKeyID]; !ok {
|
||||
res.Signatures[targetUserID][targetKeyID] = types.CrossSigningSigMap{}
|
||||
}
|
||||
if _, ok := res.Signatures[targetUserID][targetKeyID][sourceUserID]; !ok {
|
||||
res.Signatures[targetUserID][targetKeyID][sourceUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
|
||||
}
|
||||
res.Signatures[targetUserID][targetKeyID][sourceUserID][sourceKeyID] = sourceSig
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
579
userapi/internal/device_list_update.go
Normal file
579
userapi/internal/device_list_update.go
Normal file
|
@ -0,0 +1,579 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
rsapi "github.com/matrix-org/dendrite/roomserver/api"
|
||||
|
||||
"github.com/matrix-org/gomatrix"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
fedsenderapi "github.com/matrix-org/dendrite/federationapi/api"
|
||||
"github.com/matrix-org/dendrite/setup/process"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
)
|
||||
|
||||
var (
|
||||
deviceListUpdateCount = prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Namespace: "dendrite",
|
||||
Subsystem: "keyserver",
|
||||
Name: "device_list_update",
|
||||
Help: "Number of times we have attempted to update device lists from this server",
|
||||
},
|
||||
[]string{"server"},
|
||||
)
|
||||
)
|
||||
|
||||
const requestTimeout = time.Second * 30
|
||||
|
||||
func init() {
|
||||
prometheus.MustRegister(
|
||||
deviceListUpdateCount,
|
||||
)
|
||||
}
|
||||
|
||||
// DeviceListUpdater handles device list updates from remote servers.
|
||||
//
|
||||
// In the case where we have the prev_id for an update, the updater just stores the update (after acquiring a per-user lock).
|
||||
// In the case where we do not have the prev_id for an update, the updater marks the user_id as stale and notifies
|
||||
// a worker to get the latest device list for this user. Note: stream IDs are scoped per user so missing a prev_id
|
||||
// for a (user, device) does not mean that DEVICE is outdated as the previous ID could be for a different device:
|
||||
// we have to invalidate all devices for that user. Once the list has been fetched, the per-user lock is acquired and the
|
||||
// updater stores the latest list along with the latest stream ID.
|
||||
//
|
||||
// On startup, the updater spins up N workers which are responsible for querying device keys from remote servers.
|
||||
// Workers are scoped by homeserver domain, with one worker responsible for many domains, determined by hashing
|
||||
// mod N the server name. Work is sent via a channel which just serves to "poke" the worker as the data is retrieved
|
||||
// from the database (which allows us to batch requests to the same server). This has a number of desirable properties:
|
||||
// - We guarantee only 1 in-flight /keys/query request per server at any time as there is exactly 1 worker responsible
|
||||
// for that domain.
|
||||
// - We don't have unbounded growth in proportion to the number of servers (this is more important in a P2P world where
|
||||
// we have many many servers)
|
||||
// - We can adjust concurrency (at the cost of memory usage) by tuning N, to accommodate mobile devices vs servers.
|
||||
//
|
||||
// The downsides are that:
|
||||
// - Query requests can get queued behind other servers if they hash to the same worker, even if there are other free
|
||||
// workers elsewhere. Whilst suboptimal, provided we cap how long a single request can last (e.g using context timeouts)
|
||||
// we guarantee we will get around to it. Also, more users on a given server does not increase the number of requests
|
||||
// (as /keys/query allows multiple users to be specified) so being stuck behind matrix.org won't materially be any worse
|
||||
// than being stuck behind foo.bar
|
||||
//
|
||||
// In the event that the query fails, a lock is acquired and the server name along with the time to wait before retrying is
|
||||
// set in a map. A restarter goroutine periodically probes this map and injects servers which are ready to be retried.
|
||||
type DeviceListUpdater struct {
|
||||
process *process.ProcessContext
|
||||
// A map from user_id to a mutex. Used when we are missing prev IDs so we don't make more than 1
|
||||
// request to the remote server and race.
|
||||
// TODO: Put in an LRU cache to bound growth
|
||||
userIDToMutex map[string]*sync.Mutex
|
||||
mu *sync.Mutex // protects UserIDToMutex
|
||||
|
||||
db DeviceListUpdaterDatabase
|
||||
api DeviceListUpdaterAPI
|
||||
producer KeyChangeProducer
|
||||
fedClient fedsenderapi.KeyserverFederationAPI
|
||||
workerChans []chan gomatrixserverlib.ServerName
|
||||
thisServer gomatrixserverlib.ServerName
|
||||
|
||||
// When device lists are stale for a user, they get inserted into this map with a channel which `Update` will
|
||||
// block on or timeout via a select.
|
||||
userIDToChan map[string]chan bool
|
||||
userIDToChanMu *sync.Mutex
|
||||
rsAPI rsapi.KeyserverRoomserverAPI
|
||||
}
|
||||
|
||||
// DeviceListUpdaterDatabase is the subset of functionality from storage.Database required for the updater.
|
||||
// Useful for testing.
|
||||
type DeviceListUpdaterDatabase interface {
|
||||
// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
|
||||
// If no domains are given, all user IDs with stale device lists are returned.
|
||||
StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error)
|
||||
|
||||
// MarkDeviceListStale sets the stale bit for this user to isStale.
|
||||
MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error
|
||||
|
||||
// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
|
||||
// for this (user, device). Does not modify the stream ID for keys. User IDs in `clearUserIDs` will have all their device keys deleted prior
|
||||
// to insertion - use this when you have a complete snapshot of a user's keys in order to track device deletions correctly.
|
||||
StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error
|
||||
|
||||
// PrevIDsExists returns true if all prev IDs exist for this user.
|
||||
PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error)
|
||||
|
||||
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
|
||||
DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
|
||||
|
||||
DeleteStaleDeviceLists(ctx context.Context, userIDs []string) error
|
||||
}
|
||||
|
||||
type DeviceListUpdaterAPI interface {
|
||||
PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error
|
||||
}
|
||||
|
||||
// KeyChangeProducer is the interface for producers.KeyChange useful for testing.
|
||||
type KeyChangeProducer interface {
|
||||
ProduceKeyChanges(keys []api.DeviceMessage) error
|
||||
}
|
||||
|
||||
// NewDeviceListUpdater creates a new updater which fetches fresh device lists when they go stale.
|
||||
func NewDeviceListUpdater(
|
||||
process *process.ProcessContext, db DeviceListUpdaterDatabase,
|
||||
api DeviceListUpdaterAPI, producer KeyChangeProducer,
|
||||
fedClient fedsenderapi.KeyserverFederationAPI, numWorkers int,
|
||||
rsAPI rsapi.KeyserverRoomserverAPI, thisServer gomatrixserverlib.ServerName,
|
||||
) *DeviceListUpdater {
|
||||
return &DeviceListUpdater{
|
||||
process: process,
|
||||
userIDToMutex: make(map[string]*sync.Mutex),
|
||||
mu: &sync.Mutex{},
|
||||
db: db,
|
||||
api: api,
|
||||
producer: producer,
|
||||
fedClient: fedClient,
|
||||
thisServer: thisServer,
|
||||
workerChans: make([]chan gomatrixserverlib.ServerName, numWorkers),
|
||||
userIDToChan: make(map[string]chan bool),
|
||||
userIDToChanMu: &sync.Mutex{},
|
||||
rsAPI: rsAPI,
|
||||
}
|
||||
}
|
||||
|
||||
// Start the device list updater, which will try to refresh any stale device lists.
|
||||
func (u *DeviceListUpdater) Start() error {
|
||||
for i := 0; i < len(u.workerChans); i++ {
|
||||
// Allocate a small buffer per channel.
|
||||
// If the buffer limit is reached, backpressure will cause the processing of EDUs
|
||||
// to stop (in this transaction) until key requests can be made.
|
||||
ch := make(chan gomatrixserverlib.ServerName, 10)
|
||||
u.workerChans[i] = ch
|
||||
go u.worker(ch)
|
||||
}
|
||||
|
||||
staleLists, err := u.db.StaleDeviceLists(u.process.Context(), []gomatrixserverlib.ServerName{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
offset, step := time.Second*10, time.Second
|
||||
if max := len(staleLists); max > 120 {
|
||||
step = (time.Second * 120) / time.Duration(max)
|
||||
}
|
||||
for _, userID := range staleLists {
|
||||
userID := userID // otherwise we are only sending the last entry
|
||||
time.AfterFunc(offset, func() {
|
||||
u.notifyWorkers(userID)
|
||||
})
|
||||
offset += step
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CleanUp removes stale device entries for users we don't share a room with anymore
|
||||
func (u *DeviceListUpdater) CleanUp() error {
|
||||
staleUsers, err := u.db.StaleDeviceLists(u.process.Context(), []gomatrixserverlib.ServerName{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
res := rsapi.QueryLeftUsersResponse{}
|
||||
if err = u.rsAPI.QueryLeftUsers(u.process.Context(), &rsapi.QueryLeftUsersRequest{StaleDeviceListUsers: staleUsers}, &res); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(res.LeftUsers) == 0 {
|
||||
return nil
|
||||
}
|
||||
logrus.Debugf("Deleting %d stale device list entries", len(res.LeftUsers))
|
||||
return u.db.DeleteStaleDeviceLists(u.process.Context(), res.LeftUsers)
|
||||
}
|
||||
|
||||
func (u *DeviceListUpdater) mutex(userID string) *sync.Mutex {
|
||||
u.mu.Lock()
|
||||
defer u.mu.Unlock()
|
||||
if u.userIDToMutex[userID] == nil {
|
||||
u.userIDToMutex[userID] = &sync.Mutex{}
|
||||
}
|
||||
return u.userIDToMutex[userID]
|
||||
}
|
||||
|
||||
// ManualUpdate invalidates the device list for the given user and fetches the latest and tracks it.
|
||||
// Blocks until the device list is synced or the timeout is reached.
|
||||
func (u *DeviceListUpdater) ManualUpdate(ctx context.Context, serverName gomatrixserverlib.ServerName, userID string) error {
|
||||
mu := u.mutex(userID)
|
||||
mu.Lock()
|
||||
err := u.db.MarkDeviceListStale(ctx, userID, true)
|
||||
mu.Unlock()
|
||||
if err != nil {
|
||||
return fmt.Errorf("ManualUpdate: failed to mark device list for %s as stale: %w", userID, err)
|
||||
}
|
||||
u.notifyWorkers(userID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update blocks until the update has been stored in the database. It blocks primarily for satisfying sytest,
|
||||
// which assumes when /send 200 OKs that the device lists have been updated.
|
||||
func (u *DeviceListUpdater) Update(ctx context.Context, event gomatrixserverlib.DeviceListUpdateEvent) error {
|
||||
isDeviceListStale, err := u.update(ctx, event)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if isDeviceListStale {
|
||||
// poke workers to handle stale device lists
|
||||
u.notifyWorkers(event.UserID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *DeviceListUpdater) update(ctx context.Context, event gomatrixserverlib.DeviceListUpdateEvent) (bool, error) {
|
||||
mu := u.mutex(event.UserID)
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
// check if we have the prev IDs
|
||||
exists, err := u.db.PrevIDsExists(ctx, event.UserID, event.PrevID)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check prev IDs exist for %s (%s): %w", event.UserID, event.DeviceID, err)
|
||||
}
|
||||
// if this is the first time we're hearing about this user, sync the device list manually.
|
||||
if len(event.PrevID) == 0 {
|
||||
exists = false
|
||||
}
|
||||
util.GetLogger(ctx).WithFields(logrus.Fields{
|
||||
"prev_ids_exist": exists,
|
||||
"user_id": event.UserID,
|
||||
"device_id": event.DeviceID,
|
||||
"stream_id": event.StreamID,
|
||||
"prev_ids": event.PrevID,
|
||||
"display_name": event.DeviceDisplayName,
|
||||
"deleted": event.Deleted,
|
||||
}).Trace("DeviceListUpdater.Update")
|
||||
|
||||
// if we haven't missed anything update the database and notify users
|
||||
if exists || event.Deleted {
|
||||
k := event.Keys
|
||||
if event.Deleted {
|
||||
k = nil
|
||||
}
|
||||
keys := []api.DeviceMessage{
|
||||
{
|
||||
Type: api.TypeDeviceKeyUpdate,
|
||||
DeviceKeys: &api.DeviceKeys{
|
||||
DeviceID: event.DeviceID,
|
||||
DisplayName: event.DeviceDisplayName,
|
||||
KeyJSON: k,
|
||||
UserID: event.UserID,
|
||||
},
|
||||
StreamID: event.StreamID,
|
||||
},
|
||||
}
|
||||
|
||||
// DeviceKeysJSON will side-effect modify this, so it needs
|
||||
// to be a copy, not sharing any pointers with the above.
|
||||
deviceKeysCopy := *keys[0].DeviceKeys
|
||||
deviceKeysCopy.KeyJSON = nil
|
||||
existingKeys := []api.DeviceMessage{
|
||||
{
|
||||
Type: keys[0].Type,
|
||||
DeviceKeys: &deviceKeysCopy,
|
||||
StreamID: keys[0].StreamID,
|
||||
},
|
||||
}
|
||||
|
||||
// fetch what keys we had already and only emit changes
|
||||
if err = u.db.DeviceKeysJSON(ctx, existingKeys); err != nil {
|
||||
// non-fatal, log and continue
|
||||
util.GetLogger(ctx).WithError(err).WithField("user_id", event.UserID).Errorf(
|
||||
"failed to query device keys json for calculating diffs",
|
||||
)
|
||||
}
|
||||
|
||||
err = u.db.StoreRemoteDeviceKeys(ctx, keys, nil)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to store remote device keys for %s (%s): %w", event.UserID, event.DeviceID, err)
|
||||
}
|
||||
|
||||
if err = emitDeviceKeyChanges(u.producer, existingKeys, keys, false); err != nil {
|
||||
return false, fmt.Errorf("failed to produce device key changes for %s (%s): %w", event.UserID, event.DeviceID, err)
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
err = u.db.MarkDeviceListStale(ctx, event.UserID, true)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to mark device list for %s as stale: %w", event.UserID, err)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (u *DeviceListUpdater) notifyWorkers(userID string) {
|
||||
_, remoteServer, err := gomatrixserverlib.SplitID('@', userID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
hash := fnv.New32a()
|
||||
_, _ = hash.Write([]byte(remoteServer))
|
||||
index := int(int64(hash.Sum32()) % int64(len(u.workerChans)))
|
||||
|
||||
ch := u.assignChannel(userID)
|
||||
u.workerChans[index] <- remoteServer
|
||||
select {
|
||||
case <-ch:
|
||||
case <-time.After(10 * time.Second):
|
||||
// we don't return an error in this case as it's not a failure condition.
|
||||
// we mainly block for the benefit of sytest anyway
|
||||
}
|
||||
}
|
||||
|
||||
func (u *DeviceListUpdater) assignChannel(userID string) chan bool {
|
||||
u.userIDToChanMu.Lock()
|
||||
defer u.userIDToChanMu.Unlock()
|
||||
if ch, ok := u.userIDToChan[userID]; ok {
|
||||
return ch
|
||||
}
|
||||
ch := make(chan bool)
|
||||
u.userIDToChan[userID] = ch
|
||||
return ch
|
||||
}
|
||||
|
||||
func (u *DeviceListUpdater) clearChannel(userID string) {
|
||||
u.userIDToChanMu.Lock()
|
||||
defer u.userIDToChanMu.Unlock()
|
||||
if ch, ok := u.userIDToChan[userID]; ok {
|
||||
close(ch)
|
||||
delete(u.userIDToChan, userID)
|
||||
}
|
||||
}
|
||||
|
||||
func (u *DeviceListUpdater) worker(ch chan gomatrixserverlib.ServerName) {
|
||||
retries := make(map[gomatrixserverlib.ServerName]time.Time)
|
||||
retriesMu := &sync.Mutex{}
|
||||
// restarter goroutine which will inject failed servers into ch when it is time
|
||||
go func() {
|
||||
var serversToRetry []gomatrixserverlib.ServerName
|
||||
for {
|
||||
serversToRetry = serversToRetry[:0] // reuse memory
|
||||
time.Sleep(time.Second)
|
||||
retriesMu.Lock()
|
||||
now := time.Now()
|
||||
for srv, retryAt := range retries {
|
||||
if now.After(retryAt) {
|
||||
serversToRetry = append(serversToRetry, srv)
|
||||
}
|
||||
}
|
||||
for _, srv := range serversToRetry {
|
||||
delete(retries, srv)
|
||||
}
|
||||
retriesMu.Unlock()
|
||||
for _, srv := range serversToRetry {
|
||||
ch <- srv
|
||||
}
|
||||
}
|
||||
}()
|
||||
for serverName := range ch {
|
||||
retriesMu.Lock()
|
||||
_, exists := retries[serverName]
|
||||
retriesMu.Unlock()
|
||||
if exists {
|
||||
// Don't retry a server that we're already waiting for.
|
||||
continue
|
||||
}
|
||||
waitTime, shouldRetry := u.processServer(serverName)
|
||||
if shouldRetry {
|
||||
retriesMu.Lock()
|
||||
if _, exists = retries[serverName]; !exists {
|
||||
retries[serverName] = time.Now().Add(waitTime)
|
||||
}
|
||||
retriesMu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerName) (time.Duration, bool) {
|
||||
ctx := u.process.Context()
|
||||
logger := util.GetLogger(ctx).WithField("server_name", serverName)
|
||||
deviceListUpdateCount.WithLabelValues(string(serverName)).Inc()
|
||||
|
||||
waitTime := defaultWaitTime // How long should we wait to try again?
|
||||
successCount := 0 // How many user requests failed?
|
||||
|
||||
userIDs, err := u.db.StaleDeviceLists(ctx, []gomatrixserverlib.ServerName{serverName})
|
||||
if err != nil {
|
||||
logger.WithError(err).Error("Failed to load stale device lists")
|
||||
return waitTime, true
|
||||
}
|
||||
|
||||
defer func() {
|
||||
for _, userID := range userIDs {
|
||||
// always clear the channel to unblock Update calls regardless of success/failure
|
||||
u.clearChannel(userID)
|
||||
}
|
||||
}()
|
||||
|
||||
for _, userID := range userIDs {
|
||||
userWait, err := u.processServerUser(ctx, serverName, userID)
|
||||
if err != nil {
|
||||
if userWait > waitTime {
|
||||
waitTime = userWait
|
||||
}
|
||||
break
|
||||
}
|
||||
successCount++
|
||||
}
|
||||
|
||||
allUsersSucceeded := successCount == len(userIDs)
|
||||
if !allUsersSucceeded {
|
||||
logger.WithFields(logrus.Fields{
|
||||
"total": len(userIDs),
|
||||
"succeeded": successCount,
|
||||
"failed": len(userIDs) - successCount,
|
||||
"wait_time": waitTime,
|
||||
}).Debug("Failed to query device keys for some users")
|
||||
}
|
||||
return waitTime, !allUsersSucceeded
|
||||
}
|
||||
|
||||
func (u *DeviceListUpdater) processServerUser(ctx context.Context, serverName gomatrixserverlib.ServerName, userID string) (time.Duration, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, requestTimeout)
|
||||
defer cancel()
|
||||
logger := util.GetLogger(ctx).WithFields(logrus.Fields{
|
||||
"server_name": serverName,
|
||||
"user_id": userID,
|
||||
})
|
||||
res, err := u.fedClient.GetUserDevices(ctx, u.thisServer, serverName, userID)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
return time.Minute * 10, err
|
||||
}
|
||||
switch e := err.(type) {
|
||||
case *json.UnmarshalTypeError, *json.SyntaxError:
|
||||
logger.WithError(err).Debugf("Device list update for %q contained invalid JSON", userID)
|
||||
return defaultWaitTime, nil
|
||||
case *fedsenderapi.FederationClientError:
|
||||
if e.RetryAfter > 0 {
|
||||
return e.RetryAfter, err
|
||||
} else if e.Blacklisted {
|
||||
return time.Hour * 8, err
|
||||
}
|
||||
case net.Error:
|
||||
// Use the default waitTime, if it's a timeout.
|
||||
// It probably doesn't make sense to try further users.
|
||||
if !e.Timeout() {
|
||||
logger.WithError(e).Debug("GetUserDevices returned net.Error")
|
||||
return time.Minute * 10, err
|
||||
}
|
||||
case gomatrix.HTTPError:
|
||||
// The remote server returned an error, give it some time to recover.
|
||||
// This is to avoid spamming remote servers, which may not be Matrix servers anymore.
|
||||
if e.Code >= 300 {
|
||||
logger.WithError(e).Debug("GetUserDevices returned gomatrix.HTTPError")
|
||||
return hourWaitTime, err
|
||||
}
|
||||
default:
|
||||
// Something else failed
|
||||
logger.WithError(err).Debugf("GetUserDevices returned unknown error type: %T", err)
|
||||
return time.Minute * 10, err
|
||||
}
|
||||
}
|
||||
if res.UserID != userID {
|
||||
logger.WithError(err).Debugf("User ID %q in device list update response doesn't match expected %q", res.UserID, userID)
|
||||
return defaultWaitTime, nil
|
||||
}
|
||||
if res.MasterKey != nil || res.SelfSigningKey != nil {
|
||||
uploadReq := &api.PerformUploadDeviceKeysRequest{
|
||||
UserID: userID,
|
||||
}
|
||||
uploadRes := &api.PerformUploadDeviceKeysResponse{}
|
||||
if res.MasterKey != nil {
|
||||
if err = sanityCheckKey(*res.MasterKey, userID, gomatrixserverlib.CrossSigningKeyPurposeMaster); err == nil {
|
||||
uploadReq.MasterKey = *res.MasterKey
|
||||
}
|
||||
}
|
||||
if res.SelfSigningKey != nil {
|
||||
if err = sanityCheckKey(*res.SelfSigningKey, userID, gomatrixserverlib.CrossSigningKeyPurposeSelfSigning); err == nil {
|
||||
uploadReq.SelfSigningKey = *res.SelfSigningKey
|
||||
}
|
||||
}
|
||||
_ = u.api.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes)
|
||||
}
|
||||
err = u.updateDeviceList(&res)
|
||||
if err != nil {
|
||||
logger.WithError(err).Error("Fetched device list but failed to store/emit it")
|
||||
return defaultWaitTime, err
|
||||
}
|
||||
return defaultWaitTime, nil
|
||||
}
|
||||
|
||||
func (u *DeviceListUpdater) updateDeviceList(res *gomatrixserverlib.RespUserDevices) error {
|
||||
ctx := context.Background() // we've got the keys, don't time out when persisting them to the database.
|
||||
keys := make([]api.DeviceMessage, len(res.Devices))
|
||||
existingKeys := make([]api.DeviceMessage, len(res.Devices))
|
||||
for i, device := range res.Devices {
|
||||
keyJSON, err := json.Marshal(device.Keys)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithField("keys", device.Keys).Error("failed to marshal keys, skipping device")
|
||||
continue
|
||||
}
|
||||
keys[i] = api.DeviceMessage{
|
||||
Type: api.TypeDeviceKeyUpdate,
|
||||
StreamID: res.StreamID,
|
||||
DeviceKeys: &api.DeviceKeys{
|
||||
DeviceID: device.DeviceID,
|
||||
DisplayName: device.DisplayName,
|
||||
UserID: res.UserID,
|
||||
KeyJSON: keyJSON,
|
||||
},
|
||||
}
|
||||
existingKeys[i] = api.DeviceMessage{
|
||||
Type: api.TypeDeviceKeyUpdate,
|
||||
DeviceKeys: &api.DeviceKeys{
|
||||
UserID: res.UserID,
|
||||
DeviceID: device.DeviceID,
|
||||
},
|
||||
}
|
||||
}
|
||||
// fetch what keys we had already and only emit changes
|
||||
if err := u.db.DeviceKeysJSON(ctx, existingKeys); err != nil {
|
||||
// non-fatal, log and continue
|
||||
util.GetLogger(ctx).WithError(err).WithField("user_id", res.UserID).Errorf(
|
||||
"failed to query device keys json for calculating diffs",
|
||||
)
|
||||
}
|
||||
|
||||
err := u.db.StoreRemoteDeviceKeys(ctx, keys, []string{res.UserID})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to store remote device keys: %w", err)
|
||||
}
|
||||
err = u.db.MarkDeviceListStale(ctx, res.UserID, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to mark device list as fresh: %w", err)
|
||||
}
|
||||
err = emitDeviceKeyChanges(u.producer, existingKeys, keys, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to emit key changes for fresh device list: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
22
userapi/internal/device_list_update_default.go
Normal file
22
userapi/internal/device_list_update_default.go
Normal file
|
@ -0,0 +1,22 @@
|
|||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//go:build !vw
|
||||
|
||||
package internal
|
||||
|
||||
import "time"
|
||||
|
||||
const defaultWaitTime = time.Minute
|
||||
const hourWaitTime = time.Hour
|
25
userapi/internal/device_list_update_sytest.go
Normal file
25
userapi/internal/device_list_update_sytest.go
Normal file
|
@ -0,0 +1,25 @@
|
|||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//go:build vw
|
||||
|
||||
package internal
|
||||
|
||||
import "time"
|
||||
|
||||
// Sytest is expecting to receive a `/devices` request. The way it is implemented in Dendrite
|
||||
// results in a one-hour wait time from a previous device so the test times out. This is fine for
|
||||
// production, but makes an otherwise passing test fail.
|
||||
const defaultWaitTime = time.Second
|
||||
const hourWaitTime = time.Second
|
431
userapi/internal/device_list_update_test.go
Normal file
431
userapi/internal/device_list_update_test.go
Normal file
|
@ -0,0 +1,431 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
||||
roomserver "github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/setup/process"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
"github.com/matrix-org/dendrite/test/testrig"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage"
|
||||
)
|
||||
|
||||
var (
|
||||
ctx = context.Background()
|
||||
)
|
||||
|
||||
type mockKeyChangeProducer struct {
|
||||
events []api.DeviceMessage
|
||||
}
|
||||
|
||||
func (p *mockKeyChangeProducer) ProduceKeyChanges(keys []api.DeviceMessage) error {
|
||||
p.events = append(p.events, keys...)
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockDeviceListUpdaterDatabase struct {
|
||||
staleUsers map[string]bool
|
||||
prevIDsExist func(string, []int64) bool
|
||||
storedKeys []api.DeviceMessage
|
||||
mu sync.Mutex // protect staleUsers
|
||||
}
|
||||
|
||||
func (d *mockDeviceListUpdaterDatabase) DeleteStaleDeviceLists(ctx context.Context, userIDs []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
|
||||
// If no domains are given, all user IDs with stale device lists are returned.
|
||||
func (d *mockDeviceListUpdaterDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
var result []string
|
||||
for userID, isStale := range d.staleUsers {
|
||||
if !isStale {
|
||||
continue
|
||||
}
|
||||
_, remoteServer, err := gomatrixserverlib.SplitID('@', userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(domains) == 0 {
|
||||
result = append(result, userID)
|
||||
continue
|
||||
}
|
||||
for _, d := range domains {
|
||||
if remoteServer == d {
|
||||
result = append(result, userID)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// MarkDeviceListStale sets the stale bit for this user to isStale.
|
||||
func (d *mockDeviceListUpdaterDatabase) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
d.staleUsers[userID] = isStale
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *mockDeviceListUpdaterDatabase) isStale(userID string) bool {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
return d.staleUsers[userID]
|
||||
}
|
||||
|
||||
// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
|
||||
// for this (user, device). Does not modify the stream ID for keys.
|
||||
func (d *mockDeviceListUpdaterDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clear []string) error {
|
||||
d.storedKeys = append(d.storedKeys, keys...)
|
||||
return nil
|
||||
}
|
||||
|
||||
// PrevIDsExists returns true if all prev IDs exist for this user.
|
||||
func (d *mockDeviceListUpdaterDatabase) PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) {
|
||||
return d.prevIDsExist(userID, prevIDs), nil
|
||||
}
|
||||
|
||||
func (d *mockDeviceListUpdaterDatabase) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockDeviceListUpdaterAPI struct {
|
||||
}
|
||||
|
||||
func (d *mockDeviceListUpdaterAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type roundTripper struct {
|
||||
fn func(*http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
func (t *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return t.fn(req)
|
||||
}
|
||||
|
||||
func newFedClient(tripper func(*http.Request) (*http.Response, error)) *gomatrixserverlib.FederationClient {
|
||||
_, pkey, _ := ed25519.GenerateKey(nil)
|
||||
fedClient := gomatrixserverlib.NewFederationClient(
|
||||
[]*gomatrixserverlib.SigningIdentity{
|
||||
{
|
||||
ServerName: gomatrixserverlib.ServerName("example.test"),
|
||||
KeyID: gomatrixserverlib.KeyID("ed25519:test"),
|
||||
PrivateKey: pkey,
|
||||
},
|
||||
},
|
||||
)
|
||||
fedClient.Client = *gomatrixserverlib.NewClient(
|
||||
gomatrixserverlib.WithTransport(&roundTripper{tripper}),
|
||||
)
|
||||
return fedClient
|
||||
}
|
||||
|
||||
// Test that the device keys get persisted and emitted if we have the previous IDs.
|
||||
func TestUpdateHavePrevID(t *testing.T) {
|
||||
db := &mockDeviceListUpdaterDatabase{
|
||||
staleUsers: make(map[string]bool),
|
||||
prevIDsExist: func(string, []int64) bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
ap := &mockDeviceListUpdaterAPI{}
|
||||
producer := &mockKeyChangeProducer{}
|
||||
updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1, nil, "localhost")
|
||||
event := gomatrixserverlib.DeviceListUpdateEvent{
|
||||
DeviceDisplayName: "Foo Bar",
|
||||
Deleted: false,
|
||||
DeviceID: "FOO",
|
||||
Keys: []byte(`{"key":"value"}`),
|
||||
PrevID: []int64{0},
|
||||
StreamID: 1,
|
||||
UserID: "@alice:localhost",
|
||||
}
|
||||
err := updater.Update(ctx, event)
|
||||
if err != nil {
|
||||
t.Fatalf("Update returned an error: %s", err)
|
||||
}
|
||||
want := api.DeviceMessage{
|
||||
Type: api.TypeDeviceKeyUpdate,
|
||||
StreamID: event.StreamID,
|
||||
DeviceKeys: &api.DeviceKeys{
|
||||
DeviceID: event.DeviceID,
|
||||
DisplayName: event.DeviceDisplayName,
|
||||
KeyJSON: event.Keys,
|
||||
UserID: event.UserID,
|
||||
},
|
||||
}
|
||||
if !reflect.DeepEqual(producer.events, []api.DeviceMessage{want}) {
|
||||
t.Errorf("Update didn't produce correct event, got %v want %v", producer.events, want)
|
||||
}
|
||||
if !reflect.DeepEqual(db.storedKeys, []api.DeviceMessage{want}) {
|
||||
t.Errorf("DB didn't store correct event, got %v want %v", db.storedKeys, want)
|
||||
}
|
||||
if db.isStale(event.UserID) {
|
||||
t.Errorf("%s incorrectly marked as stale", event.UserID)
|
||||
}
|
||||
}
|
||||
|
||||
// Test that device keys are fetched from the remote server if we are missing prev IDs
|
||||
// and that the user's devices are marked as stale until it succeeds.
|
||||
func TestUpdateNoPrevID(t *testing.T) {
|
||||
db := &mockDeviceListUpdaterDatabase{
|
||||
staleUsers: make(map[string]bool),
|
||||
prevIDsExist: func(string, []int64) bool {
|
||||
return false
|
||||
},
|
||||
}
|
||||
ap := &mockDeviceListUpdaterAPI{}
|
||||
producer := &mockKeyChangeProducer{}
|
||||
remoteUserID := "@alice:example.somewhere"
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
keyJSON := `{"user_id":"` + remoteUserID + `","device_id":"JLAFKJWSCS","algorithms":["m.olm.v1.curve25519-aes-sha2","m.megolm.v1.aes-sha2"],"keys":{"curve25519:JLAFKJWSCS":"3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIArzgyI","ed25519:JLAFKJWSCS":"lEuiRJBit0IG6nUf5pUzWTUEsRVVe/HJkoKuEww9ULI"},"signatures":{"` + remoteUserID + `":{"ed25519:JLAFKJWSCS":"dSO80A01XiigH3uBiDVx/EjzaoycHcjq9lfQX0uWsqxl2giMIiSPR8a4d291W1ihKJL/a+myXS367WT6NAIcBA"}}}`
|
||||
fedClient := newFedClient(func(req *http.Request) (*http.Response, error) {
|
||||
defer wg.Done()
|
||||
if req.URL.Path != "/_matrix/federation/v1/user/devices/"+url.PathEscape(remoteUserID) {
|
||||
return nil, fmt.Errorf("test: invalid path: %s", req.URL.Path)
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: io.NopCloser(strings.NewReader(`
|
||||
{
|
||||
"user_id": "` + remoteUserID + `",
|
||||
"stream_id": 5,
|
||||
"devices": [
|
||||
{
|
||||
"device_id": "JLAFKJWSCS",
|
||||
"keys": ` + keyJSON + `,
|
||||
"device_display_name": "Mobile Phone"
|
||||
}
|
||||
]
|
||||
}
|
||||
`)),
|
||||
}, nil
|
||||
})
|
||||
updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2, nil, "example.test")
|
||||
if err := updater.Start(); err != nil {
|
||||
t.Fatalf("failed to start updater: %s", err)
|
||||
}
|
||||
event := gomatrixserverlib.DeviceListUpdateEvent{
|
||||
DeviceDisplayName: "Mobile Phone",
|
||||
Deleted: false,
|
||||
DeviceID: "another_device_id",
|
||||
Keys: []byte(`{"key":"value"}`),
|
||||
PrevID: []int64{3},
|
||||
StreamID: 4,
|
||||
UserID: remoteUserID,
|
||||
}
|
||||
err := updater.Update(ctx, event)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Update returned an error: %s", err)
|
||||
}
|
||||
t.Log("waiting for /users/devices to be called...")
|
||||
wg.Wait()
|
||||
// wait a bit for db to be updated...
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
want := api.DeviceMessage{
|
||||
Type: api.TypeDeviceKeyUpdate,
|
||||
StreamID: 5,
|
||||
DeviceKeys: &api.DeviceKeys{
|
||||
DeviceID: "JLAFKJWSCS",
|
||||
DisplayName: "Mobile Phone",
|
||||
UserID: remoteUserID,
|
||||
KeyJSON: []byte(keyJSON),
|
||||
},
|
||||
}
|
||||
// Now we should have a fresh list and the keys and emitted something
|
||||
if db.isStale(event.UserID) {
|
||||
t.Errorf("%s still marked as stale", event.UserID)
|
||||
}
|
||||
if !reflect.DeepEqual(producer.events, []api.DeviceMessage{want}) {
|
||||
t.Logf("len got %d len want %d", len(producer.events[0].KeyJSON), len(want.KeyJSON))
|
||||
t.Errorf("Update didn't produce correct event, got %v want %v", producer.events, want)
|
||||
}
|
||||
if !reflect.DeepEqual(db.storedKeys, []api.DeviceMessage{want}) {
|
||||
t.Errorf("DB didn't store correct event, got %v want %v", db.storedKeys, want)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Test that if we make N calls to ManualUpdate for the same user, we only do it once, assuming the
|
||||
// update is still ongoing.
|
||||
func TestDebounce(t *testing.T) {
|
||||
t.Skipf("panic on closed channel on GHA")
|
||||
db := &mockDeviceListUpdaterDatabase{
|
||||
staleUsers: make(map[string]bool),
|
||||
prevIDsExist: func(string, []int64) bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
ap := &mockDeviceListUpdaterAPI{}
|
||||
producer := &mockKeyChangeProducer{}
|
||||
fedCh := make(chan *http.Response, 1)
|
||||
srv := gomatrixserverlib.ServerName("example.com")
|
||||
userID := "@alice:example.com"
|
||||
keyJSON := `{"user_id":"` + userID + `","device_id":"JLAFKJWSCS","algorithms":["m.olm.v1.curve25519-aes-sha2","m.megolm.v1.aes-sha2"],"keys":{"curve25519:JLAFKJWSCS":"3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIArzgyI","ed25519:JLAFKJWSCS":"lEuiRJBit0IG6nUf5pUzWTUEsRVVe/HJkoKuEww9ULI"},"signatures":{"` + userID + `":{"ed25519:JLAFKJWSCS":"dSO80A01XiigH3uBiDVx/EjzaoycHcjq9lfQX0uWsqxl2giMIiSPR8a4d291W1ihKJL/a+myXS367WT6NAIcBA"}}}`
|
||||
incomingFedReq := make(chan struct{})
|
||||
fedClient := newFedClient(func(req *http.Request) (*http.Response, error) {
|
||||
if req.URL.Path != "/_matrix/federation/v1/user/devices/"+url.PathEscape(userID) {
|
||||
return nil, fmt.Errorf("test: invalid path: %s", req.URL.Path)
|
||||
}
|
||||
close(incomingFedReq)
|
||||
return <-fedCh, nil
|
||||
})
|
||||
updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1, nil, "localhost")
|
||||
if err := updater.Start(); err != nil {
|
||||
t.Fatalf("failed to start updater: %s", err)
|
||||
}
|
||||
|
||||
// hit this 5 times
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(5)
|
||||
for i := 0; i < 5; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := updater.ManualUpdate(context.Background(), srv, userID); err != nil {
|
||||
t.Errorf("ManualUpdate: %s", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// wait until the updater hits federation
|
||||
select {
|
||||
case <-incomingFedReq:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("timed out waiting for updater to hit federation")
|
||||
}
|
||||
|
||||
// user should be marked as stale
|
||||
if !db.isStale(userID) {
|
||||
t.Errorf("user %s not marked as stale", userID)
|
||||
}
|
||||
// now send the response over federation
|
||||
fedCh <- &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: io.NopCloser(strings.NewReader(`
|
||||
{
|
||||
"user_id": "` + userID + `",
|
||||
"stream_id": 5,
|
||||
"devices": [
|
||||
{
|
||||
"device_id": "JLAFKJWSCS",
|
||||
"keys": ` + keyJSON + `,
|
||||
"device_display_name": "Mobile Phone"
|
||||
}
|
||||
]
|
||||
}
|
||||
`)),
|
||||
}
|
||||
close(fedCh)
|
||||
// wait until all 5 ManualUpdates return. If we hit federation again we won't send a response
|
||||
// and should panic with read on a closed channel
|
||||
wg.Wait()
|
||||
|
||||
// user is no longer stale now
|
||||
if db.isStale(userID) {
|
||||
t.Errorf("user %s is marked as stale", userID)
|
||||
}
|
||||
}
|
||||
|
||||
func mustCreateKeyserverDB(t *testing.T, dbType test.DBType) (storage.KeyDatabase, func()) {
|
||||
t.Helper()
|
||||
|
||||
base, _, _ := testrig.Base(nil)
|
||||
connStr, clearDB := test.PrepareDBConnectionString(t, dbType)
|
||||
db, err := storage.NewKeyDatabase(base, &config.DatabaseOptions{ConnectionString: config.DataSource(connStr)})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return db, clearDB
|
||||
}
|
||||
|
||||
type mockKeyserverRoomserverAPI struct {
|
||||
leftUsers []string
|
||||
}
|
||||
|
||||
func (m *mockKeyserverRoomserverAPI) QueryLeftUsers(ctx context.Context, req *roomserver.QueryLeftUsersRequest, res *roomserver.QueryLeftUsersResponse) error {
|
||||
res.LeftUsers = m.leftUsers
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestDeviceListUpdater_CleanUp(t *testing.T) {
|
||||
processCtx := process.NewProcessContext()
|
||||
|
||||
alice := test.NewUser(t)
|
||||
bob := test.NewUser(t)
|
||||
|
||||
// Bob is not joined to any of our rooms
|
||||
rsAPI := &mockKeyserverRoomserverAPI{leftUsers: []string{bob.ID}}
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, clearDB := mustCreateKeyserverDB(t, dbType)
|
||||
defer clearDB()
|
||||
|
||||
// This should not get deleted
|
||||
if err := db.MarkDeviceListStale(processCtx.Context(), alice.ID, true); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
// this one should get deleted
|
||||
if err := db.MarkDeviceListStale(processCtx.Context(), bob.ID, true); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
updater := NewDeviceListUpdater(processCtx, db, nil,
|
||||
nil, nil,
|
||||
0, rsAPI, "test")
|
||||
if err := updater.CleanUp(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
// check that we still have Alice in our stale list
|
||||
staleUsers, err := db.StaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
// There should only be Alice
|
||||
wantCount := 1
|
||||
if count := len(staleUsers); count != wantCount {
|
||||
t.Fatalf("expected there to be %d stale device lists, got %d", wantCount, count)
|
||||
}
|
||||
|
||||
if staleUsers[0] != alice.ID {
|
||||
t.Fatalf("unexpected stale device list user: %s, want %s", staleUsers[0], alice.ID)
|
||||
}
|
||||
})
|
||||
}
|
798
userapi/internal/key_api.go
Normal file
798
userapi/internal/key_api.go
Normal file
|
@ -0,0 +1,798 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
)
|
||||
|
||||
func (a *UserInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) error {
|
||||
userIDs, latest, err := a.KeyDatabase.KeyChanges(ctx, req.Offset, req.ToOffset)
|
||||
if err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: err.Error(),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
res.Offset = latest
|
||||
res.UserIDs = userIDs
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) error {
|
||||
res.KeyErrors = make(map[string]map[string]*api.KeyError)
|
||||
if len(req.DeviceKeys) > 0 {
|
||||
a.uploadLocalDeviceKeys(ctx, req, res)
|
||||
}
|
||||
if len(req.OneTimeKeys) > 0 {
|
||||
a.uploadOneTimeKeys(ctx, req, res)
|
||||
}
|
||||
otks, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
res.OneTimeKeyCounts = []api.OneTimeKeysCount{*otks}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformClaimKeysRequest, res *api.PerformClaimKeysResponse) error {
|
||||
res.OneTimeKeys = make(map[string]map[string]map[string]json.RawMessage)
|
||||
res.Failures = make(map[string]interface{})
|
||||
// wrap request map in a top-level by-domain map
|
||||
domainToDeviceKeys := make(map[string]map[string]map[string]string)
|
||||
for userID, val := range req.OneTimeKeys {
|
||||
_, serverName, err := gomatrixserverlib.SplitID('@', userID)
|
||||
if err != nil {
|
||||
continue // ignore invalid users
|
||||
}
|
||||
nested, ok := domainToDeviceKeys[string(serverName)]
|
||||
if !ok {
|
||||
nested = make(map[string]map[string]string)
|
||||
}
|
||||
nested[userID] = val
|
||||
domainToDeviceKeys[string(serverName)] = nested
|
||||
}
|
||||
for domain, local := range domainToDeviceKeys {
|
||||
if !a.Config.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) {
|
||||
continue
|
||||
}
|
||||
// claim local keys
|
||||
keys, err := a.KeyDatabase.ClaimKeys(ctx, local)
|
||||
if err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("failed to ClaimKeys locally: %s", err),
|
||||
}
|
||||
}
|
||||
util.GetLogger(ctx).WithField("keys_claimed", len(keys)).WithField("num_users", len(local)).Info("Claimed local keys")
|
||||
for _, key := range keys {
|
||||
_, ok := res.OneTimeKeys[key.UserID]
|
||||
if !ok {
|
||||
res.OneTimeKeys[key.UserID] = make(map[string]map[string]json.RawMessage)
|
||||
}
|
||||
_, ok = res.OneTimeKeys[key.UserID][key.DeviceID]
|
||||
if !ok {
|
||||
res.OneTimeKeys[key.UserID][key.DeviceID] = make(map[string]json.RawMessage)
|
||||
}
|
||||
for keyID, keyJSON := range key.KeyJSON {
|
||||
res.OneTimeKeys[key.UserID][key.DeviceID][keyID] = keyJSON
|
||||
}
|
||||
}
|
||||
delete(domainToDeviceKeys, domain)
|
||||
}
|
||||
if len(domainToDeviceKeys) > 0 {
|
||||
a.claimRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) claimRemoteKeys(
|
||||
ctx context.Context, timeout time.Duration, res *api.PerformClaimKeysResponse, domainToDeviceKeys map[string]map[string]map[string]string,
|
||||
) {
|
||||
var wg sync.WaitGroup // Wait for fan-out goroutines to finish
|
||||
var mu sync.Mutex // Protects the response struct
|
||||
var claimed int // Number of keys claimed in total
|
||||
var failures int // Number of servers we failed to ask
|
||||
|
||||
util.GetLogger(ctx).Infof("Claiming remote keys from %d server(s)", len(domainToDeviceKeys))
|
||||
wg.Add(len(domainToDeviceKeys))
|
||||
|
||||
for d, k := range domainToDeviceKeys {
|
||||
go func(domain string, keysToClaim map[string]map[string]string) {
|
||||
fedCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
defer wg.Done()
|
||||
|
||||
claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, a.Config.Matrix.ServerName, gomatrixserverlib.ServerName(domain), keysToClaim)
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithError(err).WithField("server", domain).Error("ClaimKeys failed")
|
||||
res.Failures[domain] = map[string]interface{}{
|
||||
"message": err.Error(),
|
||||
}
|
||||
failures++
|
||||
return
|
||||
}
|
||||
|
||||
for userID, deviceIDToKeys := range claimKeyRes.OneTimeKeys {
|
||||
res.OneTimeKeys[userID] = make(map[string]map[string]json.RawMessage)
|
||||
for deviceID, keys := range deviceIDToKeys {
|
||||
res.OneTimeKeys[userID][deviceID] = keys
|
||||
claimed += len(keys)
|
||||
}
|
||||
}
|
||||
}(d, k)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
util.GetLogger(ctx).WithFields(logrus.Fields{
|
||||
"num_keys": claimed,
|
||||
"num_failures": failures,
|
||||
}).Infof("Claimed remote keys from %d server(s)", len(domainToDeviceKeys))
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.PerformDeleteKeysRequest, res *api.PerformDeleteKeysResponse) error {
|
||||
if err := a.KeyDatabase.DeleteDeviceKeys(ctx, req.UserID, req.KeyIDs); err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("Failed to delete device keys: %s", err),
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOneTimeKeysRequest, res *api.QueryOneTimeKeysResponse) error {
|
||||
count, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
|
||||
if err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("Failed to query OTK counts: %s", err),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
res.Count = *count
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) error {
|
||||
msgs, err := a.KeyDatabase.DeviceKeysForUser(ctx, req.UserID, nil, false)
|
||||
if err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("failed to query DB for device keys: %s", err),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
maxStreamID := int64(0)
|
||||
// remove deleted devices
|
||||
var result []api.DeviceMessage
|
||||
for _, m := range msgs {
|
||||
if m.StreamID > maxStreamID {
|
||||
maxStreamID = m.StreamID
|
||||
}
|
||||
if m.KeyJSON == nil || len(m.KeyJSON) == 0 {
|
||||
continue
|
||||
}
|
||||
result = append(result, m)
|
||||
}
|
||||
res.Devices = result
|
||||
res.StreamID = maxStreamID
|
||||
return nil
|
||||
}
|
||||
|
||||
// PerformMarkAsStaleIfNeeded marks the users device list as stale, if the given deviceID is not present
|
||||
// in our database.
|
||||
func (a *UserInternalAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *api.PerformMarkAsStaleRequest, res *struct{}) error {
|
||||
knownDevices, err := a.KeyDatabase.DeviceKeysForUser(ctx, req.UserID, []string{}, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(knownDevices) == 0 {
|
||||
return nil // fmt.Errorf("unknown user %s", req.UserID)
|
||||
}
|
||||
|
||||
for i := range knownDevices {
|
||||
if knownDevices[i].DeviceID == req.DeviceID {
|
||||
return nil // we already know about this device
|
||||
}
|
||||
}
|
||||
|
||||
return a.Updater.ManualUpdate(ctx, req.Domain, req.UserID)
|
||||
}
|
||||
|
||||
// nolint:gocyclo
|
||||
func (a *UserInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) error {
|
||||
var respMu sync.Mutex
|
||||
res.DeviceKeys = make(map[string]map[string]json.RawMessage)
|
||||
res.MasterKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
|
||||
res.SelfSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
|
||||
res.UserSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
|
||||
res.Failures = make(map[string]interface{})
|
||||
|
||||
// make a map from domain to device keys
|
||||
domainToDeviceKeys := make(map[string]map[string][]string)
|
||||
domainToCrossSigningKeys := make(map[string]map[string]struct{})
|
||||
for userID, deviceIDs := range req.UserToDevices {
|
||||
_, serverName, err := gomatrixserverlib.SplitID('@', userID)
|
||||
if err != nil {
|
||||
continue // ignore invalid users
|
||||
}
|
||||
domain := string(serverName)
|
||||
// query local devices
|
||||
if a.Config.Matrix.IsLocalServerName(serverName) {
|
||||
deviceKeys, err := a.KeyDatabase.DeviceKeysForUser(ctx, userID, deviceIDs, false)
|
||||
if err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("failed to query local device keys: %s", err),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// pull out display names after we have the keys so we handle wildcards correctly
|
||||
var dids []string
|
||||
for _, dk := range deviceKeys {
|
||||
dids = append(dids, dk.DeviceID)
|
||||
}
|
||||
var queryRes api.QueryDeviceInfosResponse
|
||||
err = a.QueryDeviceInfos(ctx, &api.QueryDeviceInfosRequest{
|
||||
DeviceIDs: dids,
|
||||
}, &queryRes)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).Warnf("Failed to QueryDeviceInfos for device IDs, display names will be missing")
|
||||
}
|
||||
|
||||
if res.DeviceKeys[userID] == nil {
|
||||
res.DeviceKeys[userID] = make(map[string]json.RawMessage)
|
||||
}
|
||||
for _, dk := range deviceKeys {
|
||||
if len(dk.KeyJSON) == 0 {
|
||||
continue // don't include blank keys
|
||||
}
|
||||
// inject display name if known (either locally or remotely)
|
||||
displayName := dk.DisplayName
|
||||
if queryRes.DeviceInfo[dk.DeviceID].DisplayName != "" {
|
||||
displayName = queryRes.DeviceInfo[dk.DeviceID].DisplayName
|
||||
}
|
||||
dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct {
|
||||
DisplayName string `json:"device_display_name,omitempty"`
|
||||
}{displayName})
|
||||
res.DeviceKeys[userID][dk.DeviceID] = dk.KeyJSON
|
||||
}
|
||||
} else {
|
||||
domainToDeviceKeys[domain] = make(map[string][]string)
|
||||
domainToDeviceKeys[domain][userID] = append(domainToDeviceKeys[domain][userID], deviceIDs...)
|
||||
}
|
||||
// work out if our cross-signing request for this user was
|
||||
// satisfied, if not add them to the list of things to fetch
|
||||
if _, ok := res.MasterKeys[userID]; !ok {
|
||||
if _, ok := domainToCrossSigningKeys[domain]; !ok {
|
||||
domainToCrossSigningKeys[domain] = make(map[string]struct{})
|
||||
}
|
||||
domainToCrossSigningKeys[domain][userID] = struct{}{}
|
||||
}
|
||||
if _, ok := res.SelfSigningKeys[userID]; !ok {
|
||||
if _, ok := domainToCrossSigningKeys[domain]; !ok {
|
||||
domainToCrossSigningKeys[domain] = make(map[string]struct{})
|
||||
}
|
||||
domainToCrossSigningKeys[domain][userID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// attempt to satisfy key queries from the local database first as we should get device updates pushed to us
|
||||
domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, &respMu, domainToDeviceKeys)
|
||||
if len(domainToDeviceKeys) > 0 || len(domainToCrossSigningKeys) > 0 {
|
||||
// perform key queries for remote devices
|
||||
a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys, domainToCrossSigningKeys)
|
||||
}
|
||||
|
||||
// Now that we've done the potentially expensive work of asking the federation,
|
||||
// try filling the cross-signing keys from the database that we know about.
|
||||
a.crossSigningKeysFromDatabase(ctx, req, res)
|
||||
|
||||
// Finally, append signatures that we know about
|
||||
// TODO: This is horrible because we need to round-trip the signature from
|
||||
// JSON, add the signatures and marshal it again, for some reason?
|
||||
|
||||
for targetUserID, masterKey := range res.MasterKeys {
|
||||
if masterKey.Signatures == nil {
|
||||
masterKey.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
|
||||
}
|
||||
for targetKeyID := range masterKey.Keys {
|
||||
sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, targetKeyID)
|
||||
if err != nil {
|
||||
// Stop executing the function if the context was canceled/the deadline was exceeded,
|
||||
// as we can't continue without a valid context.
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return nil
|
||||
}
|
||||
logrus.WithError(err).Errorf("a.KeyDatabase.CrossSigningSigsForTarget failed")
|
||||
continue
|
||||
}
|
||||
if len(sigMap) == 0 {
|
||||
continue
|
||||
}
|
||||
for sourceUserID, forSourceUser := range sigMap {
|
||||
for sourceKeyID, sourceSig := range forSourceUser {
|
||||
if _, ok := masterKey.Signatures[sourceUserID]; !ok {
|
||||
masterKey.Signatures[sourceUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
|
||||
}
|
||||
masterKey.Signatures[sourceUserID][sourceKeyID] = sourceSig
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for targetUserID, forUserID := range res.DeviceKeys {
|
||||
for targetKeyID, key := range forUserID {
|
||||
sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, gomatrixserverlib.KeyID(targetKeyID))
|
||||
if err != nil {
|
||||
// Stop executing the function if the context was canceled/the deadline was exceeded,
|
||||
// as we can't continue without a valid context.
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return nil
|
||||
}
|
||||
logrus.WithError(err).Errorf("a.KeyDatabase.CrossSigningSigsForTarget failed")
|
||||
continue
|
||||
}
|
||||
if len(sigMap) == 0 {
|
||||
continue
|
||||
}
|
||||
var deviceKey gomatrixserverlib.DeviceKeys
|
||||
if err = json.Unmarshal(key, &deviceKey); err != nil {
|
||||
continue
|
||||
}
|
||||
if deviceKey.Signatures == nil {
|
||||
deviceKey.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
|
||||
}
|
||||
for sourceUserID, forSourceUser := range sigMap {
|
||||
for sourceKeyID, sourceSig := range forSourceUser {
|
||||
if _, ok := deviceKey.Signatures[sourceUserID]; !ok {
|
||||
deviceKey.Signatures[sourceUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
|
||||
}
|
||||
deviceKey.Signatures[sourceUserID][sourceKeyID] = sourceSig
|
||||
}
|
||||
}
|
||||
if js, err := json.Marshal(deviceKey); err == nil {
|
||||
res.DeviceKeys[targetUserID][targetKeyID] = js
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) remoteKeysFromDatabase(
|
||||
ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, domainToDeviceKeys map[string]map[string][]string,
|
||||
) map[string]map[string][]string {
|
||||
fetchRemote := make(map[string]map[string][]string)
|
||||
for domain, userToDeviceMap := range domainToDeviceKeys {
|
||||
for userID, deviceIDs := range userToDeviceMap {
|
||||
// we can't safely return keys from the db when all devices are requested as we don't
|
||||
// know if one has just been added.
|
||||
if len(deviceIDs) > 0 {
|
||||
err := a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, deviceIDs)
|
||||
if err == nil {
|
||||
continue
|
||||
}
|
||||
util.GetLogger(ctx).WithError(err).Error("populateResponseWithDeviceKeysFromDatabase")
|
||||
}
|
||||
// fetch device lists from remote
|
||||
if _, ok := fetchRemote[domain]; !ok {
|
||||
fetchRemote[domain] = make(map[string][]string)
|
||||
}
|
||||
fetchRemote[domain][userID] = append(fetchRemote[domain][userID], deviceIDs...)
|
||||
|
||||
}
|
||||
}
|
||||
return fetchRemote
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) queryRemoteKeys(
|
||||
ctx context.Context, timeout time.Duration, res *api.QueryKeysResponse,
|
||||
domainToDeviceKeys map[string]map[string][]string, domainToCrossSigningKeys map[string]map[string]struct{},
|
||||
) {
|
||||
resultCh := make(chan *gomatrixserverlib.RespQueryKeys, len(domainToDeviceKeys))
|
||||
// allows us to wait until all federation servers have been poked
|
||||
var wg sync.WaitGroup
|
||||
// mutex for writing directly to res (e.g failures)
|
||||
var respMu sync.Mutex
|
||||
|
||||
domains := map[string]struct{}{}
|
||||
for domain := range domainToDeviceKeys {
|
||||
if a.Config.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) {
|
||||
continue
|
||||
}
|
||||
domains[domain] = struct{}{}
|
||||
}
|
||||
for domain := range domainToCrossSigningKeys {
|
||||
if a.Config.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) {
|
||||
continue
|
||||
}
|
||||
domains[domain] = struct{}{}
|
||||
}
|
||||
wg.Add(len(domains))
|
||||
|
||||
// fan out
|
||||
for domain := range domains {
|
||||
go a.queryRemoteKeysOnServer(
|
||||
ctx, domain, domainToDeviceKeys[domain], domainToCrossSigningKeys[domain],
|
||||
&wg, &respMu, timeout, resultCh, res,
|
||||
)
|
||||
}
|
||||
|
||||
// Close the result channel when the goroutines have quit so the for .. range exits
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(resultCh)
|
||||
}()
|
||||
|
||||
processResult := func(result *gomatrixserverlib.RespQueryKeys) {
|
||||
respMu.Lock()
|
||||
defer respMu.Unlock()
|
||||
for userID, nest := range result.DeviceKeys {
|
||||
res.DeviceKeys[userID] = make(map[string]json.RawMessage)
|
||||
for deviceID, deviceKey := range nest {
|
||||
keyJSON, err := json.Marshal(deviceKey)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
res.DeviceKeys[userID][deviceID] = keyJSON
|
||||
}
|
||||
}
|
||||
|
||||
for userID, body := range result.MasterKeys {
|
||||
res.MasterKeys[userID] = body
|
||||
}
|
||||
|
||||
for userID, body := range result.SelfSigningKeys {
|
||||
res.SelfSigningKeys[userID] = body
|
||||
}
|
||||
|
||||
// TODO: do we want to persist these somewhere now
|
||||
// that we have fetched them?
|
||||
}
|
||||
|
||||
for result := range resultCh {
|
||||
processResult(result)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) queryRemoteKeysOnServer(
|
||||
ctx context.Context, serverName string, devKeys map[string][]string, crossSigningKeys map[string]struct{},
|
||||
wg *sync.WaitGroup, respMu *sync.Mutex, timeout time.Duration, resultCh chan<- *gomatrixserverlib.RespQueryKeys,
|
||||
res *api.QueryKeysResponse,
|
||||
) {
|
||||
defer wg.Done()
|
||||
fedCtx := ctx
|
||||
if timeout > 0 {
|
||||
var cancel context.CancelFunc
|
||||
fedCtx, cancel = context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
}
|
||||
// for users who we do not have any knowledge about, try to start doing device list updates for them
|
||||
// by hitting /users/devices - otherwise fallback to /keys/query which has nicer bulk properties but
|
||||
// lack a stream ID.
|
||||
userIDsForAllDevices := map[string]struct{}{}
|
||||
for userID, deviceIDs := range devKeys {
|
||||
if len(deviceIDs) == 0 {
|
||||
userIDsForAllDevices[userID] = struct{}{}
|
||||
}
|
||||
}
|
||||
// for cross-signing keys, it's probably easier just to hit /keys/query if we aren't already doing
|
||||
// a device list update, so we'll populate those back into the /keys/query list if not
|
||||
for userID := range crossSigningKeys {
|
||||
if devKeys == nil {
|
||||
devKeys = map[string][]string{}
|
||||
}
|
||||
if _, ok := userIDsForAllDevices[userID]; !ok {
|
||||
devKeys[userID] = []string{}
|
||||
}
|
||||
}
|
||||
for userID := range userIDsForAllDevices {
|
||||
err := a.Updater.ManualUpdate(context.Background(), gomatrixserverlib.ServerName(serverName), userID)
|
||||
if err != nil {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
logrus.ErrorKey: err,
|
||||
"user_id": userID,
|
||||
"server": serverName,
|
||||
}).Error("Failed to manually update device lists for user")
|
||||
// try to do it via /keys/query
|
||||
devKeys[userID] = []string{}
|
||||
continue
|
||||
}
|
||||
// refresh entries from DB: unlike remoteKeysFromDatabase we know we previously had no device info for this
|
||||
// user so the fact that we're populating all devices here isn't a problem so long as we have devices.
|
||||
err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, nil)
|
||||
if err != nil {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
logrus.ErrorKey: err,
|
||||
"user_id": userID,
|
||||
"server": serverName,
|
||||
}).Error("Failed to manually update device lists for user")
|
||||
// try to do it via /keys/query
|
||||
devKeys[userID] = []string{}
|
||||
continue
|
||||
}
|
||||
}
|
||||
if len(devKeys) == 0 {
|
||||
return
|
||||
}
|
||||
queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, a.Config.Matrix.ServerName, gomatrixserverlib.ServerName(serverName), devKeys)
|
||||
if err == nil {
|
||||
resultCh <- &queryKeysResp
|
||||
return
|
||||
}
|
||||
respMu.Lock()
|
||||
res.Failures[serverName] = map[string]interface{}{
|
||||
"message": err.Error(),
|
||||
}
|
||||
respMu.Unlock()
|
||||
|
||||
// last ditch, use the cache only. This is good for when clients hit /keys/query and the remote server
|
||||
// is down, better to return something than nothing at all. Clients can know about the failure by
|
||||
// inspecting the failures map though so they can know it's a cached response.
|
||||
for userID, dkeys := range devKeys {
|
||||
// drop the error as it's already a failure at this point
|
||||
_ = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, dkeys)
|
||||
}
|
||||
|
||||
// Sytest expects no failures, if we still could retrieve keys, e.g. from local cache
|
||||
respMu.Lock()
|
||||
if len(res.DeviceKeys) > 0 {
|
||||
delete(res.Failures, serverName)
|
||||
}
|
||||
respMu.Unlock()
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) populateResponseWithDeviceKeysFromDatabase(
|
||||
ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, userID string, deviceIDs []string,
|
||||
) error {
|
||||
keys, err := a.KeyDatabase.DeviceKeysForUser(ctx, userID, deviceIDs, false)
|
||||
// if we can't query the db or there are fewer keys than requested, fetch from remote.
|
||||
if err != nil {
|
||||
return fmt.Errorf("DeviceKeysForUser %s %v failed: %w", userID, deviceIDs, err)
|
||||
}
|
||||
if len(keys) < len(deviceIDs) {
|
||||
return fmt.Errorf("DeviceKeysForUser %s returned fewer devices than requested, falling back to remote", userID)
|
||||
}
|
||||
if len(deviceIDs) == 0 && len(keys) == 0 {
|
||||
return fmt.Errorf("DeviceKeysForUser %s returned no keys but wanted all keys, falling back to remote", userID)
|
||||
}
|
||||
respMu.Lock()
|
||||
if res.DeviceKeys[userID] == nil {
|
||||
res.DeviceKeys[userID] = make(map[string]json.RawMessage)
|
||||
}
|
||||
respMu.Unlock()
|
||||
|
||||
for _, key := range keys {
|
||||
if len(key.KeyJSON) == 0 {
|
||||
continue // ignore deleted keys
|
||||
}
|
||||
// inject the display name
|
||||
key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct {
|
||||
DisplayName string `json:"device_display_name,omitempty"`
|
||||
}{key.DisplayName})
|
||||
respMu.Lock()
|
||||
res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON
|
||||
respMu.Unlock()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
|
||||
// get a list of devices from the user API that actually exist, as
|
||||
// we won't store keys for devices that don't exist
|
||||
uapidevices := &api.QueryDevicesResponse{}
|
||||
if err := a.QueryDevices(ctx, &api.QueryDevicesRequest{UserID: req.UserID}, uapidevices); err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: err.Error(),
|
||||
}
|
||||
return
|
||||
}
|
||||
if !uapidevices.UserExists {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("user %q does not exist", req.UserID),
|
||||
}
|
||||
return
|
||||
}
|
||||
existingDeviceMap := make(map[string]struct{}, len(uapidevices.Devices))
|
||||
for _, key := range uapidevices.Devices {
|
||||
existingDeviceMap[key.ID] = struct{}{}
|
||||
}
|
||||
|
||||
// Get all of the user existing device keys so we can check for changes.
|
||||
existingKeys, err := a.KeyDatabase.DeviceKeysForUser(ctx, req.UserID, nil, true)
|
||||
if err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("failed to query existing device keys: %s", err.Error()),
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Work out whether we have device keys in the keyserver for devices that
|
||||
// no longer exist in the user API. This is mostly an exercise to ensure
|
||||
// that we keep some integrity between the two.
|
||||
var toClean []gomatrixserverlib.KeyID
|
||||
for _, k := range existingKeys {
|
||||
if _, ok := existingDeviceMap[k.DeviceID]; !ok {
|
||||
toClean = append(toClean, gomatrixserverlib.KeyID(k.DeviceID))
|
||||
}
|
||||
}
|
||||
|
||||
if len(toClean) > 0 {
|
||||
if err = a.KeyDatabase.DeleteDeviceKeys(ctx, req.UserID, toClean); err != nil {
|
||||
logrus.WithField("user_id", req.UserID).WithError(err).Errorf("Failed to clean up %d stale keyserver device key entries", len(toClean))
|
||||
} else {
|
||||
logrus.WithField("user_id", req.UserID).Debugf("Cleaned up %d stale keyserver device key entries", len(toClean))
|
||||
}
|
||||
}
|
||||
|
||||
var keysToStore []api.DeviceMessage
|
||||
|
||||
if req.OnlyDisplayNameUpdates {
|
||||
for _, existingKey := range existingKeys {
|
||||
for _, newKey := range req.DeviceKeys {
|
||||
switch {
|
||||
case existingKey.UserID != newKey.UserID:
|
||||
continue
|
||||
case existingKey.DeviceID != newKey.DeviceID:
|
||||
continue
|
||||
case existingKey.DisplayName != newKey.DisplayName:
|
||||
existingKey.DisplayName = newKey.DisplayName
|
||||
}
|
||||
}
|
||||
keysToStore = append(keysToStore, existingKey)
|
||||
}
|
||||
} else {
|
||||
// assert that the user ID / device ID are not lying for each key
|
||||
for _, key := range req.DeviceKeys {
|
||||
var serverName gomatrixserverlib.ServerName
|
||||
_, serverName, err = gomatrixserverlib.SplitID('@', key.UserID)
|
||||
if err != nil {
|
||||
continue // ignore invalid users
|
||||
}
|
||||
if !a.Config.Matrix.IsLocalServerName(serverName) {
|
||||
continue // ignore remote users
|
||||
}
|
||||
if len(key.KeyJSON) == 0 {
|
||||
keysToStore = append(keysToStore, key.WithStreamID(0))
|
||||
continue // deleted keys don't need sanity checking
|
||||
}
|
||||
// check that the device in question actually exists in the user
|
||||
// API before we try and store a key for it
|
||||
if _, ok := existingDeviceMap[key.DeviceID]; !ok {
|
||||
continue
|
||||
}
|
||||
gotUserID := gjson.GetBytes(key.KeyJSON, "user_id").Str
|
||||
gotDeviceID := gjson.GetBytes(key.KeyJSON, "device_id").Str
|
||||
if gotUserID == key.UserID && gotDeviceID == key.DeviceID {
|
||||
keysToStore = append(keysToStore, key.WithStreamID(0))
|
||||
continue
|
||||
}
|
||||
|
||||
res.KeyError(key.UserID, key.DeviceID, &api.KeyError{
|
||||
Err: fmt.Sprintf(
|
||||
"user_id or device_id mismatch: users: %s - %s, devices: %s - %s",
|
||||
gotUserID, key.UserID, gotDeviceID, key.DeviceID,
|
||||
),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// store the device keys and emit changes
|
||||
err = a.KeyDatabase.StoreLocalDeviceKeys(ctx, keysToStore)
|
||||
if err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("failed to store device keys: %s", err.Error()),
|
||||
}
|
||||
return
|
||||
}
|
||||
err = emitDeviceKeyChanges(a.KeyChangeProducer, existingKeys, keysToStore, req.OnlyDisplayNameUpdates)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).Errorf("Failed to emitDeviceKeyChanges: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
|
||||
if req.UserID == "" {
|
||||
res.Error = &api.KeyError{
|
||||
Err: "user ID missing",
|
||||
}
|
||||
}
|
||||
if req.DeviceID != "" && len(req.OneTimeKeys) == 0 {
|
||||
counts, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
|
||||
if err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("a.KeyDatabase.OneTimeKeysCount: %s", err),
|
||||
}
|
||||
}
|
||||
if counts != nil {
|
||||
res.OneTimeKeyCounts = append(res.OneTimeKeyCounts, *counts)
|
||||
}
|
||||
return
|
||||
}
|
||||
for _, key := range req.OneTimeKeys {
|
||||
// grab existing keys based on (user/device/algorithm/key ID)
|
||||
keyIDsWithAlgorithms := make([]string, len(key.KeyJSON))
|
||||
i := 0
|
||||
for keyIDWithAlgo := range key.KeyJSON {
|
||||
keyIDsWithAlgorithms[i] = keyIDWithAlgo
|
||||
i++
|
||||
}
|
||||
existingKeys, err := a.KeyDatabase.ExistingOneTimeKeys(ctx, req.UserID, req.DeviceID, keyIDsWithAlgorithms)
|
||||
if err != nil {
|
||||
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
|
||||
Err: "failed to query existing one-time keys: " + err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
for keyIDWithAlgo := range existingKeys {
|
||||
// if keys exist and the JSON doesn't match, error out as the key already exists
|
||||
if !bytes.Equal(existingKeys[keyIDWithAlgo], key.KeyJSON[keyIDWithAlgo]) {
|
||||
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
|
||||
Err: fmt.Sprintf("%s device %s: algorithm / key ID %s one-time key already exists", req.UserID, req.DeviceID, keyIDWithAlgo),
|
||||
})
|
||||
continue
|
||||
}
|
||||
}
|
||||
// store one-time keys
|
||||
counts, err := a.KeyDatabase.StoreOneTimeKeys(ctx, key)
|
||||
if err != nil {
|
||||
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
|
||||
Err: fmt.Sprintf("%s device %s : failed to store one-time keys: %s", req.UserID, req.DeviceID, err.Error()),
|
||||
})
|
||||
continue
|
||||
}
|
||||
// collect counts
|
||||
res.OneTimeKeyCounts = append(res.OneTimeKeyCounts, *counts)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func emitDeviceKeyChanges(producer KeyChangeProducer, existing, new []api.DeviceMessage, onlyUpdateDisplayName bool) error {
|
||||
// if we only want to update the display names, we can skip the checks below
|
||||
if onlyUpdateDisplayName {
|
||||
return producer.ProduceKeyChanges(new)
|
||||
}
|
||||
// find keys in new that are not in existing
|
||||
var keysAdded []api.DeviceMessage
|
||||
for _, newKey := range new {
|
||||
exists := false
|
||||
for _, existingKey := range existing {
|
||||
// Do not treat the absence of keys as equal, or else we will not emit key changes
|
||||
// when users delete devices which never had a key to begin with as both KeyJSONs are nil.
|
||||
if existingKey.DeviceKeysEqual(&newKey) {
|
||||
exists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !exists {
|
||||
keysAdded = append(keysAdded, newKey)
|
||||
}
|
||||
}
|
||||
return producer.ProduceKeyChanges(keysAdded)
|
||||
}
|
161
userapi/internal/key_api_test.go
Normal file
161
userapi/internal/key_api_test.go
Normal file
|
@ -0,0 +1,161 @@
|
|||
package internal_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
"github.com/matrix-org/dendrite/test/testrig"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/internal"
|
||||
"github.com/matrix-org/dendrite/userapi/storage"
|
||||
)
|
||||
|
||||
func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.KeyDatabase, func()) {
|
||||
t.Helper()
|
||||
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||
base, _, _ := testrig.Base(nil)
|
||||
db, err := storage.NewKeyDatabase(base, &config.DatabaseOptions{
|
||||
ConnectionString: config.DataSource(connStr),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create new user db: %v", err)
|
||||
}
|
||||
return db, func() {
|
||||
base.Close()
|
||||
close()
|
||||
}
|
||||
}
|
||||
|
||||
func Test_QueryDeviceMessages(t *testing.T) {
|
||||
alice := test.NewUser(t)
|
||||
type args struct {
|
||||
req *api.QueryDeviceMessagesRequest
|
||||
res *api.QueryDeviceMessagesResponse
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
want *api.QueryDeviceMessagesResponse
|
||||
}{
|
||||
{
|
||||
name: "no existing keys",
|
||||
args: args{
|
||||
req: &api.QueryDeviceMessagesRequest{
|
||||
UserID: "@doesNotExist:localhost",
|
||||
},
|
||||
res: &api.QueryDeviceMessagesResponse{},
|
||||
},
|
||||
want: &api.QueryDeviceMessagesResponse{},
|
||||
},
|
||||
{
|
||||
name: "existing user returns devices",
|
||||
args: args{
|
||||
req: &api.QueryDeviceMessagesRequest{
|
||||
UserID: alice.ID,
|
||||
},
|
||||
res: &api.QueryDeviceMessagesResponse{},
|
||||
},
|
||||
want: &api.QueryDeviceMessagesResponse{
|
||||
StreamID: 6,
|
||||
Devices: []api.DeviceMessage{
|
||||
{
|
||||
Type: api.TypeDeviceKeyUpdate, StreamID: 5, DeviceKeys: &api.DeviceKeys{
|
||||
DeviceID: "myDevice",
|
||||
DisplayName: "first device",
|
||||
UserID: alice.ID,
|
||||
KeyJSON: []byte("ghi"),
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: api.TypeDeviceKeyUpdate, StreamID: 6, DeviceKeys: &api.DeviceKeys{
|
||||
DeviceID: "mySecondDevice",
|
||||
DisplayName: "second device",
|
||||
UserID: alice.ID,
|
||||
KeyJSON: []byte("jkl"),
|
||||
}, // streamID 6
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
deviceMessages := []api.DeviceMessage{
|
||||
{ // not the user we're looking for
|
||||
Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{
|
||||
UserID: "@doesNotExist:localhost",
|
||||
},
|
||||
// streamID 1 for this user
|
||||
},
|
||||
{ // empty keyJSON will be ignored
|
||||
Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{
|
||||
DeviceID: "myDevice",
|
||||
UserID: alice.ID,
|
||||
}, // streamID 1
|
||||
},
|
||||
{
|
||||
Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{
|
||||
DeviceID: "myDevice",
|
||||
UserID: alice.ID,
|
||||
KeyJSON: []byte("abc"),
|
||||
}, // streamID 2
|
||||
},
|
||||
{
|
||||
Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{
|
||||
DeviceID: "myDevice",
|
||||
UserID: alice.ID,
|
||||
KeyJSON: []byte("def"),
|
||||
}, // streamID 3
|
||||
},
|
||||
{
|
||||
Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{
|
||||
DeviceID: "myDevice",
|
||||
UserID: alice.ID,
|
||||
KeyJSON: []byte(""),
|
||||
}, // streamID 4
|
||||
},
|
||||
{
|
||||
Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{
|
||||
DeviceID: "myDevice",
|
||||
DisplayName: "first device",
|
||||
UserID: alice.ID,
|
||||
KeyJSON: []byte("ghi"),
|
||||
}, // streamID 5
|
||||
},
|
||||
{
|
||||
Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{
|
||||
DeviceID: "mySecondDevice",
|
||||
UserID: alice.ID,
|
||||
KeyJSON: []byte("jkl"),
|
||||
DisplayName: "second device",
|
||||
}, // streamID 6
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, closeDB := mustCreateDatabase(t, dbType)
|
||||
defer closeDB()
|
||||
if err := db.StoreLocalDeviceKeys(ctx, deviceMessages); err != nil {
|
||||
t.Fatalf("failed to store local devicesKeys")
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := &internal.UserInternalAPI{
|
||||
KeyDatabase: db,
|
||||
}
|
||||
if err := a.QueryDeviceMessages(ctx, tt.args.req, tt.args.res); (err != nil) != tt.wantErr {
|
||||
t.Errorf("QueryDeviceMessages() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
got := tt.args.res
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("QueryDeviceMessages(): got:\n%+v, want:\n%+v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
|
@ -23,6 +23,7 @@ import (
|
|||
"strconv"
|
||||
"time"
|
||||
|
||||
fedsenderapi "github.com/matrix-org/dendrite/federationapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
@ -32,7 +33,6 @@ import (
|
|||
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||
"github.com/matrix-org/dendrite/internal/pushgateway"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
keyapi "github.com/matrix-org/dendrite/keyserver/api"
|
||||
rsapi "github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
synctypes "github.com/matrix-org/dendrite/syncapi/types"
|
||||
|
@ -44,17 +44,19 @@ import (
|
|||
)
|
||||
|
||||
type UserInternalAPI struct {
|
||||
DB storage.Database
|
||||
SyncProducer *producers.SyncAPI
|
||||
Config *config.UserAPI
|
||||
DB storage.UserDatabase
|
||||
KeyDatabase storage.KeyDatabase
|
||||
SyncProducer *producers.SyncAPI
|
||||
KeyChangeProducer *producers.KeyChange
|
||||
Config *config.UserAPI
|
||||
|
||||
DisableTLSValidation bool
|
||||
// AppServices is the list of all registered AS
|
||||
AppServices []config.ApplicationService
|
||||
KeyAPI keyapi.UserKeyAPI
|
||||
RSAPI rsapi.UserRoomserverAPI
|
||||
PgClient pushgateway.Client
|
||||
Cfg *config.UserAPI
|
||||
FedClient fedsenderapi.KeyserverFederationAPI
|
||||
Updater *DeviceListUpdater
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error {
|
||||
|
@ -221,7 +223,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
|
|||
return fmt.Errorf("a.DB.SetDisplayName: %w", err)
|
||||
}
|
||||
|
||||
postRegisterJoinRooms(a.Cfg, acc, a.RSAPI)
|
||||
postRegisterJoinRooms(a.Config, acc, a.RSAPI)
|
||||
|
||||
res.AccountCreated = true
|
||||
res.Account = acc
|
||||
|
@ -293,14 +295,14 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe
|
|||
return err
|
||||
}
|
||||
// Ask the keyserver to delete device keys and signatures for those devices
|
||||
deleteReq := &keyapi.PerformDeleteKeysRequest{
|
||||
deleteReq := &api.PerformDeleteKeysRequest{
|
||||
UserID: req.UserID,
|
||||
}
|
||||
for _, keyID := range req.DeviceIDs {
|
||||
deleteReq.KeyIDs = append(deleteReq.KeyIDs, gomatrixserverlib.KeyID(keyID))
|
||||
}
|
||||
deleteRes := &keyapi.PerformDeleteKeysResponse{}
|
||||
if err := a.KeyAPI.PerformDeleteKeys(ctx, deleteReq, deleteRes); err != nil {
|
||||
deleteRes := &api.PerformDeleteKeysResponse{}
|
||||
if err := a.PerformDeleteKeys(ctx, deleteReq, deleteRes); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := deleteRes.Error; err != nil {
|
||||
|
@ -311,17 +313,17 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe
|
|||
}
|
||||
|
||||
func (a *UserInternalAPI) deviceListUpdate(userID string, deviceIDs []string) error {
|
||||
deviceKeys := make([]keyapi.DeviceKeys, len(deviceIDs))
|
||||
deviceKeys := make([]api.DeviceKeys, len(deviceIDs))
|
||||
for i, did := range deviceIDs {
|
||||
deviceKeys[i] = keyapi.DeviceKeys{
|
||||
deviceKeys[i] = api.DeviceKeys{
|
||||
UserID: userID,
|
||||
DeviceID: did,
|
||||
KeyJSON: nil,
|
||||
}
|
||||
}
|
||||
|
||||
var uploadRes keyapi.PerformUploadKeysResponse
|
||||
if err := a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{
|
||||
var uploadRes api.PerformUploadKeysResponse
|
||||
if err := a.PerformUploadKeys(context.Background(), &api.PerformUploadKeysRequest{
|
||||
UserID: userID,
|
||||
DeviceKeys: deviceKeys,
|
||||
}, &uploadRes); err != nil {
|
||||
|
@ -385,10 +387,10 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
|
|||
}
|
||||
if req.DisplayName != nil && dev.DisplayName != *req.DisplayName {
|
||||
// display name has changed: update the device key
|
||||
var uploadRes keyapi.PerformUploadKeysResponse
|
||||
if err := a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{
|
||||
var uploadRes api.PerformUploadKeysResponse
|
||||
if err := a.PerformUploadKeys(context.Background(), &api.PerformUploadKeysRequest{
|
||||
UserID: req.RequestingUserID,
|
||||
DeviceKeys: []keyapi.DeviceKeys{
|
||||
DeviceKeys: []api.DeviceKeys{
|
||||
{
|
||||
DeviceID: dev.ID,
|
||||
DisplayName: *req.DisplayName,
|
Loading…
Add table
Add a link
Reference in a new issue