From 1fed50767954c86347cb3004701a1acbce53762b Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 30 Jul 2021 10:23:36 +0100 Subject: [PATCH] Check per-user --- keyserver/internal/internal.go | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 21f4793f..2873b604 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -227,17 +227,13 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques res.Failures = make(map[string]interface{}) // get cross-signing keys from the database - crossSigningSatisfiedLocally := true if err := a.crossSigningKeys(ctx, req, res); err != nil { util.GetLogger(ctx).WithError(err).Error("Failed to retrieve cross-signing keys from database") - crossSigningSatisfiedLocally = false - } - if len(res.MasterKeys) == 0 || len(res.SelfSigningKeys) == 0 { - crossSigningSatisfiedLocally = false } // make a map from domain to device keys domainToDeviceKeys := make(map[string]map[string][]string) + domainToCrossSigningKeys := make(map[string]struct{}) for userID, deviceIDs := range req.UserToDevices { _, serverName, err := gomatrixserverlib.SplitID('@', userID) if err != nil { @@ -288,11 +284,21 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques domainToDeviceKeys[domain] = make(map[string][]string) domainToDeviceKeys[domain][userID] = append(domainToDeviceKeys[domain][userID], deviceIDs...) } + // work out if our cross-signing request for this user was + // satisfied + if _, ok := domainToCrossSigningKeys[userID]; !ok { + if _, ok := res.MasterKeys[userID]; !ok { + domainToCrossSigningKeys[userID] = struct{}{} + } + if _, ok := res.SelfSigningKeys[userID]; !ok { + domainToCrossSigningKeys[userID] = struct{}{} + } + } } // attempt to satisfy key queries from the local database first as we should get device updates pushed to us domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, domainToDeviceKeys) - if len(domainToDeviceKeys) == 0 && crossSigningSatisfiedLocally { + if len(domainToDeviceKeys) == 0 && len(domainToCrossSigningKeys) == 0 { return // nothing to query }