From 8630960141e0a005640f005ae73d4faa6c2e660c Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 30 Jul 2021 11:04:26 +0100 Subject: [PATCH] Try this again --- keyserver/internal/internal.go | 39 +++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 9d037876..eb87fae1 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -231,6 +231,7 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques // make a map from domain to device keys domainToDeviceKeys := make(map[string]map[string][]string) + domainToCrossSigningKeys := make(map[string]map[string]struct{}) for userID, deviceIDs := range req.UserToDevices { _, serverName, err := gomatrixserverlib.SplitID('@', userID) if err != nil { @@ -284,37 +285,41 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques // work out if our cross-signing request for this user was // satisfied, if not add them to the list of things to fetch if _, ok := res.MasterKeys[userID]; !ok { - util.GetLogger(ctx).Infof("No cross-signing master keys for %s found", userID) - if _, ok := domainToDeviceKeys[domain]; !ok { - domainToDeviceKeys[domain] = make(map[string][]string) + if _, ok := domainToCrossSigningKeys[domain]; !ok { + domainToCrossSigningKeys[domain] = make(map[string]struct{}) } - if _, ok := domainToDeviceKeys[domain][userID]; !ok { - util.GetLogger(ctx).Infof("Request cross-signing keys from %s for %s", domain, userID) - domainToDeviceKeys[domain][userID] = []string{} - } else { - util.GetLogger(ctx).Infof("Already requesting keys from %s for %s", domain, userID) + if _, ok := domainToCrossSigningKeys[domain][userID]; !ok { + domainToCrossSigningKeys[domain][userID] = struct{}{} } } if _, ok := res.SelfSigningKeys[userID]; !ok { - util.GetLogger(ctx).Infof("No cross-signing self-signing keys for %s found", userID) - if _, ok := domainToDeviceKeys[domain]; !ok { - domainToDeviceKeys[domain] = make(map[string][]string) + if _, ok := domainToCrossSigningKeys[domain]; !ok { + domainToCrossSigningKeys[domain] = make(map[string]struct{}) } - if _, ok := domainToDeviceKeys[domain][userID]; !ok { - util.GetLogger(ctx).Infof("Request cross-signing keys from %s for %s", domain, userID) - domainToDeviceKeys[domain][userID] = []string{} - } else { - util.GetLogger(ctx).Infof("Already requesting keys from %s for %s", domain, userID) + if _, ok := domainToCrossSigningKeys[domain][userID]; !ok { + domainToCrossSigningKeys[domain][userID] = struct{}{} } } } // attempt to satisfy key queries from the local database first as we should get device updates pushed to us domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, domainToDeviceKeys) - if len(domainToDeviceKeys) == 0 { + if len(domainToDeviceKeys) == 0 && len(domainToCrossSigningKeys) == 0 { return // nothing to query } + // add in any cross-signing requests that need to be made to the list + for domain, forDomain := range domainToCrossSigningKeys { + for userID := range forDomain { + if _, ok := domainToDeviceKeys[domain]; !ok { + domainToDeviceKeys[domain] = make(map[string][]string) + } + if _, ok := domainToDeviceKeys[domain][userID]; !ok { + domainToDeviceKeys[domain][userID] = []string{} + } + } + } + // perform key queries for remote devices a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys) }