Cross-signing groundwork (#1953)

* Cross-signing groundwork

* Update to matrix-org/gomatrixserverlib#274

* Fix gobind builds, which stops unit tests in CI from yelling

* Some changes from review comments

* Fix build by passing in UIA

* Update to matrix-org/gomatrixserverlib@bec8d22

* Process master/self-signing keys from devices call

* nolint

* Enum-ify the key type in the database

* Process self-signing key too

* Fix sanity check in device list updater

* Fix check

* Fix sytest, hopefully

* Fix build
This commit is contained in:
Neil Alexander 2021-08-04 17:56:29 +01:00 committed by GitHub
parent 4cc8b28b7f
commit eb0efa4636
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
27 changed files with 860 additions and 50 deletions

View file

@ -221,9 +221,17 @@ func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.Query
func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) {
res.DeviceKeys = make(map[string]map[string]json.RawMessage)
res.MasterKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
res.SelfSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
res.UserSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
res.Failures = make(map[string]interface{})
// get cross-signing keys from the database
a.crossSigningKeysFromDatabase(ctx, req, res)
// make a map from domain to device keys
domainToDeviceKeys := make(map[string]map[string][]string)
domainToCrossSigningKeys := make(map[string]map[string]struct{})
for userID, deviceIDs := range req.UserToDevices {
_, serverName, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
@ -274,16 +282,30 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
domainToDeviceKeys[domain] = make(map[string][]string)
domainToDeviceKeys[domain][userID] = append(domainToDeviceKeys[domain][userID], deviceIDs...)
}
// work out if our cross-signing request for this user was
// satisfied, if not add them to the list of things to fetch
if _, ok := res.MasterKeys[userID]; !ok {
if _, ok := domainToCrossSigningKeys[domain]; !ok {
domainToCrossSigningKeys[domain] = make(map[string]struct{})
}
domainToCrossSigningKeys[domain][userID] = struct{}{}
}
if _, ok := res.SelfSigningKeys[userID]; !ok {
if _, ok := domainToCrossSigningKeys[domain]; !ok {
domainToCrossSigningKeys[domain] = make(map[string]struct{})
}
domainToCrossSigningKeys[domain][userID] = struct{}{}
}
}
// attempt to satisfy key queries from the local database first as we should get device updates pushed to us
domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, domainToDeviceKeys)
if len(domainToDeviceKeys) == 0 {
if len(domainToDeviceKeys) == 0 && len(domainToCrossSigningKeys) == 0 {
return // nothing to query
}
// perform key queries for remote devices
a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys)
a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys, domainToCrossSigningKeys)
}
func (a *KeyInternalAPI) remoteKeysFromDatabase(
@ -313,18 +335,30 @@ func (a *KeyInternalAPI) remoteKeysFromDatabase(
}
func (a *KeyInternalAPI) queryRemoteKeys(
ctx context.Context, timeout time.Duration, res *api.QueryKeysResponse, domainToDeviceKeys map[string]map[string][]string,
ctx context.Context, timeout time.Duration, res *api.QueryKeysResponse,
domainToDeviceKeys map[string]map[string][]string, domainToCrossSigningKeys map[string]map[string]struct{},
) {
resultCh := make(chan *gomatrixserverlib.RespQueryKeys, len(domainToDeviceKeys))
// allows us to wait until all federation servers have been poked
var wg sync.WaitGroup
wg.Add(len(domainToDeviceKeys))
// mutex for writing directly to res (e.g failures)
var respMu sync.Mutex
domains := map[string]struct{}{}
for domain := range domainToDeviceKeys {
domains[domain] = struct{}{}
}
for domain := range domainToCrossSigningKeys {
domains[domain] = struct{}{}
}
wg.Add(len(domains))
// fan out
for domain, deviceKeys := range domainToDeviceKeys {
go a.queryRemoteKeysOnServer(ctx, domain, deviceKeys, &wg, &respMu, timeout, resultCh, res)
for domain := range domains {
go a.queryRemoteKeysOnServer(
ctx, domain, domainToDeviceKeys[domain], domainToCrossSigningKeys[domain],
&wg, &respMu, timeout, resultCh, res,
)
}
// Close the result channel when the goroutines have quit so the for .. range exits
@ -344,12 +378,29 @@ func (a *KeyInternalAPI) queryRemoteKeys(
res.DeviceKeys[userID][deviceID] = keyJSON
}
}
for userID, body := range result.MasterKeys {
switch b := body.CrossSigningBody.(type) {
case *gomatrixserverlib.CrossSigningKey:
res.MasterKeys[userID] = *b
}
}
for userID, body := range result.SelfSigningKeys {
switch b := body.CrossSigningBody.(type) {
case *gomatrixserverlib.CrossSigningKey:
res.SelfSigningKeys[userID] = *b
}
}
// TODO: do we want to persist these somewhere now
// that we have fetched them?
}
}
func (a *KeyInternalAPI) queryRemoteKeysOnServer(
ctx context.Context, serverName string, devKeys map[string][]string, wg *sync.WaitGroup,
respMu *sync.Mutex, timeout time.Duration, resultCh chan<- *gomatrixserverlib.RespQueryKeys,
ctx context.Context, serverName string, devKeys map[string][]string, crossSigningKeys map[string]struct{},
wg *sync.WaitGroup, respMu *sync.Mutex, timeout time.Duration, resultCh chan<- *gomatrixserverlib.RespQueryKeys,
res *api.QueryKeysResponse,
) {
defer wg.Done()
@ -358,14 +409,24 @@ func (a *KeyInternalAPI) queryRemoteKeysOnServer(
// for users who we do not have any knowledge about, try to start doing device list updates for them
// by hitting /users/devices - otherwise fallback to /keys/query which has nicer bulk properties but
// lack a stream ID.
var userIDsForAllDevices []string
userIDsForAllDevices := map[string]struct{}{}
for userID, deviceIDs := range devKeys {
if len(deviceIDs) == 0 {
userIDsForAllDevices = append(userIDsForAllDevices, userID)
userIDsForAllDevices[userID] = struct{}{}
delete(devKeys, userID)
}
}
for _, userID := range userIDsForAllDevices {
// for cross-signing keys, it's probably easier just to hit /keys/query if we aren't already doing
// a device list update, so we'll populate those back into the /keys/query list if not
for userID := range crossSigningKeys {
if devKeys == nil {
devKeys = map[string][]string{}
}
if _, ok := userIDsForAllDevices[userID]; !ok {
devKeys[userID] = []string{}
}
}
for userID := range userIDsForAllDevices {
err := a.Updater.ManualUpdate(context.Background(), gomatrixserverlib.ServerName(serverName), userID)
if err != nil {
logrus.WithFields(logrus.Fields{