Merge branch 'master' into add-nats-support

This commit is contained in:
Neil Alexander 2021-11-02 17:36:22 +00:00
commit 73d6964fb4
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
205 changed files with 5074 additions and 1217 deletions

View file

@ -20,6 +20,8 @@ import (
"strings"
"time"
eduapi "github.com/matrix-org/dendrite/eduserver/api"
"github.com/matrix-org/dendrite/keyserver/types"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
)
@ -32,24 +34,40 @@ type KeyInternalAPI interface {
PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse)
// PerformClaimKeys claims one-time keys for use in pre-key messages
PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse)
PerformDeleteKeys(ctx context.Context, req *PerformDeleteKeysRequest, res *PerformDeleteKeysResponse)
PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse)
PerformUploadDeviceSignatures(ctx context.Context, req *PerformUploadDeviceSignaturesRequest, res *PerformUploadDeviceSignaturesResponse)
QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse)
QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse)
QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse)
QueryDeviceMessages(ctx context.Context, req *QueryDeviceMessagesRequest, res *QueryDeviceMessagesResponse)
QuerySignatures(ctx context.Context, req *QuerySignaturesRequest, res *QuerySignaturesResponse)
}
// KeyError is returned if there was a problem performing/querying the server
type KeyError struct {
Err string
Err string `json:"error"`
IsInvalidSignature bool `json:"is_invalid_signature,omitempty"` // M_INVALID_SIGNATURE
IsMissingParam bool `json:"is_missing_param,omitempty"` // M_MISSING_PARAM
IsInvalidParam bool `json:"is_invalid_param,omitempty"` // M_INVALID_PARAM
}
func (k *KeyError) Error() string {
return k.Err
}
type DeviceMessageType int
const (
TypeDeviceKeyUpdate DeviceMessageType = iota
TypeCrossSigningUpdate
)
// DeviceMessage represents the message produced into Kafka by the key server.
type DeviceMessage struct {
DeviceKeys
Type DeviceMessageType `json:"Type,omitempty"`
*DeviceKeys `json:"DeviceKeys,omitempty"`
*eduapi.OutputCrossSigningKeyUpdate `json:"CrossSigningKeyUpdate,omitempty"`
// A monotonically increasing number which represents device changes for this user.
StreamID int
}
@ -70,7 +88,7 @@ type DeviceKeys struct {
// WithStreamID returns a copy of this device message with the given stream ID
func (k *DeviceKeys) WithStreamID(streamID int) DeviceMessage {
return DeviceMessage{
DeviceKeys: *k,
DeviceKeys: k,
StreamID: streamID,
}
}
@ -128,6 +146,18 @@ type PerformUploadKeysResponse struct {
OneTimeKeyCounts []OneTimeKeysCount
}
// PerformDeleteKeysRequest asks the keyserver to forget about certain
// keys, and signatures related to those keys.
type PerformDeleteKeysRequest struct {
UserID string
KeyIDs []gomatrixserverlib.KeyID
}
// PerformDeleteKeysResponse is the response to PerformDeleteKeysRequest.
type PerformDeleteKeysResponse struct {
Error *KeyError
}
// KeyError sets a key error field on KeyErrors
func (r *PerformUploadKeysResponse) KeyError(userID, deviceID string, err *KeyError) {
if r.KeyErrors[userID] == nil {
@ -151,7 +181,30 @@ type PerformClaimKeysResponse struct {
Error *KeyError
}
type PerformUploadDeviceKeysRequest struct {
gomatrixserverlib.CrossSigningKeys
// The user that uploaded the key, should be populated by the clientapi.
UserID string
}
type PerformUploadDeviceKeysResponse struct {
Error *KeyError
}
type PerformUploadDeviceSignaturesRequest struct {
Signatures map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice
// The user that uploaded the sig, should be populated by the clientapi.
UserID string
}
type PerformUploadDeviceSignaturesResponse struct {
Error *KeyError
}
type QueryKeysRequest struct {
// The user ID asking for the keys, e.g. if from a client API request.
// Will not be populated if the key request came from federation.
UserID string
// Maps user IDs to a list of devices
UserToDevices map[string][]string
Timeout time.Duration
@ -162,6 +215,10 @@ type QueryKeysResponse struct {
Failures map[string]interface{}
// Map of user_id to device_id to device_key
DeviceKeys map[string]map[string]json.RawMessage
// Maps of user_id to cross signing key
MasterKeys map[string]gomatrixserverlib.CrossSigningKey
SelfSigningKeys map[string]gomatrixserverlib.CrossSigningKey
UserSigningKeys map[string]gomatrixserverlib.CrossSigningKey
// Set if there was a fatal error processing this query
Error *KeyError
}
@ -211,6 +268,24 @@ type QueryDeviceMessagesResponse struct {
Error *KeyError
}
type QuerySignaturesRequest struct {
// A map of target user ID -> target key/device IDs to retrieve signatures for
TargetIDs map[string][]gomatrixserverlib.KeyID `json:"target_ids"`
}
type QuerySignaturesResponse struct {
// A map of target user ID -> target key/device ID -> origin user ID -> origin key/device ID -> signatures
Signatures map[string]map[gomatrixserverlib.KeyID]types.CrossSigningSigMap
// A map of target user ID -> cross-signing master key
MasterKeys map[string]gomatrixserverlib.CrossSigningKey
// A map of target user ID -> cross-signing self-signing key
SelfSigningKeys map[string]gomatrixserverlib.CrossSigningKey
// A map of target user ID -> cross-signing user-signing key
UserSigningKeys map[string]gomatrixserverlib.CrossSigningKey
// The request error, if any
Error *KeyError
}
type InputDeviceListUpdateRequest struct {
Event gomatrixserverlib.DeviceListUpdateEvent
}

View file

@ -0,0 +1,112 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package consumers
import (
"context"
"encoding/json"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
"github.com/Shopify/sarama"
)
type OutputCrossSigningKeyUpdateConsumer struct {
eduServerConsumer *internal.ContinualConsumer
keyDB storage.Database
keyAPI api.KeyInternalAPI
serverName string
}
func NewOutputCrossSigningKeyUpdateConsumer(
process *process.ProcessContext,
cfg *config.Dendrite,
kafkaConsumer sarama.Consumer,
keyDB storage.Database,
keyAPI api.KeyInternalAPI,
) *OutputCrossSigningKeyUpdateConsumer {
// The keyserver both produces and consumes on the TopicOutputKeyChangeEvent
// topic. We will only produce events where the UserID matches our server name,
// and we will only consume events where the UserID does NOT match our server
// name (because the update came from a remote server).
consumer := internal.ContinualConsumer{
Process: process,
ComponentName: "keyserver/keyserver",
Topic: cfg.Global.Kafka.TopicFor(config.TopicOutputKeyChangeEvent),
Consumer: kafkaConsumer,
PartitionStore: keyDB,
}
s := &OutputCrossSigningKeyUpdateConsumer{
eduServerConsumer: &consumer,
keyDB: keyDB,
keyAPI: keyAPI,
serverName: string(cfg.Global.ServerName),
}
consumer.ProcessMessage = s.onMessage
return s
}
func (s *OutputCrossSigningKeyUpdateConsumer) Start() error {
return s.eduServerConsumer.Start()
}
// onMessage is called in response to a message received on the
// key change events topic from the key server.
func (t *OutputCrossSigningKeyUpdateConsumer) onMessage(msg *sarama.ConsumerMessage) error {
var m api.DeviceMessage
if err := json.Unmarshal(msg.Value, &m); err != nil {
logrus.WithError(err).Errorf("failed to read device message from key change topic")
return nil
}
switch m.Type {
case api.TypeCrossSigningUpdate:
return t.onCrossSigningMessage(m)
default:
return nil
}
}
func (s *OutputCrossSigningKeyUpdateConsumer) onCrossSigningMessage(m api.DeviceMessage) error {
output := m.CrossSigningKeyUpdate
_, host, err := gomatrixserverlib.SplitID('@', output.UserID)
if err != nil {
logrus.WithError(err).Errorf("eduserver output log: user ID parse failure")
return nil
}
if host == gomatrixserverlib.ServerName(s.serverName) {
// Ignore any messages that contain information about our own users, as
// they already originated from this server.
return nil
}
uploadReq := &api.PerformUploadDeviceKeysRequest{
UserID: output.UserID,
}
if output.MasterKey != nil {
uploadReq.MasterKey = *output.MasterKey
}
if output.SelfSigningKey != nil {
uploadReq.SelfSigningKey = *output.SelfSigningKey
}
uploadRes := &api.PerformUploadDeviceKeysResponse{}
s.keyAPI.PerformUploadDeviceKeys(context.TODO(), uploadReq, uploadRes)
return uploadRes.Error
}

View file

@ -0,0 +1,550 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package internal
import (
"bytes"
"context"
"crypto/ed25519"
"database/sql"
"fmt"
"strings"
eduserverAPI "github.com/matrix-org/dendrite/eduserver/api"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/curve25519"
)
func sanityCheckKey(key gomatrixserverlib.CrossSigningKey, userID string, purpose gomatrixserverlib.CrossSigningKeyPurpose) error {
// Is there exactly one key?
if len(key.Keys) != 1 {
return fmt.Errorf("should contain exactly one key")
}
// Does the key ID match the key value? Iterates exactly once
for keyID, keyData := range key.Keys {
b64 := keyData.Encode()
tokens := strings.Split(string(keyID), ":")
if len(tokens) != 2 {
return fmt.Errorf("key ID is incorrectly formatted")
}
if tokens[1] != b64 {
return fmt.Errorf("key ID isn't correct")
}
switch tokens[0] {
case "ed25519":
if len(keyData) != ed25519.PublicKeySize {
return fmt.Errorf("ed25519 key is not the correct length")
}
case "curve25519":
if len(keyData) != curve25519.PointSize {
return fmt.Errorf("curve25519 key is not the correct length")
}
default:
// We can't enforce the key length to be correct for an
// algorithm that we don't recognise, so instead we'll
// just make sure that it isn't incredibly excessive.
if l := len(keyData); l > 4096 {
return fmt.Errorf("unknown key type is too long (%d bytes)", l)
}
}
}
// Check to see if the signatures make sense
for _, forOriginUser := range key.Signatures {
for originKeyID, originSignature := range forOriginUser {
switch strings.SplitN(string(originKeyID), ":", 1)[0] {
case "ed25519":
if len(originSignature) != ed25519.SignatureSize {
return fmt.Errorf("ed25519 signature is not the correct length")
}
case "curve25519":
return fmt.Errorf("curve25519 signatures are impossible")
default:
if l := len(originSignature); l > 4096 {
return fmt.Errorf("unknown signature type is too long (%d bytes)", l)
}
}
}
}
// Does the key claim to be from the right user?
if userID != key.UserID {
return fmt.Errorf("key has a user ID mismatch")
}
// Does the key contain the correct purpose?
useful := false
for _, usage := range key.Usage {
if usage == purpose {
useful = true
break
}
}
if !useful {
return fmt.Errorf("key does not contain correct usage purpose")
}
return nil
}
// nolint:gocyclo
func (a *KeyInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) {
// Find the keys to store.
byPurpose := map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey{}
toStore := types.CrossSigningKeyMap{}
hasMasterKey := false
if len(req.MasterKey.Keys) > 0 {
if err := sanityCheckKey(req.MasterKey, req.UserID, gomatrixserverlib.CrossSigningKeyPurposeMaster); err != nil {
res.Error = &api.KeyError{
Err: "Master key sanity check failed: " + err.Error(),
IsInvalidParam: true,
}
return
}
byPurpose[gomatrixserverlib.CrossSigningKeyPurposeMaster] = req.MasterKey
for _, key := range req.MasterKey.Keys { // iterates once, see sanityCheckKey
toStore[gomatrixserverlib.CrossSigningKeyPurposeMaster] = key
}
hasMasterKey = true
}
if len(req.SelfSigningKey.Keys) > 0 {
if err := sanityCheckKey(req.SelfSigningKey, req.UserID, gomatrixserverlib.CrossSigningKeyPurposeSelfSigning); err != nil {
res.Error = &api.KeyError{
Err: "Self-signing key sanity check failed: " + err.Error(),
IsInvalidParam: true,
}
return
}
byPurpose[gomatrixserverlib.CrossSigningKeyPurposeSelfSigning] = req.SelfSigningKey
for _, key := range req.SelfSigningKey.Keys { // iterates once, see sanityCheckKey
toStore[gomatrixserverlib.CrossSigningKeyPurposeSelfSigning] = key
}
}
if len(req.UserSigningKey.Keys) > 0 {
if err := sanityCheckKey(req.UserSigningKey, req.UserID, gomatrixserverlib.CrossSigningKeyPurposeUserSigning); err != nil {
res.Error = &api.KeyError{
Err: "User-signing key sanity check failed: " + err.Error(),
IsInvalidParam: true,
}
return
}
byPurpose[gomatrixserverlib.CrossSigningKeyPurposeUserSigning] = req.UserSigningKey
for _, key := range req.UserSigningKey.Keys { // iterates once, see sanityCheckKey
toStore[gomatrixserverlib.CrossSigningKeyPurposeUserSigning] = key
}
}
// If there's nothing to do then stop here.
if len(toStore) == 0 {
res.Error = &api.KeyError{
Err: "No keys were supplied in the request",
IsMissingParam: true,
}
return
}
// We can't have a self-signing or user-signing key without a master
// key, so make sure we have one of those.
if !hasMasterKey {
existingKeys, err := a.DB.CrossSigningKeysDataForUser(ctx, req.UserID)
if err != nil {
res.Error = &api.KeyError{
Err: "Retrieving cross-signing keys from database failed: " + err.Error(),
}
return
}
_, hasMasterKey = existingKeys[gomatrixserverlib.CrossSigningKeyPurposeMaster]
}
// If we still can't find a master key for the user then stop the upload.
// This satisfies the "Fails to upload self-signing key without master key" test.
if !hasMasterKey {
res.Error = &api.KeyError{
Err: "No master key was found",
IsMissingParam: true,
}
return
}
// Store the keys.
if err := a.DB.StoreCrossSigningKeysForUser(ctx, req.UserID, toStore); err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("a.DB.StoreCrossSigningKeysForUser: %s", err),
}
return
}
// Now upload any signatures that were included with the keys.
for _, key := range byPurpose {
var targetKeyID gomatrixserverlib.KeyID
for targetKey := range key.Keys { // iterates once, see sanityCheckKey
targetKeyID = targetKey
}
for sigUserID, forSigUserID := range key.Signatures {
if sigUserID != req.UserID {
continue
}
for sigKeyID, sigBytes := range forSigUserID {
if err := a.DB.StoreCrossSigningSigsForTarget(ctx, sigUserID, sigKeyID, req.UserID, targetKeyID, sigBytes); err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("a.DB.StoreCrossSigningSigsForTarget: %s", err),
}
return
}
}
}
}
// Finally, generate a notification that we updated the keys.
if _, host, err := gomatrixserverlib.SplitID('@', req.UserID); err == nil && host == a.ThisServer {
update := eduserverAPI.CrossSigningKeyUpdate{
UserID: req.UserID,
}
if mk, ok := byPurpose[gomatrixserverlib.CrossSigningKeyPurposeMaster]; ok {
update.MasterKey = &mk
}
if ssk, ok := byPurpose[gomatrixserverlib.CrossSigningKeyPurposeSelfSigning]; ok {
update.SelfSigningKey = &ssk
}
if update.MasterKey == nil && update.SelfSigningKey == nil {
return
}
if err := a.Producer.ProduceSigningKeyUpdate(update); err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err),
}
return
}
}
}
func (a *KeyInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req *api.PerformUploadDeviceSignaturesRequest, res *api.PerformUploadDeviceSignaturesResponse) {
// Before we do anything, we need the master and self-signing keys for this user.
// Then we can verify the signatures make sense.
queryReq := &api.QueryKeysRequest{
UserID: req.UserID,
UserToDevices: map[string][]string{},
}
queryRes := &api.QueryKeysResponse{}
for userID := range req.Signatures {
queryReq.UserToDevices[userID] = []string{}
}
a.QueryKeys(ctx, queryReq, queryRes)
selfSignatures := map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{}
otherSignatures := map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{}
// Sort signatures into two groups: one where people have signed their own
// keys and one where people have signed someone elses
for userID, forUserID := range req.Signatures {
for keyID, keyOrDevice := range forUserID {
switch key := keyOrDevice.CrossSigningBody.(type) {
case *gomatrixserverlib.CrossSigningKey:
if key.UserID == req.UserID {
if _, ok := selfSignatures[userID]; !ok {
selfSignatures[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{}
}
selfSignatures[userID][keyID] = keyOrDevice
} else {
if _, ok := otherSignatures[userID]; !ok {
otherSignatures[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{}
}
otherSignatures[userID][keyID] = keyOrDevice
}
case *gomatrixserverlib.DeviceKeys:
if key.UserID == req.UserID {
if _, ok := selfSignatures[userID]; !ok {
selfSignatures[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{}
}
selfSignatures[userID][keyID] = keyOrDevice
} else {
if _, ok := otherSignatures[userID]; !ok {
otherSignatures[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{}
}
otherSignatures[userID][keyID] = keyOrDevice
}
default:
continue
}
}
}
if err := a.processSelfSignatures(ctx, selfSignatures); err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("a.processSelfSignatures: %s", err),
}
return
}
if err := a.processOtherSignatures(ctx, req.UserID, queryRes, otherSignatures); err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("a.processOtherSignatures: %s", err),
}
return
}
// Finally, generate a notification that we updated the signatures.
for userID := range req.Signatures {
if _, host, err := gomatrixserverlib.SplitID('@', userID); err == nil && host == a.ThisServer {
update := eduserverAPI.CrossSigningKeyUpdate{
UserID: userID,
}
if err := a.Producer.ProduceSigningKeyUpdate(update); err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err),
}
return
}
}
}
}
func (a *KeyInternalAPI) processSelfSignatures(
ctx context.Context,
signatures map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice,
) error {
// Here we will process:
// * The user signing their own devices using their self-signing key
// * The user signing their master key using one of their devices
for targetUserID, forTargetUserID := range signatures {
for targetKeyID, signature := range forTargetUserID {
switch sig := signature.CrossSigningBody.(type) {
case *gomatrixserverlib.CrossSigningKey:
for originUserID, forOriginUserID := range sig.Signatures {
for originKeyID, originSig := range forOriginUserID {
if err := a.DB.StoreCrossSigningSigsForTarget(
ctx, originUserID, originKeyID, targetUserID, targetKeyID, originSig,
); err != nil {
return fmt.Errorf("a.DB.StoreCrossSigningKeysForTarget: %w", err)
}
}
}
case *gomatrixserverlib.DeviceKeys:
for originUserID, forOriginUserID := range sig.Signatures {
for originKeyID, originSig := range forOriginUserID {
if err := a.DB.StoreCrossSigningSigsForTarget(
ctx, originUserID, originKeyID, targetUserID, targetKeyID, originSig,
); err != nil {
return fmt.Errorf("a.DB.StoreCrossSigningKeysForTarget: %w", err)
}
}
}
default:
return fmt.Errorf("unexpected type assertion")
}
}
}
return nil
}
func (a *KeyInternalAPI) processOtherSignatures(
ctx context.Context, userID string, queryRes *api.QueryKeysResponse,
signatures map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice,
) error {
// Here we will process:
// * A user signing someone else's master keys using their user-signing keys
for targetUserID, forTargetUserID := range signatures {
for _, signature := range forTargetUserID {
switch sig := signature.CrossSigningBody.(type) {
case *gomatrixserverlib.CrossSigningKey:
// Find the local copy of the master key. We'll use this to be
// sure that the supplied stanza matches the key that we think it
// should be.
masterKey, ok := queryRes.MasterKeys[targetUserID]
if !ok {
return fmt.Errorf("failed to find master key for user %q", targetUserID)
}
// For each key ID, write the signatures. Maybe there'll be more
// than one algorithm in the future so it's best not to focus on
// everything being ed25519:.
for targetKeyID, suppliedKeyData := range sig.Keys {
// The master key will be supplied in the request, but we should
// make sure that it matches what we think the master key should
// actually be.
localKeyData, lok := masterKey.Keys[targetKeyID]
if !lok {
return fmt.Errorf("uploaded master key %q for user %q doesn't match local copy", targetKeyID, targetUserID)
} else if !bytes.Equal(suppliedKeyData, localKeyData) {
return fmt.Errorf("uploaded master key %q for user %q doesn't match local copy", targetKeyID, targetUserID)
}
// We only care about the signatures from the uploading user, so
// we will ignore anything that didn't originate from them.
userSigs, ok := sig.Signatures[userID]
if !ok {
return fmt.Errorf("there are no signatures on master key %q from uploading user %q", targetKeyID, userID)
}
for originKeyID, originSig := range userSigs {
if err := a.DB.StoreCrossSigningSigsForTarget(
ctx, userID, originKeyID, targetUserID, targetKeyID, originSig,
); err != nil {
return fmt.Errorf("a.DB.StoreCrossSigningKeysForTarget: %w", err)
}
}
}
default:
// Users should only be signing another person's master key,
// so if we're here, it's probably because it's actually a
// gomatrixserverlib.DeviceKeys, which doesn't make sense.
}
}
}
return nil
}
func (a *KeyInternalAPI) crossSigningKeysFromDatabase(
ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse,
) {
for userID := range req.UserToDevices {
keys, err := a.DB.CrossSigningKeysForUser(ctx, userID)
if err != nil {
logrus.WithError(err).Errorf("Failed to get cross-signing keys for user %q", userID)
continue
}
for keyType, key := range keys {
var keyID gomatrixserverlib.KeyID
for id := range key.Keys {
keyID = id
break
}
sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, userID, keyID)
if err != nil && err != sql.ErrNoRows {
logrus.WithError(err).Errorf("Failed to get cross-signing signatures for user %q key %q", userID, keyID)
continue
}
appendSignature := func(originUserID string, originKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) {
if key.Signatures == nil {
key.Signatures = types.CrossSigningSigMap{}
}
if _, ok := key.Signatures[originUserID]; !ok {
key.Signatures[originUserID] = make(map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes)
}
key.Signatures[originUserID][originKeyID] = signature
}
for originUserID, forOrigin := range sigMap {
for originKeyID, signature := range forOrigin {
switch {
case req.UserID != "" && originUserID == req.UserID:
// Include signatures that we created
appendSignature(originUserID, originKeyID, signature)
case originUserID == userID:
// Include signatures that were created by the person whose key
// we are processing
appendSignature(originUserID, originKeyID, signature)
}
}
}
switch keyType {
case gomatrixserverlib.CrossSigningKeyPurposeMaster:
res.MasterKeys[userID] = key
case gomatrixserverlib.CrossSigningKeyPurposeSelfSigning:
res.SelfSigningKeys[userID] = key
case gomatrixserverlib.CrossSigningKeyPurposeUserSigning:
res.UserSigningKeys[userID] = key
}
}
}
}
func (a *KeyInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySignaturesRequest, res *api.QuerySignaturesResponse) {
for targetUserID, forTargetUser := range req.TargetIDs {
keyMap, err := a.DB.CrossSigningKeysForUser(ctx, targetUserID)
if err != nil && err != sql.ErrNoRows {
res.Error = &api.KeyError{
Err: fmt.Sprintf("a.DB.CrossSigningKeysForUser: %s", err),
}
continue
}
for targetPurpose, targetKey := range keyMap {
switch targetPurpose {
case gomatrixserverlib.CrossSigningKeyPurposeMaster:
if res.MasterKeys == nil {
res.MasterKeys = map[string]gomatrixserverlib.CrossSigningKey{}
}
res.MasterKeys[targetUserID] = targetKey
case gomatrixserverlib.CrossSigningKeyPurposeSelfSigning:
if res.SelfSigningKeys == nil {
res.SelfSigningKeys = map[string]gomatrixserverlib.CrossSigningKey{}
}
res.SelfSigningKeys[targetUserID] = targetKey
case gomatrixserverlib.CrossSigningKeyPurposeUserSigning:
if res.UserSigningKeys == nil {
res.UserSigningKeys = map[string]gomatrixserverlib.CrossSigningKey{}
}
res.UserSigningKeys[targetUserID] = targetKey
}
}
for _, targetKeyID := range forTargetUser {
sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, targetUserID, targetKeyID)
if err != nil && err != sql.ErrNoRows {
res.Error = &api.KeyError{
Err: fmt.Sprintf("a.DB.CrossSigningSigsForTarget: %s", err),
}
return
}
for sourceUserID, forSourceUser := range sigMap {
for sourceKeyID, sourceSig := range forSourceUser {
if res.Signatures == nil {
res.Signatures = map[string]map[gomatrixserverlib.KeyID]types.CrossSigningSigMap{}
}
if _, ok := res.Signatures[targetUserID]; !ok {
res.Signatures[targetUserID] = map[gomatrixserverlib.KeyID]types.CrossSigningSigMap{}
}
if _, ok := res.Signatures[targetUserID][targetKeyID]; !ok {
res.Signatures[targetUserID][targetKeyID] = types.CrossSigningSigMap{}
}
if _, ok := res.Signatures[targetUserID][targetKeyID][sourceUserID]; !ok {
res.Signatures[targetUserID][targetKeyID][sourceUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
}
res.Signatures[targetUserID][targetKeyID][sourceUserID][sourceKeyID] = sourceSig
}
}
}
}
}

View file

@ -82,6 +82,7 @@ type DeviceListUpdater struct {
mu *sync.Mutex // protects UserIDToMutex
db DeviceListUpdaterDatabase
api DeviceListUpdaterAPI
producer KeyChangeProducer
fedClient fedsenderapi.FederationClient
workerChans []chan gomatrixserverlib.ServerName
@ -114,6 +115,10 @@ type DeviceListUpdaterDatabase interface {
DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
}
type DeviceListUpdaterAPI interface {
PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse)
}
// KeyChangeProducer is the interface for producers.KeyChange useful for testing.
type KeyChangeProducer interface {
ProduceKeyChanges(keys []api.DeviceMessage) error
@ -121,13 +126,14 @@ type KeyChangeProducer interface {
// NewDeviceListUpdater creates a new updater which fetches fresh device lists when they go stale.
func NewDeviceListUpdater(
db DeviceListUpdaterDatabase, producer KeyChangeProducer, fedClient fedsenderapi.FederationClient,
numWorkers int,
db DeviceListUpdaterDatabase, api DeviceListUpdaterAPI, producer KeyChangeProducer,
fedClient fedsenderapi.FederationClient, numWorkers int,
) *DeviceListUpdater {
return &DeviceListUpdater{
userIDToMutex: make(map[string]*sync.Mutex),
mu: &sync.Mutex{},
db: db,
api: api,
producer: producer,
fedClient: fedClient,
workerChans: make([]chan gomatrixserverlib.ServerName, numWorkers),
@ -225,7 +231,8 @@ func (u *DeviceListUpdater) update(ctx context.Context, event gomatrixserverlib.
}
keys := []api.DeviceMessage{
{
DeviceKeys: api.DeviceKeys{
Type: api.TypeDeviceKeyUpdate,
DeviceKeys: &api.DeviceKeys{
DeviceID: event.DeviceID,
DisplayName: event.DeviceDisplayName,
KeyJSON: k,
@ -367,6 +374,23 @@ func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerNam
}
continue
}
if res.MasterKey != nil || res.SelfSigningKey != nil {
uploadReq := &api.PerformUploadDeviceKeysRequest{
UserID: userID,
}
uploadRes := &api.PerformUploadDeviceKeysResponse{}
if res.MasterKey != nil {
if err = sanityCheckKey(*res.MasterKey, userID, gomatrixserverlib.CrossSigningKeyPurposeMaster); err == nil {
uploadReq.MasterKey = *res.MasterKey
}
}
if res.SelfSigningKey != nil {
if err = sanityCheckKey(*res.SelfSigningKey, userID, gomatrixserverlib.CrossSigningKeyPurposeSelfSigning); err == nil {
uploadReq.SelfSigningKey = *res.SelfSigningKey
}
}
u.api.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes)
}
err = u.updateDeviceList(&res)
if err != nil {
logger.WithError(err).WithField("user_id", userID).Error("fetched device list but failed to store/emit it")
@ -394,8 +418,9 @@ func (u *DeviceListUpdater) updateDeviceList(res *gomatrixserverlib.RespUserDevi
continue
}
keys[i] = api.DeviceMessage{
Type: api.TypeDeviceKeyUpdate,
StreamID: res.StreamID,
DeviceKeys: api.DeviceKeys{
DeviceKeys: &api.DeviceKeys{
DeviceID: device.DeviceID,
DisplayName: device.DisplayName,
UserID: res.UserID,
@ -403,7 +428,8 @@ func (u *DeviceListUpdater) updateDeviceList(res *gomatrixserverlib.RespUserDevi
},
}
existingKeys[i] = api.DeviceMessage{
DeviceKeys: api.DeviceKeys{
Type: api.TypeDeviceKeyUpdate,
DeviceKeys: &api.DeviceKeys{
UserID: res.UserID,
DeviceID: device.DeviceID,
},

View file

@ -95,6 +95,13 @@ func (d *mockDeviceListUpdaterDatabase) DeviceKeysJSON(ctx context.Context, keys
return nil
}
type mockDeviceListUpdaterAPI struct {
}
func (d *mockDeviceListUpdaterAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) {
}
type roundTripper struct {
fn func(*http.Request) (*http.Response, error)
}
@ -122,8 +129,9 @@ func TestUpdateHavePrevID(t *testing.T) {
return true
},
}
ap := &mockDeviceListUpdaterAPI{}
producer := &mockKeyChangeProducer{}
updater := NewDeviceListUpdater(db, producer, nil, 1)
updater := NewDeviceListUpdater(db, ap, producer, nil, 1)
event := gomatrixserverlib.DeviceListUpdateEvent{
DeviceDisplayName: "Foo Bar",
Deleted: false,
@ -138,8 +146,9 @@ func TestUpdateHavePrevID(t *testing.T) {
t.Fatalf("Update returned an error: %s", err)
}
want := api.DeviceMessage{
Type: api.TypeDeviceKeyUpdate,
StreamID: event.StreamID,
DeviceKeys: api.DeviceKeys{
DeviceKeys: &api.DeviceKeys{
DeviceID: event.DeviceID,
DisplayName: event.DeviceDisplayName,
KeyJSON: event.Keys,
@ -166,6 +175,7 @@ func TestUpdateNoPrevID(t *testing.T) {
return false
},
}
ap := &mockDeviceListUpdaterAPI{}
producer := &mockKeyChangeProducer{}
remoteUserID := "@alice:example.somewhere"
var wg sync.WaitGroup
@ -193,7 +203,7 @@ func TestUpdateNoPrevID(t *testing.T) {
`)),
}, nil
})
updater := NewDeviceListUpdater(db, producer, fedClient, 2)
updater := NewDeviceListUpdater(db, ap, producer, fedClient, 2)
if err := updater.Start(); err != nil {
t.Fatalf("failed to start updater: %s", err)
}
@ -215,8 +225,9 @@ func TestUpdateNoPrevID(t *testing.T) {
// wait a bit for db to be updated...
time.Sleep(100 * time.Millisecond)
want := api.DeviceMessage{
Type: api.TypeDeviceKeyUpdate,
StreamID: 5,
DeviceKeys: api.DeviceKeys{
DeviceKeys: &api.DeviceKeys{
DeviceID: "JLAFKJWSCS",
DisplayName: "Mobile Phone",
UserID: remoteUserID,

View file

@ -182,6 +182,14 @@ func (a *KeyInternalAPI) claimRemoteKeys(
util.GetLogger(ctx).WithField("num_keys", keysClaimed).Info("Claimed remote keys")
}
func (a *KeyInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.PerformDeleteKeysRequest, res *api.PerformDeleteKeysResponse) {
if err := a.DB.DeleteDeviceKeys(ctx, req.UserID, req.KeyIDs); err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("Failed to delete device keys: %s", err),
}
}
}
func (a *KeyInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOneTimeKeysRequest, res *api.QueryOneTimeKeysResponse) {
count, err := a.DB.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
if err != nil {
@ -221,9 +229,17 @@ func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.Query
func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) {
res.DeviceKeys = make(map[string]map[string]json.RawMessage)
res.MasterKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
res.SelfSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
res.UserSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
res.Failures = make(map[string]interface{})
// get cross-signing keys from the database
a.crossSigningKeysFromDatabase(ctx, req, res)
// make a map from domain to device keys
domainToDeviceKeys := make(map[string]map[string][]string)
domainToCrossSigningKeys := make(map[string]map[string]struct{})
for userID, deviceIDs := range req.UserToDevices {
_, serverName, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
@ -274,16 +290,56 @@ func (a *KeyInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysReques
domainToDeviceKeys[domain] = make(map[string][]string)
domainToDeviceKeys[domain][userID] = append(domainToDeviceKeys[domain][userID], deviceIDs...)
}
// work out if our cross-signing request for this user was
// satisfied, if not add them to the list of things to fetch
if _, ok := res.MasterKeys[userID]; !ok {
if _, ok := domainToCrossSigningKeys[domain]; !ok {
domainToCrossSigningKeys[domain] = make(map[string]struct{})
}
domainToCrossSigningKeys[domain][userID] = struct{}{}
}
if _, ok := res.SelfSigningKeys[userID]; !ok {
if _, ok := domainToCrossSigningKeys[domain]; !ok {
domainToCrossSigningKeys[domain] = make(map[string]struct{})
}
domainToCrossSigningKeys[domain][userID] = struct{}{}
}
}
// attempt to satisfy key queries from the local database first as we should get device updates pushed to us
domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, domainToDeviceKeys)
if len(domainToDeviceKeys) == 0 {
return // nothing to query
if len(domainToDeviceKeys) > 0 || len(domainToCrossSigningKeys) > 0 {
// perform key queries for remote devices
a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys, domainToCrossSigningKeys)
}
// perform key queries for remote devices
a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys)
// Finally, append signatures that we know about
// TODO: This is horrible because we need to round-trip the signature from
// JSON, add the signatures and marshal it again, for some reason?
for userID, forUserID := range res.DeviceKeys {
for keyID, key := range forUserID {
sigMap, err := a.DB.CrossSigningSigsForTarget(ctx, userID, gomatrixserverlib.KeyID(keyID))
if err != nil {
logrus.WithError(err).Errorf("a.DB.CrossSigningSigsForTarget failed")
continue
}
if len(sigMap) == 0 {
continue
}
var deviceKey gomatrixserverlib.DeviceKeys
if err = json.Unmarshal(key, &deviceKey); err != nil {
continue
}
for sourceUserID, forSourceUser := range sigMap {
for sourceKeyID, sourceSig := range forSourceUser {
deviceKey.Signatures[sourceUserID][sourceKeyID] = sourceSig
}
}
if js, err := json.Marshal(deviceKey); err == nil {
res.DeviceKeys[userID][keyID] = js
}
}
}
}
func (a *KeyInternalAPI) remoteKeysFromDatabase(
@ -313,18 +369,36 @@ func (a *KeyInternalAPI) remoteKeysFromDatabase(
}
func (a *KeyInternalAPI) queryRemoteKeys(
ctx context.Context, timeout time.Duration, res *api.QueryKeysResponse, domainToDeviceKeys map[string]map[string][]string,
ctx context.Context, timeout time.Duration, res *api.QueryKeysResponse,
domainToDeviceKeys map[string]map[string][]string, domainToCrossSigningKeys map[string]map[string]struct{},
) {
resultCh := make(chan *gomatrixserverlib.RespQueryKeys, len(domainToDeviceKeys))
// allows us to wait until all federation servers have been poked
var wg sync.WaitGroup
wg.Add(len(domainToDeviceKeys))
// mutex for writing directly to res (e.g failures)
var respMu sync.Mutex
domains := map[string]struct{}{}
for domain := range domainToDeviceKeys {
if domain == string(a.ThisServer) {
continue
}
domains[domain] = struct{}{}
}
for domain := range domainToCrossSigningKeys {
if domain == string(a.ThisServer) {
continue
}
domains[domain] = struct{}{}
}
wg.Add(len(domains))
// fan out
for domain, deviceKeys := range domainToDeviceKeys {
go a.queryRemoteKeysOnServer(ctx, domain, deviceKeys, &wg, &respMu, timeout, resultCh, res)
for domain := range domains {
go a.queryRemoteKeysOnServer(
ctx, domain, domainToDeviceKeys[domain], domainToCrossSigningKeys[domain],
&wg, &respMu, timeout, resultCh, res,
)
}
// Close the result channel when the goroutines have quit so the for .. range exits
@ -344,28 +418,53 @@ func (a *KeyInternalAPI) queryRemoteKeys(
res.DeviceKeys[userID][deviceID] = keyJSON
}
}
for userID, body := range result.MasterKeys {
res.MasterKeys[userID] = body
}
for userID, body := range result.SelfSigningKeys {
res.SelfSigningKeys[userID] = body
}
// TODO: do we want to persist these somewhere now
// that we have fetched them?
}
}
func (a *KeyInternalAPI) queryRemoteKeysOnServer(
ctx context.Context, serverName string, devKeys map[string][]string, wg *sync.WaitGroup,
respMu *sync.Mutex, timeout time.Duration, resultCh chan<- *gomatrixserverlib.RespQueryKeys,
ctx context.Context, serverName string, devKeys map[string][]string, crossSigningKeys map[string]struct{},
wg *sync.WaitGroup, respMu *sync.Mutex, timeout time.Duration, resultCh chan<- *gomatrixserverlib.RespQueryKeys,
res *api.QueryKeysResponse,
) {
defer wg.Done()
fedCtx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
fedCtx := ctx
if timeout > 0 {
var cancel context.CancelFunc
fedCtx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
}
// for users who we do not have any knowledge about, try to start doing device list updates for them
// by hitting /users/devices - otherwise fallback to /keys/query which has nicer bulk properties but
// lack a stream ID.
var userIDsForAllDevices []string
userIDsForAllDevices := map[string]struct{}{}
for userID, deviceIDs := range devKeys {
if len(deviceIDs) == 0 {
userIDsForAllDevices = append(userIDsForAllDevices, userID)
userIDsForAllDevices[userID] = struct{}{}
delete(devKeys, userID)
}
}
for _, userID := range userIDsForAllDevices {
// for cross-signing keys, it's probably easier just to hit /keys/query if we aren't already doing
// a device list update, so we'll populate those back into the /keys/query list if not
for userID := range crossSigningKeys {
if devKeys == nil {
devKeys = map[string][]string{}
}
if _, ok := userIDsForAllDevices[userID]; !ok {
devKeys[userID] = []string{}
}
}
for userID := range userIDsForAllDevices {
err := a.Updater.ManualUpdate(context.Background(), gomatrixserverlib.ServerName(serverName), userID)
if err != nil {
logrus.WithFields(logrus.Fields{
@ -482,7 +581,8 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
existingKeys := make([]api.DeviceMessage, len(keysToStore))
for i := range keysToStore {
existingKeys[i] = api.DeviceMessage{
DeviceKeys: api.DeviceKeys{
Type: api.TypeDeviceKeyUpdate,
DeviceKeys: &api.DeviceKeys{
UserID: keysToStore[i].UserID,
DeviceID: keysToStore[i].DeviceID,
},

View file

@ -27,13 +27,17 @@ import (
// HTTP paths for the internal HTTP APIs
const (
InputDeviceListUpdatePath = "/keyserver/inputDeviceListUpdate"
PerformUploadKeysPath = "/keyserver/performUploadKeys"
PerformClaimKeysPath = "/keyserver/performClaimKeys"
QueryKeysPath = "/keyserver/queryKeys"
QueryKeyChangesPath = "/keyserver/queryKeyChanges"
QueryOneTimeKeysPath = "/keyserver/queryOneTimeKeys"
QueryDeviceMessagesPath = "/keyserver/queryDeviceMessages"
InputDeviceListUpdatePath = "/keyserver/inputDeviceListUpdate"
PerformUploadKeysPath = "/keyserver/performUploadKeys"
PerformClaimKeysPath = "/keyserver/performClaimKeys"
PerformDeleteKeysPath = "/keyserver/performDeleteKeys"
PerformUploadDeviceKeysPath = "/keyserver/performUploadDeviceKeys"
PerformUploadDeviceSignaturesPath = "/keyserver/performUploadDeviceSignatures"
QueryKeysPath = "/keyserver/queryKeys"
QueryKeyChangesPath = "/keyserver/queryKeyChanges"
QueryOneTimeKeysPath = "/keyserver/queryOneTimeKeys"
QueryDeviceMessagesPath = "/keyserver/queryDeviceMessages"
QuerySignaturesPath = "/keyserver/querySignatures"
)
// NewKeyServerClient creates a KeyInternalAPI implemented by talking to a HTTP POST API.
@ -91,6 +95,23 @@ func (h *httpKeyInternalAPI) PerformClaimKeys(
}
}
func (h *httpKeyInternalAPI) PerformDeleteKeys(
ctx context.Context,
request *api.PerformDeleteKeysRequest,
response *api.PerformDeleteKeysResponse,
) {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformClaimKeys")
defer span.Finish()
apiURL := h.apiURL + PerformClaimKeysPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.Error = &api.KeyError{
Err: err.Error(),
}
}
}
func (h *httpKeyInternalAPI) PerformUploadKeys(
ctx context.Context,
request *api.PerformUploadKeysRequest,
@ -175,3 +196,54 @@ func (h *httpKeyInternalAPI) QueryKeyChanges(
}
}
}
func (h *httpKeyInternalAPI) PerformUploadDeviceKeys(
ctx context.Context,
request *api.PerformUploadDeviceKeysRequest,
response *api.PerformUploadDeviceKeysResponse,
) {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformUploadDeviceKeys")
defer span.Finish()
apiURL := h.apiURL + PerformUploadDeviceKeysPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.Error = &api.KeyError{
Err: err.Error(),
}
}
}
func (h *httpKeyInternalAPI) PerformUploadDeviceSignatures(
ctx context.Context,
request *api.PerformUploadDeviceSignaturesRequest,
response *api.PerformUploadDeviceSignaturesResponse,
) {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformUploadDeviceSignatures")
defer span.Finish()
apiURL := h.apiURL + PerformUploadDeviceSignaturesPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.Error = &api.KeyError{
Err: err.Error(),
}
}
}
func (h *httpKeyInternalAPI) QuerySignatures(
ctx context.Context,
request *api.QuerySignaturesRequest,
response *api.QuerySignaturesResponse,
) {
span, ctx := opentracing.StartSpanFromContext(ctx, "QuerySignatures")
defer span.Finish()
apiURL := h.apiURL + QuerySignaturesPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.Error = &api.KeyError{
Err: err.Error(),
}
}
}

View file

@ -47,6 +47,17 @@ func AddRoutes(internalAPIMux *mux.Router, s api.KeyInternalAPI) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(PerformDeleteKeysPath,
httputil.MakeInternalAPI("performDeleteKeys", func(req *http.Request) util.JSONResponse {
request := api.PerformDeleteKeysRequest{}
response := api.PerformDeleteKeysResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
s.PerformDeleteKeys(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(PerformUploadKeysPath,
httputil.MakeInternalAPI("performUploadKeys", func(req *http.Request) util.JSONResponse {
request := api.PerformUploadKeysRequest{}
@ -58,6 +69,28 @@ func AddRoutes(internalAPIMux *mux.Router, s api.KeyInternalAPI) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(PerformUploadDeviceKeysPath,
httputil.MakeInternalAPI("performUploadDeviceKeys", func(req *http.Request) util.JSONResponse {
request := api.PerformUploadDeviceKeysRequest{}
response := api.PerformUploadDeviceKeysResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
s.PerformUploadDeviceKeys(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(PerformUploadDeviceSignaturesPath,
httputil.MakeInternalAPI("performUploadDeviceSignatures", func(req *http.Request) util.JSONResponse {
request := api.PerformUploadDeviceSignaturesRequest{}
response := api.PerformUploadDeviceSignaturesResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
s.PerformUploadDeviceSignatures(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(QueryKeysPath,
httputil.MakeInternalAPI("queryKeys", func(req *http.Request) util.JSONResponse {
request := api.QueryKeysRequest{}
@ -102,4 +135,15 @@ func AddRoutes(internalAPIMux *mux.Router, s api.KeyInternalAPI) {
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(QuerySignaturesPath,
httputil.MakeInternalAPI("querySignatures", func(req *http.Request) util.JSONResponse {
request := api.QuerySignaturesRequest{}
response := api.QuerySignaturesResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
s.QuerySignatures(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
}

View file

@ -18,10 +18,12 @@ import (
"github.com/gorilla/mux"
fedsenderapi "github.com/matrix-org/dendrite/federationsender/api"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/consumers"
"github.com/matrix-org/dendrite/keyserver/internal"
"github.com/matrix-org/dendrite/keyserver/inthttp"
"github.com/matrix-org/dendrite/keyserver/producers"
"github.com/matrix-org/dendrite/keyserver/storage"
"github.com/matrix-org/dendrite/setup"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/jetstream"
"github.com/sirupsen/logrus"
@ -36,9 +38,9 @@ func AddInternalRoutes(router *mux.Router, intAPI api.KeyInternalAPI) {
// NewInternalAPI returns a concerete implementation of the internal API. Callers
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
func NewInternalAPI(
cfg *config.KeyServer, fedClient fedsenderapi.FederationClient,
base *setup.BaseDendrite, cfg *config.KeyServer, fedClient fedsenderapi.FederationClient,
) api.KeyInternalAPI {
_, producer := jetstream.SetupConsumerProducer(&cfg.Matrix.JetStream)
consumer, producer := jetstream.SetupConsumerProducer(&cfg.Matrix.JetStream)
db, err := storage.NewDatabase(&cfg.Database)
if err != nil {
@ -49,17 +51,26 @@ func NewInternalAPI(
Producer: producer,
DB: db,
}
updater := internal.NewDeviceListUpdater(db, keyChangeProducer, fedClient, 8) // 8 workers TODO: configurable
ap := &internal.KeyInternalAPI{
DB: db,
ThisServer: cfg.Matrix.ServerName,
FedClient: fedClient,
Producer: keyChangeProducer,
}
updater := internal.NewDeviceListUpdater(db, ap, keyChangeProducer, fedClient, 8) // 8 workers TODO: configurable
ap.Updater = updater
go func() {
if err := updater.Start(); err != nil {
logrus.WithError(err).Panicf("failed to start device list updater")
}
}()
return &internal.KeyInternalAPI{
DB: db,
ThisServer: cfg.Matrix.ServerName,
FedClient: fedClient,
Producer: keyChangeProducer,
Updater: updater,
keyconsumer := consumers.NewOutputCrossSigningKeyUpdateConsumer(
base.ProcessContext, base.Cfg, consumer, db, ap,
)
if err := keyconsumer.Start(); err != nil {
logrus.WithError(err).Panicf("failed to start keyserver EDU server consumer")
}
return ap
}

View file

@ -19,6 +19,7 @@ import (
"encoding/json"
"github.com/Shopify/sarama"
eduapi "github.com/matrix-org/dendrite/eduserver/api"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage"
"github.com/sirupsen/logrus"
@ -73,3 +74,36 @@ func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceMessage) error {
}
return nil
}
func (p *KeyChange) ProduceSigningKeyUpdate(key eduapi.CrossSigningKeyUpdate) error {
var m sarama.ProducerMessage
output := &api.DeviceMessage{
Type: api.TypeCrossSigningUpdate,
OutputCrossSigningKeyUpdate: &eduapi.OutputCrossSigningKeyUpdate{
CrossSigningKeyUpdate: key,
},
}
value, err := json.Marshal(output)
if err != nil {
return err
}
m.Topic = string(p.Topic)
m.Key = sarama.StringEncoder(key.UserID)
m.Value = sarama.ByteEncoder(value)
partition, offset, err := p.Producer.SendMessage(&m)
if err != nil {
return err
}
err = p.DB.StoreKeyChange(context.Background(), partition, offset, key.UserID)
if err != nil {
return err
}
logrus.WithFields(logrus.Fields{
"user_id": key.UserID,
}).Infof("Produced to cross-signing update topic '%s'", p.Topic)
return nil
}

View file

@ -18,11 +18,15 @@ import (
"context"
"encoding/json"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/types"
"github.com/matrix-org/gomatrixserverlib"
)
type Database interface {
internal.PartitionStorer
// ExistingOneTimeKeys returns a map of keyIDWithAlgorithm to key JSON for the given parameters. If no keys exist with this combination
// of user/device/key/algorithm 4-uple then it is omitted from the map. Returns an error when failing to communicate with the database.
ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
@ -54,6 +58,10 @@ type Database interface {
// If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice.
DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error)
// DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying
// cross-signing signatures relating to that device.
DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error
// ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key
// cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice.
ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error)
@ -73,4 +81,11 @@ type Database interface {
// MarkDeviceListStale sets the stale bit for this user to isStale.
MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error
CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error)
CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error)
CrossSigningSigsForTarget(ctx context.Context, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error)
StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error
StoreCrossSigningSigsForTarget(ctx context.Context, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error
}

View file

@ -0,0 +1,102 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
"github.com/matrix-org/dendrite/keyserver/types"
"github.com/matrix-org/gomatrixserverlib"
)
var crossSigningKeysSchema = `
CREATE TABLE IF NOT EXISTS keyserver_cross_signing_keys (
user_id TEXT NOT NULL,
key_type SMALLINT NOT NULL,
key_data TEXT NOT NULL,
PRIMARY KEY (user_id, key_type)
);
`
const selectCrossSigningKeysForUserSQL = "" +
"SELECT key_type, key_data FROM keyserver_cross_signing_keys" +
" WHERE user_id = $1"
const upsertCrossSigningKeysForUserSQL = "" +
"INSERT INTO keyserver_cross_signing_keys (user_id, key_type, key_data)" +
" VALUES($1, $2, $3)" +
" ON CONFLICT (user_id, key_type) DO UPDATE SET key_data = $3"
type crossSigningKeysStatements struct {
db *sql.DB
selectCrossSigningKeysForUserStmt *sql.Stmt
upsertCrossSigningKeysForUserStmt *sql.Stmt
}
func NewPostgresCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error) {
s := &crossSigningKeysStatements{
db: db,
}
_, err := db.Exec(crossSigningKeysSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.selectCrossSigningKeysForUserStmt, selectCrossSigningKeysForUserSQL},
{&s.upsertCrossSigningKeysForUserStmt, upsertCrossSigningKeysForUserSQL},
}.Prepare(db)
}
func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser(
ctx context.Context, txn *sql.Tx, userID string,
) (r types.CrossSigningKeyMap, err error) {
rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningKeysForUserStmt).QueryContext(ctx, userID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningKeysForUserStmt: rows.close() failed")
r = types.CrossSigningKeyMap{}
for rows.Next() {
var keyTypeInt int16
var keyData gomatrixserverlib.Base64Bytes
if err := rows.Scan(&keyTypeInt, &keyData); err != nil {
return nil, err
}
keyType, ok := types.KeyTypeIntToPurpose[keyTypeInt]
if !ok {
return nil, fmt.Errorf("unknown key purpose int %d", keyTypeInt)
}
r[keyType] = keyData
}
return
}
func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser(
ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes,
) error {
keyTypeInt, ok := types.KeyTypePurposeToInt[keyType]
if !ok {
return fmt.Errorf("unknown key purpose %q", keyType)
}
if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningKeysForUserStmt).ExecContext(ctx, userID, keyTypeInt, keyData); err != nil {
return fmt.Errorf("s.upsertCrossSigningKeysForUserStmt: %w", err)
}
return nil
}

View file

@ -0,0 +1,118 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
"github.com/matrix-org/dendrite/keyserver/types"
"github.com/matrix-org/gomatrixserverlib"
)
var crossSigningSigsSchema = `
CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs (
origin_user_id TEXT NOT NULL,
origin_key_id TEXT NOT NULL,
target_user_id TEXT NOT NULL,
target_key_id TEXT NOT NULL,
signature TEXT NOT NULL,
PRIMARY KEY (origin_user_id, target_user_id, target_key_id)
);
`
const selectCrossSigningSigsForTargetSQL = "" +
"SELECT origin_user_id, origin_key_id, signature FROM keyserver_cross_signing_sigs" +
" WHERE target_user_id = $1 AND target_key_id = $2"
const upsertCrossSigningSigsForTargetSQL = "" +
"INSERT INTO keyserver_cross_signing_sigs (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" +
" VALUES($1, $2, $3, $4, $5)" +
" ON CONFLICT (origin_user_id, target_user_id, target_key_id) DO UPDATE SET (origin_key_id, signature) = ($2, $5)"
const deleteCrossSigningSigsForTargetSQL = "" +
"DELETE FROM keyserver_cross_signing_sigs WHERE target_user_id=$1 AND target_key_id=$2"
type crossSigningSigsStatements struct {
db *sql.DB
selectCrossSigningSigsForTargetStmt *sql.Stmt
upsertCrossSigningSigsForTargetStmt *sql.Stmt
deleteCrossSigningSigsForTargetStmt *sql.Stmt
}
func NewPostgresCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, error) {
s := &crossSigningSigsStatements{
db: db,
}
_, err := db.Exec(crossSigningSigsSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.selectCrossSigningSigsForTargetStmt, selectCrossSigningSigsForTargetSQL},
{&s.upsertCrossSigningSigsForTargetStmt, upsertCrossSigningSigsForTargetSQL},
{&s.deleteCrossSigningSigsForTargetStmt, deleteCrossSigningSigsForTargetSQL},
}.Prepare(db)
}
func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget(
ctx context.Context, txn *sql.Tx, targetUserID string, targetKeyID gomatrixserverlib.KeyID,
) (r types.CrossSigningSigMap, err error) {
rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningSigsForTargetStmt).QueryContext(ctx, targetUserID, targetKeyID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningSigsForTargetStmt: rows.close() failed")
r = types.CrossSigningSigMap{}
for rows.Next() {
var userID string
var keyID gomatrixserverlib.KeyID
var signature gomatrixserverlib.Base64Bytes
if err := rows.Scan(&userID, &keyID, &signature); err != nil {
return nil, err
}
if _, ok := r[userID]; !ok {
r[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
}
r[userID][keyID] = signature
}
return
}
func (s *crossSigningSigsStatements) UpsertCrossSigningSigsForTarget(
ctx context.Context, txn *sql.Tx,
originUserID string, originKeyID gomatrixserverlib.KeyID,
targetUserID string, targetKeyID gomatrixserverlib.KeyID,
signature gomatrixserverlib.Base64Bytes,
) error {
if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningSigsForTargetStmt).ExecContext(ctx, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil {
return fmt.Errorf("s.upsertCrossSigningSigsForTargetStmt: %w", err)
}
return nil
}
func (s *crossSigningSigsStatements) DeleteCrossSigningSigsForTarget(
ctx context.Context, txn *sql.Tx,
targetUserID string, targetKeyID gomatrixserverlib.KeyID,
) error {
if _, err := sqlutil.TxStmt(txn, s.deleteCrossSigningSigsForTargetStmt).ExecContext(ctx, targetUserID, targetKeyID); err != nil {
return fmt.Errorf("s.deleteCrossSigningSigsForTargetStmt: %w", err)
}
return nil
}

View file

@ -62,6 +62,9 @@ const selectMaxStreamForUserSQL = "" +
const countStreamIDsForUserSQL = "" +
"SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id = ANY($2)"
const deleteDeviceKeysSQL = "" +
"DELETE FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
const deleteAllDeviceKeysSQL = "" +
"DELETE FROM keyserver_device_keys WHERE user_id=$1"
@ -72,6 +75,7 @@ type deviceKeysStatements struct {
selectBatchDeviceKeysStmt *sql.Stmt
selectMaxStreamForUserStmt *sql.Stmt
countStreamIDsForUserStmt *sql.Stmt
deleteDeviceKeysStmt *sql.Stmt
deleteAllDeviceKeysStmt *sql.Stmt
}
@ -98,6 +102,9 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
if s.countStreamIDsForUserStmt, err = db.Prepare(countStreamIDsForUserSQL); err != nil {
return nil, err
}
if s.deleteDeviceKeysStmt, err = db.Prepare(deleteDeviceKeysSQL); err != nil {
return nil, err
}
if s.deleteAllDeviceKeysStmt, err = db.Prepare(deleteAllDeviceKeysSQL); err != nil {
return nil, err
}
@ -114,6 +121,7 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []
return err
}
// this will be '' when there is no device
keys[i].Type = api.TypeDeviceKeyUpdate
keys[i].KeyJSON = []byte(keyJSONStr)
keys[i].StreamID = streamID
if displayName.Valid {
@ -162,6 +170,11 @@ func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx
return nil
}
func (s *deviceKeysStatements) DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error {
_, err := sqlutil.TxStmt(txn, s.deleteDeviceKeysStmt).ExecContext(ctx, userID, deviceID)
return err
}
func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error {
_, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
return err
@ -179,7 +192,10 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
}
var result []api.DeviceMessage
for rows.Next() {
var dk api.DeviceMessage
dk := api.DeviceMessage{
Type: api.TypeDeviceKeyUpdate,
DeviceKeys: &api.DeviceKeys{},
}
dk.UserID = userID
var keyJSON string
var streamID int

View file

@ -43,12 +43,26 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error)
if err != nil {
return nil, err
}
return &shared.Database{
csk, err := NewPostgresCrossSigningKeysTable(db)
if err != nil {
return nil, err
}
css, err := NewPostgresCrossSigningSigsTable(db)
if err != nil {
return nil, err
}
d := &shared.Database{
DB: db,
Writer: sqlutil.NewDummyWriter(),
OneTimeKeysTable: otk,
DeviceKeysTable: dk,
KeyChangesTable: kc,
StaleDeviceListsTable: sdl,
}, nil
CrossSigningKeysTable: csk,
CrossSigningSigsTable: css,
}
if err = d.PartitionOffsetStatements.Prepare(db, d.Writer, "keyserver"); err != nil {
return nil, err
}
return d, nil
}

View file

@ -18,10 +18,12 @@ import (
"context"
"database/sql"
"encoding/json"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
"github.com/matrix-org/dendrite/keyserver/types"
"github.com/matrix-org/gomatrixserverlib"
)
@ -32,6 +34,9 @@ type Database struct {
DeviceKeysTable tables.DeviceKeys
KeyChangesTable tables.KeyChanges
StaleDeviceListsTable tables.StaleDeviceLists
CrossSigningKeysTable tables.CrossSigningKeys
CrossSigningSigsTable tables.CrossSigningSigs
sqlutil.PartitionOffsetStatements
}
func (d *Database) ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
@ -152,3 +157,95 @@ func (d *Database) MarkDeviceListStale(ctx context.Context, userID string, isSta
return d.StaleDeviceListsTable.InsertStaleDeviceList(ctx, userID, isStale)
})
}
// DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying
// cross-signing signatures relating to that device.
func (d *Database) DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error {
return d.Writer.Do(nil, nil, func(txn *sql.Tx) error {
for _, deviceID := range deviceIDs {
if err := d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget(ctx, txn, userID, deviceID); err != nil && err != sql.ErrNoRows {
return fmt.Errorf("d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget: %w", err)
}
if err := d.DeviceKeysTable.DeleteDeviceKeys(ctx, txn, userID, string(deviceID)); err != nil && err != sql.ErrNoRows {
return fmt.Errorf("d.DeviceKeysTable.DeleteDeviceKeys: %w", err)
}
}
return nil
})
}
// CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any.
func (d *Database) CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error) {
keyMap, err := d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID)
if err != nil {
return nil, fmt.Errorf("d.CrossSigningKeysTable.SelectCrossSigningKeysForUser: %w", err)
}
results := map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey{}
for purpose, key := range keyMap {
keyID := gomatrixserverlib.KeyID("ed25519:" + key.Encode())
result := gomatrixserverlib.CrossSigningKey{
UserID: userID,
Usage: []gomatrixserverlib.CrossSigningKeyPurpose{purpose},
Keys: map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{
keyID: key,
},
}
sigMap, err := d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, userID, keyID)
if err != nil {
continue
}
for sigUserID, forSigUserID := range sigMap {
if userID != sigUserID {
continue
}
if result.Signatures == nil {
result.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
}
if _, ok := result.Signatures[sigUserID]; !ok {
result.Signatures[sigUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
}
for sigKeyID, sigBytes := range forSigUserID {
result.Signatures[sigUserID][sigKeyID] = sigBytes
}
}
results[purpose] = result
}
return results, nil
}
// CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any.
func (d *Database) CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error) {
return d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID)
}
// CrossSigningSigsForTarget returns the signatures for a given user's key ID, if any.
func (d *Database) CrossSigningSigsForTarget(ctx context.Context, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) {
return d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, targetUserID, targetKeyID)
}
// StoreCrossSigningKeysForUser stores the latest known cross-signing keys for a user.
func (d *Database) StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
for keyType, keyData := range keyMap {
if err := d.CrossSigningKeysTable.UpsertCrossSigningKeysForUser(ctx, txn, userID, keyType, keyData); err != nil {
return fmt.Errorf("d.CrossSigningKeysTable.InsertCrossSigningKeysForUser: %w", err)
}
}
return nil
})
}
// StoreCrossSigningSigsForTarget stores a signature for a target user ID and key/dvice.
func (d *Database) StoreCrossSigningSigsForTarget(
ctx context.Context,
originUserID string, originKeyID gomatrixserverlib.KeyID,
targetUserID string, targetKeyID gomatrixserverlib.KeyID,
signature gomatrixserverlib.Base64Bytes,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if err := d.CrossSigningSigsTable.UpsertCrossSigningSigsForTarget(ctx, nil, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil {
return fmt.Errorf("d.CrossSigningSigsTable.InsertCrossSigningSigsForTarget: %w", err)
}
return nil
})
}

View file

@ -0,0 +1,101 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
"github.com/matrix-org/dendrite/keyserver/types"
"github.com/matrix-org/gomatrixserverlib"
)
var crossSigningKeysSchema = `
CREATE TABLE IF NOT EXISTS keyserver_cross_signing_keys (
user_id TEXT NOT NULL,
key_type INTEGER NOT NULL,
key_data TEXT NOT NULL,
PRIMARY KEY (user_id, key_type)
);
`
const selectCrossSigningKeysForUserSQL = "" +
"SELECT key_type, key_data FROM keyserver_cross_signing_keys" +
" WHERE user_id = $1"
const upsertCrossSigningKeysForUserSQL = "" +
"INSERT OR REPLACE INTO keyserver_cross_signing_keys (user_id, key_type, key_data)" +
" VALUES($1, $2, $3)"
type crossSigningKeysStatements struct {
db *sql.DB
selectCrossSigningKeysForUserStmt *sql.Stmt
upsertCrossSigningKeysForUserStmt *sql.Stmt
}
func NewSqliteCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error) {
s := &crossSigningKeysStatements{
db: db,
}
_, err := db.Exec(crossSigningKeysSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.selectCrossSigningKeysForUserStmt, selectCrossSigningKeysForUserSQL},
{&s.upsertCrossSigningKeysForUserStmt, upsertCrossSigningKeysForUserSQL},
}.Prepare(db)
}
func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser(
ctx context.Context, txn *sql.Tx, userID string,
) (r types.CrossSigningKeyMap, err error) {
rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningKeysForUserStmt).QueryContext(ctx, userID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningKeysForUserStmt: rows.close() failed")
r = types.CrossSigningKeyMap{}
for rows.Next() {
var keyTypeInt int16
var keyData gomatrixserverlib.Base64Bytes
if err := rows.Scan(&keyTypeInt, &keyData); err != nil {
return nil, err
}
keyType, ok := types.KeyTypeIntToPurpose[keyTypeInt]
if !ok {
return nil, fmt.Errorf("unknown key purpose int %d", keyTypeInt)
}
r[keyType] = keyData
}
return
}
func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser(
ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes,
) error {
keyTypeInt, ok := types.KeyTypePurposeToInt[keyType]
if !ok {
return fmt.Errorf("unknown key purpose %q", keyType)
}
if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningKeysForUserStmt).ExecContext(ctx, userID, keyTypeInt, keyData); err != nil {
return fmt.Errorf("s.upsertCrossSigningKeysForUserStmt: %w", err)
}
return nil
}

View file

@ -0,0 +1,117 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
"github.com/matrix-org/dendrite/keyserver/types"
"github.com/matrix-org/gomatrixserverlib"
)
var crossSigningSigsSchema = `
CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs (
origin_user_id TEXT NOT NULL,
origin_key_id TEXT NOT NULL,
target_user_id TEXT NOT NULL,
target_key_id TEXT NOT NULL,
signature TEXT NOT NULL,
PRIMARY KEY (origin_user_id, target_user_id, target_key_id)
);
`
const selectCrossSigningSigsForTargetSQL = "" +
"SELECT origin_user_id, origin_key_id, signature FROM keyserver_cross_signing_sigs" +
" WHERE target_user_id = $1 AND target_key_id = $2"
const upsertCrossSigningSigsForTargetSQL = "" +
"INSERT OR REPLACE INTO keyserver_cross_signing_sigs (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" +
" VALUES($1, $2, $3, $4, $5)"
const deleteCrossSigningSigsForTargetSQL = "" +
"DELETE FROM keyserver_cross_signing_sigs WHERE target_user_id=$1 AND target_key_id=$2"
type crossSigningSigsStatements struct {
db *sql.DB
selectCrossSigningSigsForTargetStmt *sql.Stmt
upsertCrossSigningSigsForTargetStmt *sql.Stmt
deleteCrossSigningSigsForTargetStmt *sql.Stmt
}
func NewSqliteCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, error) {
s := &crossSigningSigsStatements{
db: db,
}
_, err := db.Exec(crossSigningSigsSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.selectCrossSigningSigsForTargetStmt, selectCrossSigningSigsForTargetSQL},
{&s.upsertCrossSigningSigsForTargetStmt, upsertCrossSigningSigsForTargetSQL},
{&s.deleteCrossSigningSigsForTargetStmt, deleteCrossSigningSigsForTargetSQL},
}.Prepare(db)
}
func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget(
ctx context.Context, txn *sql.Tx, targetUserID string, targetKeyID gomatrixserverlib.KeyID,
) (r types.CrossSigningSigMap, err error) {
rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningSigsForTargetStmt).QueryContext(ctx, targetUserID, targetKeyID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningSigsForTargetStmt: rows.close() failed")
r = types.CrossSigningSigMap{}
for rows.Next() {
var userID string
var keyID gomatrixserverlib.KeyID
var signature gomatrixserverlib.Base64Bytes
if err := rows.Scan(&userID, &keyID, &signature); err != nil {
return nil, err
}
if _, ok := r[userID]; !ok {
r[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
}
r[userID][keyID] = signature
}
return
}
func (s *crossSigningSigsStatements) UpsertCrossSigningSigsForTarget(
ctx context.Context, txn *sql.Tx,
originUserID string, originKeyID gomatrixserverlib.KeyID,
targetUserID string, targetKeyID gomatrixserverlib.KeyID,
signature gomatrixserverlib.Base64Bytes,
) error {
if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningSigsForTargetStmt).ExecContext(ctx, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil {
return fmt.Errorf("s.upsertCrossSigningSigsForTargetStmt: %w", err)
}
return nil
}
func (s *crossSigningSigsStatements) DeleteCrossSigningSigsForTarget(
ctx context.Context, txn *sql.Tx,
targetUserID string, targetKeyID gomatrixserverlib.KeyID,
) error {
if _, err := sqlutil.TxStmt(txn, s.deleteCrossSigningSigsForTargetStmt).ExecContext(ctx, targetUserID, targetKeyID); err != nil {
return fmt.Errorf("s.deleteCrossSigningSigsForTargetStmt: %w", err)
}
return nil
}

View file

@ -58,6 +58,9 @@ const selectMaxStreamForUserSQL = "" +
const countStreamIDsForUserSQL = "" +
"SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id IN ($2)"
const deleteDeviceKeysSQL = "" +
"DELETE FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
const deleteAllDeviceKeysSQL = "" +
"DELETE FROM keyserver_device_keys WHERE user_id=$1"
@ -67,6 +70,7 @@ type deviceKeysStatements struct {
selectDeviceKeysStmt *sql.Stmt
selectBatchDeviceKeysStmt *sql.Stmt
selectMaxStreamForUserStmt *sql.Stmt
deleteDeviceKeysStmt *sql.Stmt
deleteAllDeviceKeysStmt *sql.Stmt
}
@ -90,12 +94,20 @@ func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
return nil, err
}
if s.deleteDeviceKeysStmt, err = db.Prepare(deleteDeviceKeysSQL); err != nil {
return nil, err
}
if s.deleteAllDeviceKeysStmt, err = db.Prepare(deleteAllDeviceKeysSQL); err != nil {
return nil, err
}
return s, nil
}
func (s *deviceKeysStatements) DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error {
_, err := sqlutil.TxStmt(txn, s.deleteDeviceKeysStmt).ExecContext(ctx, userID, deviceID)
return err
}
func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error {
_, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
return err
@ -113,7 +125,11 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed")
var result []api.DeviceMessage
for rows.Next() {
var dk api.DeviceMessage
dk := api.DeviceMessage{
Type: api.TypeDeviceKeyUpdate,
DeviceKeys: &api.DeviceKeys{},
}
dk.Type = api.TypeDeviceKeyUpdate
dk.UserID = userID
var keyJSON string
var streamID int
@ -144,6 +160,7 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []
return err
}
// this will be '' when there is no device
keys[i].Type = api.TypeDeviceKeyUpdate
keys[i].KeyJSON = []byte(keyJSONStr)
keys[i].StreamID = streamID
if displayName.Valid {

View file

@ -41,12 +41,26 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error)
if err != nil {
return nil, err
}
return &shared.Database{
csk, err := NewSqliteCrossSigningKeysTable(db)
if err != nil {
return nil, err
}
css, err := NewSqliteCrossSigningSigsTable(db)
if err != nil {
return nil, err
}
d := &shared.Database{
DB: db,
Writer: sqlutil.NewExclusiveWriter(),
OneTimeKeysTable: otk,
DeviceKeysTable: dk,
KeyChangesTable: kc,
StaleDeviceListsTable: sdl,
}, nil
CrossSigningKeysTable: csk,
CrossSigningSigsTable: css,
}
if err = d.PartitionOffsetStatements.Prepare(db, d.Writer, "keyserver"); err != nil {
return nil, err
}
return d, nil
}

View file

@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
//go:build !wasm
// +build !wasm
package storage

View file

@ -105,7 +105,8 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) {
bob := "@bob:TestDeviceKeysStreamIDGeneration"
msgs := []api.DeviceMessage{
{
DeviceKeys: api.DeviceKeys{
Type: api.TypeDeviceKeyUpdate,
DeviceKeys: &api.DeviceKeys{
DeviceID: "AAA",
UserID: alice,
KeyJSON: []byte(`{"key":"v1"}`),
@ -113,7 +114,8 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) {
// StreamID: 1
},
{
DeviceKeys: api.DeviceKeys{
Type: api.TypeDeviceKeyUpdate,
DeviceKeys: &api.DeviceKeys{
DeviceID: "AAA",
UserID: bob,
KeyJSON: []byte(`{"key":"v1"}`),
@ -121,7 +123,8 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) {
// StreamID: 1 as this is a different user
},
{
DeviceKeys: api.DeviceKeys{
Type: api.TypeDeviceKeyUpdate,
DeviceKeys: &api.DeviceKeys{
DeviceID: "another_device",
UserID: alice,
KeyJSON: []byte(`{"key":"v1"}`),
@ -143,7 +146,8 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) {
// updating a device sets the next stream ID for that user
msgs = []api.DeviceMessage{
{
DeviceKeys: api.DeviceKeys{
Type: api.TypeDeviceKeyUpdate,
DeviceKeys: &api.DeviceKeys{
DeviceID: "AAA",
UserID: alice,
KeyJSON: []byte(`{"key":"v2"}`),

View file

@ -20,6 +20,7 @@ import (
"encoding/json"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/types"
"github.com/matrix-org/gomatrixserverlib"
)
@ -38,6 +39,7 @@ type DeviceKeys interface {
SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error)
CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error)
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error)
DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error
DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error
}
@ -52,3 +54,14 @@ type StaleDeviceLists interface {
InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error
SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error)
}
type CrossSigningKeys interface {
SelectCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string) (r types.CrossSigningKeyMap, err error)
UpsertCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes) error
}
type CrossSigningSigs interface {
SelectCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (r types.CrossSigningSigMap, err error)
UpsertCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error
DeleteCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, targetUserID string, targetKeyID gomatrixserverlib.KeyID) error
}

View file

@ -0,0 +1,39 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package types
import "github.com/matrix-org/gomatrixserverlib"
// KeyTypePurposeToInt maps a purpose to an integer, which is used in the
// database to reduce the amount of space taken up by this column.
var KeyTypePurposeToInt = map[gomatrixserverlib.CrossSigningKeyPurpose]int16{
gomatrixserverlib.CrossSigningKeyPurposeMaster: 1,
gomatrixserverlib.CrossSigningKeyPurposeSelfSigning: 2,
gomatrixserverlib.CrossSigningKeyPurposeUserSigning: 3,
}
// KeyTypeIntToPurpose maps an integer to a purpose, which is used in the
// database to reduce the amount of space taken up by this column.
var KeyTypeIntToPurpose = map[int16]gomatrixserverlib.CrossSigningKeyPurpose{
1: gomatrixserverlib.CrossSigningKeyPurposeMaster,
2: gomatrixserverlib.CrossSigningKeyPurposeSelfSigning,
3: gomatrixserverlib.CrossSigningKeyPurposeUserSigning,
}
// Map of purpose -> public key
type CrossSigningKeyMap map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.Base64Bytes
// Map of user ID -> key ID -> signature
type CrossSigningSigMap map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes