mirror of
https://github.com/hoernschen/dendrite.git
synced 2025-07-31 13:22:46 +00:00
Merge keyserver & userapi (#2972)
As discussed yesterday, a first draft of merging the keyserver and the userapi.
This commit is contained in:
parent
bd6f0c14e5
commit
4594233f89
107 changed files with 1730 additions and 1863 deletions
|
@ -15,9 +15,13 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/userapi/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
|
@ -26,15 +30,12 @@ import (
|
|||
|
||||
// UserInternalAPI is the internal API for information about users and devices.
|
||||
type UserInternalAPI interface {
|
||||
AppserviceUserAPI
|
||||
SyncUserAPI
|
||||
ClientUserAPI
|
||||
MediaUserAPI
|
||||
FederationUserAPI
|
||||
RoomserverUserAPI
|
||||
KeyserverUserAPI
|
||||
|
||||
QuerySearchProfilesAPI // used by p2p demos
|
||||
QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error)
|
||||
}
|
||||
|
||||
// api functions required by the appservice api
|
||||
|
@ -43,11 +44,6 @@ type AppserviceUserAPI interface {
|
|||
PerformDeviceCreation(ctx context.Context, req *PerformDeviceCreationRequest, res *PerformDeviceCreationResponse) error
|
||||
}
|
||||
|
||||
type KeyserverUserAPI interface {
|
||||
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
|
||||
QueryDeviceInfos(ctx context.Context, req *QueryDeviceInfosRequest, res *QueryDeviceInfosResponse) error
|
||||
}
|
||||
|
||||
type RoomserverUserAPI interface {
|
||||
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
|
||||
QueryAccountByLocalpart(ctx context.Context, req *QueryAccountByLocalpartRequest, res *QueryAccountByLocalpartResponse) (err error)
|
||||
|
@ -60,13 +56,20 @@ type MediaUserAPI interface {
|
|||
|
||||
// api functions required by the federation api
|
||||
type FederationUserAPI interface {
|
||||
UploadDeviceKeysAPI
|
||||
QueryOpenIDToken(ctx context.Context, req *QueryOpenIDTokenRequest, res *QueryOpenIDTokenResponse) error
|
||||
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error
|
||||
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
|
||||
QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error
|
||||
QuerySignatures(ctx context.Context, req *QuerySignaturesRequest, res *QuerySignaturesResponse) error
|
||||
QueryDeviceMessages(ctx context.Context, req *QueryDeviceMessagesRequest, res *QueryDeviceMessagesResponse) error
|
||||
PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error
|
||||
}
|
||||
|
||||
// api functions required by the sync api
|
||||
type SyncUserAPI interface {
|
||||
QueryAcccessTokenAPI
|
||||
SyncKeyAPI
|
||||
QueryAccountData(ctx context.Context, req *QueryAccountDataRequest, res *QueryAccountDataResponse) error
|
||||
PerformLastSeenUpdate(ctx context.Context, req *PerformLastSeenUpdateRequest, res *PerformLastSeenUpdateResponse) error
|
||||
PerformDeviceUpdate(ctx context.Context, req *PerformDeviceUpdateRequest, res *PerformDeviceUpdateResponse) error
|
||||
|
@ -79,6 +82,7 @@ type ClientUserAPI interface {
|
|||
QueryAcccessTokenAPI
|
||||
LoginTokenInternalAPI
|
||||
UserLoginAPI
|
||||
ClientKeyAPI
|
||||
QueryNumericLocalpart(ctx context.Context, req *QueryNumericLocalpartRequest, res *QueryNumericLocalpartResponse) error
|
||||
QueryDevices(ctx context.Context, req *QueryDevicesRequest, res *QueryDevicesResponse) error
|
||||
QueryProfile(ctx context.Context, req *QueryProfileRequest, res *QueryProfileResponse) error
|
||||
|
@ -681,3 +685,310 @@ type QueryAccountByLocalpartRequest struct {
|
|||
type QueryAccountByLocalpartResponse struct {
|
||||
Account *Account
|
||||
}
|
||||
|
||||
// API functions required by the clientapi
|
||||
type ClientKeyAPI interface {
|
||||
UploadDeviceKeysAPI
|
||||
QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error
|
||||
PerformUploadKeys(ctx context.Context, req *PerformUploadKeysRequest, res *PerformUploadKeysResponse) error
|
||||
|
||||
PerformUploadDeviceSignatures(ctx context.Context, req *PerformUploadDeviceSignaturesRequest, res *PerformUploadDeviceSignaturesResponse) error
|
||||
// PerformClaimKeys claims one-time keys for use in pre-key messages
|
||||
PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error
|
||||
PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error
|
||||
}
|
||||
|
||||
type UploadDeviceKeysAPI interface {
|
||||
PerformUploadDeviceKeys(ctx context.Context, req *PerformUploadDeviceKeysRequest, res *PerformUploadDeviceKeysResponse) error
|
||||
}
|
||||
|
||||
// API functions required by the syncapi
|
||||
type SyncKeyAPI interface {
|
||||
QueryKeyChanges(ctx context.Context, req *QueryKeyChangesRequest, res *QueryKeyChangesResponse) error
|
||||
QueryOneTimeKeys(ctx context.Context, req *QueryOneTimeKeysRequest, res *QueryOneTimeKeysResponse) error
|
||||
PerformMarkAsStaleIfNeeded(ctx context.Context, req *PerformMarkAsStaleRequest, res *struct{}) error
|
||||
}
|
||||
|
||||
type FederationKeyAPI interface {
|
||||
UploadDeviceKeysAPI
|
||||
QueryKeys(ctx context.Context, req *QueryKeysRequest, res *QueryKeysResponse) error
|
||||
QuerySignatures(ctx context.Context, req *QuerySignaturesRequest, res *QuerySignaturesResponse) error
|
||||
QueryDeviceMessages(ctx context.Context, req *QueryDeviceMessagesRequest, res *QueryDeviceMessagesResponse) error
|
||||
PerformClaimKeys(ctx context.Context, req *PerformClaimKeysRequest, res *PerformClaimKeysResponse) error
|
||||
}
|
||||
|
||||
// KeyError is returned if there was a problem performing/querying the server
|
||||
type KeyError struct {
|
||||
Err string `json:"error"`
|
||||
IsInvalidSignature bool `json:"is_invalid_signature,omitempty"` // M_INVALID_SIGNATURE
|
||||
IsMissingParam bool `json:"is_missing_param,omitempty"` // M_MISSING_PARAM
|
||||
IsInvalidParam bool `json:"is_invalid_param,omitempty"` // M_INVALID_PARAM
|
||||
}
|
||||
|
||||
func (k *KeyError) Error() string {
|
||||
return k.Err
|
||||
}
|
||||
|
||||
type DeviceMessageType int
|
||||
|
||||
const (
|
||||
TypeDeviceKeyUpdate DeviceMessageType = iota
|
||||
TypeCrossSigningUpdate
|
||||
)
|
||||
|
||||
// DeviceMessage represents the message produced into Kafka by the key server.
|
||||
type DeviceMessage struct {
|
||||
Type DeviceMessageType `json:"Type,omitempty"`
|
||||
*DeviceKeys `json:"DeviceKeys,omitempty"`
|
||||
*OutputCrossSigningKeyUpdate `json:"CrossSigningKeyUpdate,omitempty"`
|
||||
// A monotonically increasing number which represents device changes for this user.
|
||||
StreamID int64
|
||||
DeviceChangeID int64
|
||||
}
|
||||
|
||||
// OutputCrossSigningKeyUpdate is an entry in the signing key update output kafka log
|
||||
type OutputCrossSigningKeyUpdate struct {
|
||||
CrossSigningKeyUpdate `json:"signing_keys"`
|
||||
}
|
||||
|
||||
type CrossSigningKeyUpdate struct {
|
||||
MasterKey *gomatrixserverlib.CrossSigningKey `json:"master_key,omitempty"`
|
||||
SelfSigningKey *gomatrixserverlib.CrossSigningKey `json:"self_signing_key,omitempty"`
|
||||
UserID string `json:"user_id"`
|
||||
}
|
||||
|
||||
// DeviceKeysEqual returns true if the device keys updates contain the
|
||||
// same display name and key JSON. This will return false if either of
|
||||
// the updates is not a device keys update, or if the user ID/device ID
|
||||
// differ between the two.
|
||||
func (m1 *DeviceMessage) DeviceKeysEqual(m2 *DeviceMessage) bool {
|
||||
if m1.DeviceKeys == nil || m2.DeviceKeys == nil {
|
||||
return false
|
||||
}
|
||||
if m1.UserID != m2.UserID || m1.DeviceID != m2.DeviceID {
|
||||
return false
|
||||
}
|
||||
if m1.DisplayName != m2.DisplayName {
|
||||
return false // different display names
|
||||
}
|
||||
if len(m1.KeyJSON) == 0 || len(m2.KeyJSON) == 0 {
|
||||
return false // either is empty
|
||||
}
|
||||
return bytes.Equal(m1.KeyJSON, m2.KeyJSON)
|
||||
}
|
||||
|
||||
// DeviceKeys represents a set of device keys for a single device
|
||||
// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload
|
||||
type DeviceKeys struct {
|
||||
// The user who owns this device
|
||||
UserID string
|
||||
// The device ID of this device
|
||||
DeviceID string
|
||||
// The device display name
|
||||
DisplayName string
|
||||
// The raw device key JSON
|
||||
KeyJSON []byte
|
||||
}
|
||||
|
||||
// WithStreamID returns a copy of this device message with the given stream ID
|
||||
func (k *DeviceKeys) WithStreamID(streamID int64) DeviceMessage {
|
||||
return DeviceMessage{
|
||||
DeviceKeys: k,
|
||||
StreamID: streamID,
|
||||
}
|
||||
}
|
||||
|
||||
// OneTimeKeys represents a set of one-time keys for a single device
|
||||
// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload
|
||||
type OneTimeKeys struct {
|
||||
// The user who owns this device
|
||||
UserID string
|
||||
// The device ID of this device
|
||||
DeviceID string
|
||||
// A map of algorithm:key_id => key JSON
|
||||
KeyJSON map[string]json.RawMessage
|
||||
}
|
||||
|
||||
// Split a key in KeyJSON into algorithm and key ID
|
||||
func (k *OneTimeKeys) Split(keyIDWithAlgo string) (algo string, keyID string) {
|
||||
segments := strings.Split(keyIDWithAlgo, ":")
|
||||
return segments[0], segments[1]
|
||||
}
|
||||
|
||||
// OneTimeKeysCount represents the counts of one-time keys for a single device
|
||||
type OneTimeKeysCount struct {
|
||||
// The user who owns this device
|
||||
UserID string
|
||||
// The device ID of this device
|
||||
DeviceID string
|
||||
// algorithm to count e.g:
|
||||
// {
|
||||
// "curve25519": 10,
|
||||
// "signed_curve25519": 20
|
||||
// }
|
||||
KeyCount map[string]int
|
||||
}
|
||||
|
||||
// PerformUploadKeysRequest is the request to PerformUploadKeys
|
||||
type PerformUploadKeysRequest struct {
|
||||
UserID string // Required - User performing the request
|
||||
DeviceID string // Optional - Device performing the request, for fetching OTK count
|
||||
DeviceKeys []DeviceKeys
|
||||
OneTimeKeys []OneTimeKeys
|
||||
// OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update
|
||||
// the display name for their respective device, and NOT to modify the keys. The key
|
||||
// itself doesn't change but it's easier to pretend upload new keys and reuse the same code paths.
|
||||
// Without this flag, requests to modify device display names would delete device keys.
|
||||
OnlyDisplayNameUpdates bool
|
||||
}
|
||||
|
||||
// PerformUploadKeysResponse is the response to PerformUploadKeys
|
||||
type PerformUploadKeysResponse struct {
|
||||
// A fatal error when processing e.g database failures
|
||||
Error *KeyError
|
||||
// A map of user_id -> device_id -> Error for tracking failures.
|
||||
KeyErrors map[string]map[string]*KeyError
|
||||
OneTimeKeyCounts []OneTimeKeysCount
|
||||
}
|
||||
|
||||
// PerformDeleteKeysRequest asks the keyserver to forget about certain
|
||||
// keys, and signatures related to those keys.
|
||||
type PerformDeleteKeysRequest struct {
|
||||
UserID string
|
||||
KeyIDs []gomatrixserverlib.KeyID
|
||||
}
|
||||
|
||||
// PerformDeleteKeysResponse is the response to PerformDeleteKeysRequest.
|
||||
type PerformDeleteKeysResponse struct {
|
||||
Error *KeyError
|
||||
}
|
||||
|
||||
// KeyError sets a key error field on KeyErrors
|
||||
func (r *PerformUploadKeysResponse) KeyError(userID, deviceID string, err *KeyError) {
|
||||
if r.KeyErrors[userID] == nil {
|
||||
r.KeyErrors[userID] = make(map[string]*KeyError)
|
||||
}
|
||||
r.KeyErrors[userID][deviceID] = err
|
||||
}
|
||||
|
||||
type PerformClaimKeysRequest struct {
|
||||
// Map of user_id to device_id to algorithm name
|
||||
OneTimeKeys map[string]map[string]string
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
type PerformClaimKeysResponse struct {
|
||||
// Map of user_id to device_id to algorithm:key_id to key JSON
|
||||
OneTimeKeys map[string]map[string]map[string]json.RawMessage
|
||||
// Map of remote server domain to error JSON
|
||||
Failures map[string]interface{}
|
||||
// Set if there was a fatal error processing this action
|
||||
Error *KeyError
|
||||
}
|
||||
|
||||
type PerformUploadDeviceKeysRequest struct {
|
||||
gomatrixserverlib.CrossSigningKeys
|
||||
// The user that uploaded the key, should be populated by the clientapi.
|
||||
UserID string
|
||||
}
|
||||
|
||||
type PerformUploadDeviceKeysResponse struct {
|
||||
Error *KeyError
|
||||
}
|
||||
|
||||
type PerformUploadDeviceSignaturesRequest struct {
|
||||
Signatures map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice
|
||||
// The user that uploaded the sig, should be populated by the clientapi.
|
||||
UserID string
|
||||
}
|
||||
|
||||
type PerformUploadDeviceSignaturesResponse struct {
|
||||
Error *KeyError
|
||||
}
|
||||
|
||||
type QueryKeysRequest struct {
|
||||
// The user ID asking for the keys, e.g. if from a client API request.
|
||||
// Will not be populated if the key request came from federation.
|
||||
UserID string
|
||||
// Maps user IDs to a list of devices
|
||||
UserToDevices map[string][]string
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
type QueryKeysResponse struct {
|
||||
// Map of remote server domain to error JSON
|
||||
Failures map[string]interface{}
|
||||
// Map of user_id to device_id to device_key
|
||||
DeviceKeys map[string]map[string]json.RawMessage
|
||||
// Maps of user_id to cross signing key
|
||||
MasterKeys map[string]gomatrixserverlib.CrossSigningKey
|
||||
SelfSigningKeys map[string]gomatrixserverlib.CrossSigningKey
|
||||
UserSigningKeys map[string]gomatrixserverlib.CrossSigningKey
|
||||
// Set if there was a fatal error processing this query
|
||||
Error *KeyError
|
||||
}
|
||||
|
||||
type QueryKeyChangesRequest struct {
|
||||
// The offset of the last received key event, or sarama.OffsetOldest if this is from the beginning
|
||||
Offset int64
|
||||
// The inclusive offset where to track key changes up to. Messages with this offset are included in the response.
|
||||
// Use types.OffsetNewest if the offset is unknown (then check the response Offset to avoid racing).
|
||||
ToOffset int64
|
||||
}
|
||||
|
||||
type QueryKeyChangesResponse struct {
|
||||
// The set of users who have had their keys change.
|
||||
UserIDs []string
|
||||
// The latest offset represented in this response.
|
||||
Offset int64
|
||||
// Set if there was a problem handling the request.
|
||||
Error *KeyError
|
||||
}
|
||||
|
||||
type QueryOneTimeKeysRequest struct {
|
||||
// The local user to query OTK counts for
|
||||
UserID string
|
||||
// The device to query OTK counts for
|
||||
DeviceID string
|
||||
}
|
||||
|
||||
type QueryOneTimeKeysResponse struct {
|
||||
// OTK key counts, in the extended /sync form described by https://matrix.org/docs/spec/client_server/r0.6.1#id84
|
||||
Count OneTimeKeysCount
|
||||
Error *KeyError
|
||||
}
|
||||
|
||||
type QueryDeviceMessagesRequest struct {
|
||||
UserID string
|
||||
}
|
||||
|
||||
type QueryDeviceMessagesResponse struct {
|
||||
// The latest stream ID
|
||||
StreamID int64
|
||||
Devices []DeviceMessage
|
||||
Error *KeyError
|
||||
}
|
||||
|
||||
type QuerySignaturesRequest struct {
|
||||
// A map of target user ID -> target key/device IDs to retrieve signatures for
|
||||
TargetIDs map[string][]gomatrixserverlib.KeyID `json:"target_ids"`
|
||||
}
|
||||
|
||||
type QuerySignaturesResponse struct {
|
||||
// A map of target user ID -> target key/device ID -> origin user ID -> origin key/device ID -> signatures
|
||||
Signatures map[string]map[gomatrixserverlib.KeyID]types.CrossSigningSigMap
|
||||
// A map of target user ID -> cross-signing master key
|
||||
MasterKeys map[string]gomatrixserverlib.CrossSigningKey
|
||||
// A map of target user ID -> cross-signing self-signing key
|
||||
SelfSigningKeys map[string]gomatrixserverlib.CrossSigningKey
|
||||
// A map of target user ID -> cross-signing user-signing key
|
||||
UserSigningKeys map[string]gomatrixserverlib.CrossSigningKey
|
||||
// The request error, if any
|
||||
Error *KeyError
|
||||
}
|
||||
|
||||
type PerformMarkAsStaleRequest struct {
|
||||
UserID string
|
||||
Domain gomatrixserverlib.ServerName
|
||||
DeviceID string
|
||||
}
|
||||
|
|
|
@ -37,7 +37,7 @@ type OutputReceiptEventConsumer struct {
|
|||
jetstream nats.JetStreamContext
|
||||
durable string
|
||||
topic string
|
||||
db storage.Database
|
||||
db storage.UserDatabase
|
||||
serverName gomatrixserverlib.ServerName
|
||||
syncProducer *producers.SyncAPI
|
||||
pgClient pushgateway.Client
|
||||
|
@ -49,7 +49,7 @@ func NewOutputReceiptEventConsumer(
|
|||
process *process.ProcessContext,
|
||||
cfg *config.UserAPI,
|
||||
js nats.JetStreamContext,
|
||||
store storage.Database,
|
||||
store storage.UserDatabase,
|
||||
syncProducer *producers.SyncAPI,
|
||||
pgClient pushgateway.Client,
|
||||
) *OutputReceiptEventConsumer {
|
||||
|
|
95
userapi/consumers/devicelistupdate.go
Normal file
95
userapi/consumers/devicelistupdate.go
Normal file
|
@ -0,0 +1,95 @@
|
|||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package consumers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/matrix-org/dendrite/userapi/internal"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/nats-io/nats.go"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||
"github.com/matrix-org/dendrite/setup/process"
|
||||
)
|
||||
|
||||
// DeviceListUpdateConsumer consumes device list updates that came in over federation.
|
||||
type DeviceListUpdateConsumer struct {
|
||||
ctx context.Context
|
||||
jetstream nats.JetStreamContext
|
||||
durable string
|
||||
topic string
|
||||
updater *internal.DeviceListUpdater
|
||||
isLocalServerName func(gomatrixserverlib.ServerName) bool
|
||||
}
|
||||
|
||||
// NewDeviceListUpdateConsumer creates a new DeviceListConsumer. Call Start() to begin consuming from key servers.
|
||||
func NewDeviceListUpdateConsumer(
|
||||
process *process.ProcessContext,
|
||||
cfg *config.UserAPI,
|
||||
js nats.JetStreamContext,
|
||||
updater *internal.DeviceListUpdater,
|
||||
) *DeviceListUpdateConsumer {
|
||||
return &DeviceListUpdateConsumer{
|
||||
ctx: process.Context(),
|
||||
jetstream: js,
|
||||
durable: cfg.Matrix.JetStream.Prefixed("KeyServerInputDeviceListConsumer"),
|
||||
topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputDeviceListUpdate),
|
||||
updater: updater,
|
||||
isLocalServerName: cfg.Matrix.IsLocalServerName,
|
||||
}
|
||||
}
|
||||
|
||||
// Start consuming from key servers
|
||||
func (t *DeviceListUpdateConsumer) Start() error {
|
||||
return jetstream.JetStreamConsumer(
|
||||
t.ctx, t.jetstream, t.topic, t.durable, 1,
|
||||
t.onMessage, nats.DeliverAll(), nats.ManualAck(),
|
||||
)
|
||||
}
|
||||
|
||||
// onMessage is called in response to a message received on the
|
||||
// key change events topic from the key server.
|
||||
func (t *DeviceListUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
|
||||
msg := msgs[0] // Guaranteed to exist if onMessage is called
|
||||
var m gomatrixserverlib.DeviceListUpdateEvent
|
||||
if err := json.Unmarshal(msg.Data, &m); err != nil {
|
||||
logrus.WithError(err).Errorf("Failed to read from device list update input topic")
|
||||
return true
|
||||
}
|
||||
origin := gomatrixserverlib.ServerName(msg.Header.Get("origin"))
|
||||
if _, serverName, err := gomatrixserverlib.SplitID('@', m.UserID); err != nil {
|
||||
return true
|
||||
} else if t.isLocalServerName(serverName) {
|
||||
return true
|
||||
} else if serverName != origin {
|
||||
return true
|
||||
}
|
||||
|
||||
err := t.updater.Update(ctx, m)
|
||||
if err != nil {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"user_id": m.UserID,
|
||||
"device_id": m.DeviceID,
|
||||
"stream_id": m.StreamID,
|
||||
"prev_id": m.PrevID,
|
||||
}).WithError(err).Errorf("Failed to update device list")
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
|
@ -38,7 +38,7 @@ type OutputRoomEventConsumer struct {
|
|||
rsAPI rsapi.UserRoomserverAPI
|
||||
jetstream nats.JetStreamContext
|
||||
durable string
|
||||
db storage.Database
|
||||
db storage.UserDatabase
|
||||
topic string
|
||||
pgClient pushgateway.Client
|
||||
syncProducer *producers.SyncAPI
|
||||
|
@ -53,7 +53,7 @@ func NewOutputRoomEventConsumer(
|
|||
process *process.ProcessContext,
|
||||
cfg *config.UserAPI,
|
||||
js nats.JetStreamContext,
|
||||
store storage.Database,
|
||||
store storage.UserDatabase,
|
||||
pgClient pushgateway.Client,
|
||||
rsAPI rsapi.UserRoomserverAPI,
|
||||
syncProducer *producers.SyncAPI,
|
||||
|
|
|
@ -18,11 +18,11 @@ import (
|
|||
userAPITypes "github.com/matrix-org/dendrite/userapi/types"
|
||||
)
|
||||
|
||||
func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
|
||||
func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.UserDatabase, func()) {
|
||||
base, baseclose := testrig.CreateBaseDendrite(t, dbType)
|
||||
t.Helper()
|
||||
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||
db, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{
|
||||
db, err := storage.NewUserDatabase(base, &config.DatabaseOptions{
|
||||
ConnectionString: config.DataSource(connStr),
|
||||
}, "", 4, 0, 0, "")
|
||||
if err != nil {
|
||||
|
|
111
userapi/consumers/signingkeyupdate.go
Normal file
111
userapi/consumers/signingkeyupdate.go
Normal file
|
@ -0,0 +1,111 @@
|
|||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package consumers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/nats-io/nats.go"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||
"github.com/matrix-org/dendrite/setup/process"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
)
|
||||
|
||||
// SigningKeyUpdateConsumer consumes signing key updates that came in over federation.
|
||||
type SigningKeyUpdateConsumer struct {
|
||||
ctx context.Context
|
||||
jetstream nats.JetStreamContext
|
||||
durable string
|
||||
topic string
|
||||
userAPI api.UploadDeviceKeysAPI
|
||||
cfg *config.UserAPI
|
||||
isLocalServerName func(gomatrixserverlib.ServerName) bool
|
||||
}
|
||||
|
||||
// NewSigningKeyUpdateConsumer creates a new SigningKeyUpdateConsumer. Call Start() to begin consuming from key servers.
|
||||
func NewSigningKeyUpdateConsumer(
|
||||
process *process.ProcessContext,
|
||||
cfg *config.UserAPI,
|
||||
js nats.JetStreamContext,
|
||||
userAPI api.UploadDeviceKeysAPI,
|
||||
) *SigningKeyUpdateConsumer {
|
||||
return &SigningKeyUpdateConsumer{
|
||||
ctx: process.Context(),
|
||||
jetstream: js,
|
||||
durable: cfg.Matrix.JetStream.Prefixed("KeyServerSigningKeyConsumer"),
|
||||
topic: cfg.Matrix.JetStream.Prefixed(jetstream.InputSigningKeyUpdate),
|
||||
userAPI: userAPI,
|
||||
cfg: cfg,
|
||||
isLocalServerName: cfg.Matrix.IsLocalServerName,
|
||||
}
|
||||
}
|
||||
|
||||
// Start consuming from key servers
|
||||
func (t *SigningKeyUpdateConsumer) Start() error {
|
||||
return jetstream.JetStreamConsumer(
|
||||
t.ctx, t.jetstream, t.topic, t.durable, 1,
|
||||
t.onMessage, nats.DeliverAll(), nats.ManualAck(),
|
||||
)
|
||||
}
|
||||
|
||||
// onMessage is called in response to a message received on the
|
||||
// signing key update events topic from the key server.
|
||||
func (t *SigningKeyUpdateConsumer) onMessage(ctx context.Context, msgs []*nats.Msg) bool {
|
||||
msg := msgs[0] // Guaranteed to exist if onMessage is called
|
||||
var updatePayload api.CrossSigningKeyUpdate
|
||||
if err := json.Unmarshal(msg.Data, &updatePayload); err != nil {
|
||||
logrus.WithError(err).Errorf("Failed to read from signing key update input topic")
|
||||
return true
|
||||
}
|
||||
origin := gomatrixserverlib.ServerName(msg.Header.Get("origin"))
|
||||
if _, serverName, err := gomatrixserverlib.SplitID('@', updatePayload.UserID); err != nil {
|
||||
logrus.WithError(err).Error("failed to split user id")
|
||||
return true
|
||||
} else if t.isLocalServerName(serverName) {
|
||||
logrus.Warn("dropping device key update from ourself")
|
||||
return true
|
||||
} else if serverName != origin {
|
||||
logrus.Warnf("dropping device key update, %s != %s", serverName, origin)
|
||||
return true
|
||||
}
|
||||
|
||||
keys := gomatrixserverlib.CrossSigningKeys{}
|
||||
if updatePayload.MasterKey != nil {
|
||||
keys.MasterKey = *updatePayload.MasterKey
|
||||
}
|
||||
if updatePayload.SelfSigningKey != nil {
|
||||
keys.SelfSigningKey = *updatePayload.SelfSigningKey
|
||||
}
|
||||
uploadReq := &api.PerformUploadDeviceKeysRequest{
|
||||
CrossSigningKeys: keys,
|
||||
UserID: updatePayload.UserID,
|
||||
}
|
||||
uploadRes := &api.PerformUploadDeviceKeysResponse{}
|
||||
if err := t.userAPI.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes); err != nil {
|
||||
logrus.WithError(err).Error("failed to upload device keys")
|
||||
return false
|
||||
}
|
||||
if uploadRes.Error != nil {
|
||||
logrus.WithError(uploadRes.Error).Error("failed to upload device keys")
|
||||
return true
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
587
userapi/internal/cross_signing.go
Normal file
587
userapi/internal/cross_signing.go
Normal file
|
@ -0,0 +1,587 @@
|
|||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/curve25519"
|
||||
)
|
||||
|
||||
func sanityCheckKey(key gomatrixserverlib.CrossSigningKey, userID string, purpose gomatrixserverlib.CrossSigningKeyPurpose) error {
|
||||
// Is there exactly one key?
|
||||
if len(key.Keys) != 1 {
|
||||
return fmt.Errorf("should contain exactly one key")
|
||||
}
|
||||
|
||||
// Does the key ID match the key value? Iterates exactly once
|
||||
for keyID, keyData := range key.Keys {
|
||||
b64 := keyData.Encode()
|
||||
tokens := strings.Split(string(keyID), ":")
|
||||
if len(tokens) != 2 {
|
||||
return fmt.Errorf("key ID is incorrectly formatted")
|
||||
}
|
||||
if tokens[1] != b64 {
|
||||
return fmt.Errorf("key ID isn't correct")
|
||||
}
|
||||
switch tokens[0] {
|
||||
case "ed25519":
|
||||
if len(keyData) != ed25519.PublicKeySize {
|
||||
return fmt.Errorf("ed25519 key is not the correct length")
|
||||
}
|
||||
case "curve25519":
|
||||
if len(keyData) != curve25519.PointSize {
|
||||
return fmt.Errorf("curve25519 key is not the correct length")
|
||||
}
|
||||
default:
|
||||
// We can't enforce the key length to be correct for an
|
||||
// algorithm that we don't recognise, so instead we'll
|
||||
// just make sure that it isn't incredibly excessive.
|
||||
if l := len(keyData); l > 4096 {
|
||||
return fmt.Errorf("unknown key type is too long (%d bytes)", l)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check to see if the signatures make sense
|
||||
for _, forOriginUser := range key.Signatures {
|
||||
for originKeyID, originSignature := range forOriginUser {
|
||||
switch strings.SplitN(string(originKeyID), ":", 1)[0] {
|
||||
case "ed25519":
|
||||
if len(originSignature) != ed25519.SignatureSize {
|
||||
return fmt.Errorf("ed25519 signature is not the correct length")
|
||||
}
|
||||
case "curve25519":
|
||||
return fmt.Errorf("curve25519 signatures are impossible")
|
||||
default:
|
||||
if l := len(originSignature); l > 4096 {
|
||||
return fmt.Errorf("unknown signature type is too long (%d bytes)", l)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Does the key claim to be from the right user?
|
||||
if userID != key.UserID {
|
||||
return fmt.Errorf("key has a user ID mismatch")
|
||||
}
|
||||
|
||||
// Does the key contain the correct purpose?
|
||||
useful := false
|
||||
for _, usage := range key.Usage {
|
||||
if usage == purpose {
|
||||
useful = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !useful {
|
||||
return fmt.Errorf("key does not contain correct usage purpose")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// nolint:gocyclo
|
||||
func (a *UserInternalAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error {
|
||||
// Find the keys to store.
|
||||
byPurpose := map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey{}
|
||||
toStore := types.CrossSigningKeyMap{}
|
||||
hasMasterKey := false
|
||||
|
||||
if len(req.MasterKey.Keys) > 0 {
|
||||
if err := sanityCheckKey(req.MasterKey, req.UserID, gomatrixserverlib.CrossSigningKeyPurposeMaster); err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: "Master key sanity check failed: " + err.Error(),
|
||||
IsInvalidParam: true,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
byPurpose[gomatrixserverlib.CrossSigningKeyPurposeMaster] = req.MasterKey
|
||||
for _, key := range req.MasterKey.Keys { // iterates once, see sanityCheckKey
|
||||
toStore[gomatrixserverlib.CrossSigningKeyPurposeMaster] = key
|
||||
}
|
||||
hasMasterKey = true
|
||||
}
|
||||
|
||||
if len(req.SelfSigningKey.Keys) > 0 {
|
||||
if err := sanityCheckKey(req.SelfSigningKey, req.UserID, gomatrixserverlib.CrossSigningKeyPurposeSelfSigning); err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: "Self-signing key sanity check failed: " + err.Error(),
|
||||
IsInvalidParam: true,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
byPurpose[gomatrixserverlib.CrossSigningKeyPurposeSelfSigning] = req.SelfSigningKey
|
||||
for _, key := range req.SelfSigningKey.Keys { // iterates once, see sanityCheckKey
|
||||
toStore[gomatrixserverlib.CrossSigningKeyPurposeSelfSigning] = key
|
||||
}
|
||||
}
|
||||
|
||||
if len(req.UserSigningKey.Keys) > 0 {
|
||||
if err := sanityCheckKey(req.UserSigningKey, req.UserID, gomatrixserverlib.CrossSigningKeyPurposeUserSigning); err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: "User-signing key sanity check failed: " + err.Error(),
|
||||
IsInvalidParam: true,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
byPurpose[gomatrixserverlib.CrossSigningKeyPurposeUserSigning] = req.UserSigningKey
|
||||
for _, key := range req.UserSigningKey.Keys { // iterates once, see sanityCheckKey
|
||||
toStore[gomatrixserverlib.CrossSigningKeyPurposeUserSigning] = key
|
||||
}
|
||||
}
|
||||
|
||||
// If there's nothing to do then stop here.
|
||||
if len(toStore) == 0 {
|
||||
res.Error = &api.KeyError{
|
||||
Err: "No keys were supplied in the request",
|
||||
IsMissingParam: true,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// We can't have a self-signing or user-signing key without a master
|
||||
// key, so make sure we have one of those. We will also only actually do
|
||||
// something if any of the specified keys in the request are different
|
||||
// to what we've got in the database, to avoid generating key change
|
||||
// notifications unnecessarily.
|
||||
existingKeys, err := a.KeyDatabase.CrossSigningKeysDataForUser(ctx, req.UserID)
|
||||
if err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: "Retrieving cross-signing keys from database failed: " + err.Error(),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// If we still can't find a master key for the user then stop the upload.
|
||||
// This satisfies the "Fails to upload self-signing key without master key" test.
|
||||
if !hasMasterKey {
|
||||
if _, hasMasterKey = existingKeys[gomatrixserverlib.CrossSigningKeyPurposeMaster]; !hasMasterKey {
|
||||
res.Error = &api.KeyError{
|
||||
Err: "No master key was found",
|
||||
IsMissingParam: true,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Check if anything actually changed compared to what we have in the database.
|
||||
changed := false
|
||||
for _, purpose := range []gomatrixserverlib.CrossSigningKeyPurpose{
|
||||
gomatrixserverlib.CrossSigningKeyPurposeMaster,
|
||||
gomatrixserverlib.CrossSigningKeyPurposeSelfSigning,
|
||||
gomatrixserverlib.CrossSigningKeyPurposeUserSigning,
|
||||
} {
|
||||
old, gotOld := existingKeys[purpose]
|
||||
new, gotNew := toStore[purpose]
|
||||
if gotOld != gotNew {
|
||||
// A new key purpose has been specified that we didn't know before,
|
||||
// or one has been removed.
|
||||
changed = true
|
||||
break
|
||||
}
|
||||
if !bytes.Equal(old, new) {
|
||||
// One of the existing keys for a purpose we already knew about has
|
||||
// changed.
|
||||
changed = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !changed {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Store the keys.
|
||||
if err := a.KeyDatabase.StoreCrossSigningKeysForUser(ctx, req.UserID, toStore); err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("a.DB.StoreCrossSigningKeysForUser: %s", err),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Now upload any signatures that were included with the keys.
|
||||
for _, key := range byPurpose {
|
||||
var targetKeyID gomatrixserverlib.KeyID
|
||||
for targetKey := range key.Keys { // iterates once, see sanityCheckKey
|
||||
targetKeyID = targetKey
|
||||
}
|
||||
for sigUserID, forSigUserID := range key.Signatures {
|
||||
if sigUserID != req.UserID {
|
||||
continue
|
||||
}
|
||||
for sigKeyID, sigBytes := range forSigUserID {
|
||||
if err := a.KeyDatabase.StoreCrossSigningSigsForTarget(ctx, sigUserID, sigKeyID, req.UserID, targetKeyID, sigBytes); err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("a.DB.StoreCrossSigningSigsForTarget: %s", err),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Finally, generate a notification that we updated the keys.
|
||||
update := api.CrossSigningKeyUpdate{
|
||||
UserID: req.UserID,
|
||||
}
|
||||
if mk, ok := byPurpose[gomatrixserverlib.CrossSigningKeyPurposeMaster]; ok {
|
||||
update.MasterKey = &mk
|
||||
}
|
||||
if ssk, ok := byPurpose[gomatrixserverlib.CrossSigningKeyPurposeSelfSigning]; ok {
|
||||
update.SelfSigningKey = &ssk
|
||||
}
|
||||
if update.MasterKey == nil && update.SelfSigningKey == nil {
|
||||
return nil
|
||||
}
|
||||
if err := a.KeyChangeProducer.ProduceSigningKeyUpdate(update); err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) PerformUploadDeviceSignatures(ctx context.Context, req *api.PerformUploadDeviceSignaturesRequest, res *api.PerformUploadDeviceSignaturesResponse) error {
|
||||
// Before we do anything, we need the master and self-signing keys for this user.
|
||||
// Then we can verify the signatures make sense.
|
||||
queryReq := &api.QueryKeysRequest{
|
||||
UserID: req.UserID,
|
||||
UserToDevices: map[string][]string{},
|
||||
}
|
||||
queryRes := &api.QueryKeysResponse{}
|
||||
for userID := range req.Signatures {
|
||||
queryReq.UserToDevices[userID] = []string{}
|
||||
}
|
||||
_ = a.QueryKeys(ctx, queryReq, queryRes)
|
||||
|
||||
selfSignatures := map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{}
|
||||
otherSignatures := map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{}
|
||||
|
||||
// Sort signatures into two groups: one where people have signed their own
|
||||
// keys and one where people have signed someone elses
|
||||
for userID, forUserID := range req.Signatures {
|
||||
for keyID, keyOrDevice := range forUserID {
|
||||
switch key := keyOrDevice.CrossSigningBody.(type) {
|
||||
case *gomatrixserverlib.CrossSigningKey:
|
||||
if key.UserID == req.UserID {
|
||||
if _, ok := selfSignatures[userID]; !ok {
|
||||
selfSignatures[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{}
|
||||
}
|
||||
selfSignatures[userID][keyID] = keyOrDevice
|
||||
} else {
|
||||
if _, ok := otherSignatures[userID]; !ok {
|
||||
otherSignatures[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{}
|
||||
}
|
||||
otherSignatures[userID][keyID] = keyOrDevice
|
||||
}
|
||||
|
||||
case *gomatrixserverlib.DeviceKeys:
|
||||
if key.UserID == req.UserID {
|
||||
if _, ok := selfSignatures[userID]; !ok {
|
||||
selfSignatures[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{}
|
||||
}
|
||||
selfSignatures[userID][keyID] = keyOrDevice
|
||||
} else {
|
||||
if _, ok := otherSignatures[userID]; !ok {
|
||||
otherSignatures[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice{}
|
||||
}
|
||||
otherSignatures[userID][keyID] = keyOrDevice
|
||||
}
|
||||
|
||||
default:
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := a.processSelfSignatures(ctx, selfSignatures); err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("a.processSelfSignatures: %s", err),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := a.processOtherSignatures(ctx, req.UserID, queryRes, otherSignatures); err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("a.processOtherSignatures: %s", err),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Finally, generate a notification that we updated the signatures.
|
||||
for userID := range req.Signatures {
|
||||
masterKey := queryRes.MasterKeys[userID]
|
||||
selfSigningKey := queryRes.SelfSigningKeys[userID]
|
||||
update := api.CrossSigningKeyUpdate{
|
||||
UserID: userID,
|
||||
MasterKey: &masterKey,
|
||||
SelfSigningKey: &selfSigningKey,
|
||||
}
|
||||
if err := a.KeyChangeProducer.ProduceSigningKeyUpdate(update); err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("a.Producer.ProduceSigningKeyUpdate: %s", err),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) processSelfSignatures(
|
||||
ctx context.Context,
|
||||
signatures map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice,
|
||||
) error {
|
||||
// Here we will process:
|
||||
// * The user signing their own devices using their self-signing key
|
||||
// * The user signing their master key using one of their devices
|
||||
|
||||
for targetUserID, forTargetUserID := range signatures {
|
||||
for targetKeyID, signature := range forTargetUserID {
|
||||
switch sig := signature.CrossSigningBody.(type) {
|
||||
case *gomatrixserverlib.CrossSigningKey:
|
||||
for keyID := range sig.Keys {
|
||||
split := strings.SplitN(string(keyID), ":", 2)
|
||||
if len(split) > 1 && gomatrixserverlib.KeyID(split[1]) == targetKeyID {
|
||||
targetKeyID = keyID // contains the ed25519: or other scheme
|
||||
break
|
||||
}
|
||||
}
|
||||
for originUserID, forOriginUserID := range sig.Signatures {
|
||||
for originKeyID, originSig := range forOriginUserID {
|
||||
if err := a.KeyDatabase.StoreCrossSigningSigsForTarget(
|
||||
ctx, originUserID, originKeyID, targetUserID, targetKeyID, originSig,
|
||||
); err != nil {
|
||||
return fmt.Errorf("a.DB.StoreCrossSigningKeysForTarget: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case *gomatrixserverlib.DeviceKeys:
|
||||
for originUserID, forOriginUserID := range sig.Signatures {
|
||||
for originKeyID, originSig := range forOriginUserID {
|
||||
if err := a.KeyDatabase.StoreCrossSigningSigsForTarget(
|
||||
ctx, originUserID, originKeyID, targetUserID, targetKeyID, originSig,
|
||||
); err != nil {
|
||||
return fmt.Errorf("a.DB.StoreCrossSigningKeysForTarget: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unexpected type assertion")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) processOtherSignatures(
|
||||
ctx context.Context, userID string, queryRes *api.QueryKeysResponse,
|
||||
signatures map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.CrossSigningForKeyOrDevice,
|
||||
) error {
|
||||
// Here we will process:
|
||||
// * A user signing someone else's master keys using their user-signing keys
|
||||
|
||||
for targetUserID, forTargetUserID := range signatures {
|
||||
for _, signature := range forTargetUserID {
|
||||
switch sig := signature.CrossSigningBody.(type) {
|
||||
case *gomatrixserverlib.CrossSigningKey:
|
||||
// Find the local copy of the master key. We'll use this to be
|
||||
// sure that the supplied stanza matches the key that we think it
|
||||
// should be.
|
||||
masterKey, ok := queryRes.MasterKeys[targetUserID]
|
||||
if !ok {
|
||||
return fmt.Errorf("failed to find master key for user %q", targetUserID)
|
||||
}
|
||||
|
||||
// For each key ID, write the signatures. Maybe there'll be more
|
||||
// than one algorithm in the future so it's best not to focus on
|
||||
// everything being ed25519:.
|
||||
for targetKeyID, suppliedKeyData := range sig.Keys {
|
||||
// The master key will be supplied in the request, but we should
|
||||
// make sure that it matches what we think the master key should
|
||||
// actually be.
|
||||
localKeyData, lok := masterKey.Keys[targetKeyID]
|
||||
if !lok {
|
||||
return fmt.Errorf("uploaded master key %q for user %q doesn't match local copy", targetKeyID, targetUserID)
|
||||
} else if !bytes.Equal(suppliedKeyData, localKeyData) {
|
||||
return fmt.Errorf("uploaded master key %q for user %q doesn't match local copy", targetKeyID, targetUserID)
|
||||
}
|
||||
|
||||
// We only care about the signatures from the uploading user, so
|
||||
// we will ignore anything that didn't originate from them.
|
||||
userSigs, ok := sig.Signatures[userID]
|
||||
if !ok {
|
||||
return fmt.Errorf("there are no signatures on master key %q from uploading user %q", targetKeyID, userID)
|
||||
}
|
||||
|
||||
for originKeyID, originSig := range userSigs {
|
||||
if err := a.KeyDatabase.StoreCrossSigningSigsForTarget(
|
||||
ctx, userID, originKeyID, targetUserID, targetKeyID, originSig,
|
||||
); err != nil {
|
||||
return fmt.Errorf("a.DB.StoreCrossSigningKeysForTarget: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
// Users should only be signing another person's master key,
|
||||
// so if we're here, it's probably because it's actually a
|
||||
// gomatrixserverlib.DeviceKeys, which doesn't make sense.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) crossSigningKeysFromDatabase(
|
||||
ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse,
|
||||
) {
|
||||
for targetUserID := range req.UserToDevices {
|
||||
keys, err := a.KeyDatabase.CrossSigningKeysForUser(ctx, targetUserID)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Errorf("Failed to get cross-signing keys for user %q", targetUserID)
|
||||
continue
|
||||
}
|
||||
|
||||
for keyType, key := range keys {
|
||||
var keyID gomatrixserverlib.KeyID
|
||||
for id := range key.Keys {
|
||||
keyID = id
|
||||
break
|
||||
}
|
||||
|
||||
sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, keyID)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
logrus.WithError(err).Errorf("Failed to get cross-signing signatures for user %q key %q", targetUserID, keyID)
|
||||
continue
|
||||
}
|
||||
|
||||
appendSignature := func(originUserID string, originKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) {
|
||||
if key.Signatures == nil {
|
||||
key.Signatures = types.CrossSigningSigMap{}
|
||||
}
|
||||
if _, ok := key.Signatures[originUserID]; !ok {
|
||||
key.Signatures[originUserID] = make(map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes)
|
||||
}
|
||||
key.Signatures[originUserID][originKeyID] = signature
|
||||
}
|
||||
|
||||
for originUserID, forOrigin := range sigMap {
|
||||
for originKeyID, signature := range forOrigin {
|
||||
switch {
|
||||
case req.UserID != "" && originUserID == req.UserID:
|
||||
// Include signatures that we created
|
||||
appendSignature(originUserID, originKeyID, signature)
|
||||
case originUserID == targetUserID:
|
||||
// Include signatures that were created by the person whose key
|
||||
// we are processing
|
||||
appendSignature(originUserID, originKeyID, signature)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch keyType {
|
||||
case gomatrixserverlib.CrossSigningKeyPurposeMaster:
|
||||
res.MasterKeys[targetUserID] = key
|
||||
|
||||
case gomatrixserverlib.CrossSigningKeyPurposeSelfSigning:
|
||||
res.SelfSigningKeys[targetUserID] = key
|
||||
|
||||
case gomatrixserverlib.CrossSigningKeyPurposeUserSigning:
|
||||
res.UserSigningKeys[targetUserID] = key
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) QuerySignatures(ctx context.Context, req *api.QuerySignaturesRequest, res *api.QuerySignaturesResponse) error {
|
||||
for targetUserID, forTargetUser := range req.TargetIDs {
|
||||
keyMap, err := a.KeyDatabase.CrossSigningKeysForUser(ctx, targetUserID)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("a.DB.CrossSigningKeysForUser: %s", err),
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
for targetPurpose, targetKey := range keyMap {
|
||||
switch targetPurpose {
|
||||
case gomatrixserverlib.CrossSigningKeyPurposeMaster:
|
||||
if res.MasterKeys == nil {
|
||||
res.MasterKeys = map[string]gomatrixserverlib.CrossSigningKey{}
|
||||
}
|
||||
res.MasterKeys[targetUserID] = targetKey
|
||||
|
||||
case gomatrixserverlib.CrossSigningKeyPurposeSelfSigning:
|
||||
if res.SelfSigningKeys == nil {
|
||||
res.SelfSigningKeys = map[string]gomatrixserverlib.CrossSigningKey{}
|
||||
}
|
||||
res.SelfSigningKeys[targetUserID] = targetKey
|
||||
|
||||
case gomatrixserverlib.CrossSigningKeyPurposeUserSigning:
|
||||
if res.UserSigningKeys == nil {
|
||||
res.UserSigningKeys = map[string]gomatrixserverlib.CrossSigningKey{}
|
||||
}
|
||||
res.UserSigningKeys[targetUserID] = targetKey
|
||||
}
|
||||
}
|
||||
|
||||
for _, targetKeyID := range forTargetUser {
|
||||
// Get own signatures only.
|
||||
sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, targetUserID, targetUserID, targetKeyID)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("a.DB.CrossSigningSigsForTarget: %s", err),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
for sourceUserID, forSourceUser := range sigMap {
|
||||
for sourceKeyID, sourceSig := range forSourceUser {
|
||||
if res.Signatures == nil {
|
||||
res.Signatures = map[string]map[gomatrixserverlib.KeyID]types.CrossSigningSigMap{}
|
||||
}
|
||||
if _, ok := res.Signatures[targetUserID]; !ok {
|
||||
res.Signatures[targetUserID] = map[gomatrixserverlib.KeyID]types.CrossSigningSigMap{}
|
||||
}
|
||||
if _, ok := res.Signatures[targetUserID][targetKeyID]; !ok {
|
||||
res.Signatures[targetUserID][targetKeyID] = types.CrossSigningSigMap{}
|
||||
}
|
||||
if _, ok := res.Signatures[targetUserID][targetKeyID][sourceUserID]; !ok {
|
||||
res.Signatures[targetUserID][targetKeyID][sourceUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
|
||||
}
|
||||
res.Signatures[targetUserID][targetKeyID][sourceUserID][sourceKeyID] = sourceSig
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
579
userapi/internal/device_list_update.go
Normal file
579
userapi/internal/device_list_update.go
Normal file
|
@ -0,0 +1,579 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
rsapi "github.com/matrix-org/dendrite/roomserver/api"
|
||||
|
||||
"github.com/matrix-org/gomatrix"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
fedsenderapi "github.com/matrix-org/dendrite/federationapi/api"
|
||||
"github.com/matrix-org/dendrite/setup/process"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
)
|
||||
|
||||
var (
|
||||
deviceListUpdateCount = prometheus.NewCounterVec(
|
||||
prometheus.CounterOpts{
|
||||
Namespace: "dendrite",
|
||||
Subsystem: "keyserver",
|
||||
Name: "device_list_update",
|
||||
Help: "Number of times we have attempted to update device lists from this server",
|
||||
},
|
||||
[]string{"server"},
|
||||
)
|
||||
)
|
||||
|
||||
const requestTimeout = time.Second * 30
|
||||
|
||||
func init() {
|
||||
prometheus.MustRegister(
|
||||
deviceListUpdateCount,
|
||||
)
|
||||
}
|
||||
|
||||
// DeviceListUpdater handles device list updates from remote servers.
|
||||
//
|
||||
// In the case where we have the prev_id for an update, the updater just stores the update (after acquiring a per-user lock).
|
||||
// In the case where we do not have the prev_id for an update, the updater marks the user_id as stale and notifies
|
||||
// a worker to get the latest device list for this user. Note: stream IDs are scoped per user so missing a prev_id
|
||||
// for a (user, device) does not mean that DEVICE is outdated as the previous ID could be for a different device:
|
||||
// we have to invalidate all devices for that user. Once the list has been fetched, the per-user lock is acquired and the
|
||||
// updater stores the latest list along with the latest stream ID.
|
||||
//
|
||||
// On startup, the updater spins up N workers which are responsible for querying device keys from remote servers.
|
||||
// Workers are scoped by homeserver domain, with one worker responsible for many domains, determined by hashing
|
||||
// mod N the server name. Work is sent via a channel which just serves to "poke" the worker as the data is retrieved
|
||||
// from the database (which allows us to batch requests to the same server). This has a number of desirable properties:
|
||||
// - We guarantee only 1 in-flight /keys/query request per server at any time as there is exactly 1 worker responsible
|
||||
// for that domain.
|
||||
// - We don't have unbounded growth in proportion to the number of servers (this is more important in a P2P world where
|
||||
// we have many many servers)
|
||||
// - We can adjust concurrency (at the cost of memory usage) by tuning N, to accommodate mobile devices vs servers.
|
||||
//
|
||||
// The downsides are that:
|
||||
// - Query requests can get queued behind other servers if they hash to the same worker, even if there are other free
|
||||
// workers elsewhere. Whilst suboptimal, provided we cap how long a single request can last (e.g using context timeouts)
|
||||
// we guarantee we will get around to it. Also, more users on a given server does not increase the number of requests
|
||||
// (as /keys/query allows multiple users to be specified) so being stuck behind matrix.org won't materially be any worse
|
||||
// than being stuck behind foo.bar
|
||||
//
|
||||
// In the event that the query fails, a lock is acquired and the server name along with the time to wait before retrying is
|
||||
// set in a map. A restarter goroutine periodically probes this map and injects servers which are ready to be retried.
|
||||
type DeviceListUpdater struct {
|
||||
process *process.ProcessContext
|
||||
// A map from user_id to a mutex. Used when we are missing prev IDs so we don't make more than 1
|
||||
// request to the remote server and race.
|
||||
// TODO: Put in an LRU cache to bound growth
|
||||
userIDToMutex map[string]*sync.Mutex
|
||||
mu *sync.Mutex // protects UserIDToMutex
|
||||
|
||||
db DeviceListUpdaterDatabase
|
||||
api DeviceListUpdaterAPI
|
||||
producer KeyChangeProducer
|
||||
fedClient fedsenderapi.KeyserverFederationAPI
|
||||
workerChans []chan gomatrixserverlib.ServerName
|
||||
thisServer gomatrixserverlib.ServerName
|
||||
|
||||
// When device lists are stale for a user, they get inserted into this map with a channel which `Update` will
|
||||
// block on or timeout via a select.
|
||||
userIDToChan map[string]chan bool
|
||||
userIDToChanMu *sync.Mutex
|
||||
rsAPI rsapi.KeyserverRoomserverAPI
|
||||
}
|
||||
|
||||
// DeviceListUpdaterDatabase is the subset of functionality from storage.Database required for the updater.
|
||||
// Useful for testing.
|
||||
type DeviceListUpdaterDatabase interface {
|
||||
// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
|
||||
// If no domains are given, all user IDs with stale device lists are returned.
|
||||
StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error)
|
||||
|
||||
// MarkDeviceListStale sets the stale bit for this user to isStale.
|
||||
MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error
|
||||
|
||||
// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
|
||||
// for this (user, device). Does not modify the stream ID for keys. User IDs in `clearUserIDs` will have all their device keys deleted prior
|
||||
// to insertion - use this when you have a complete snapshot of a user's keys in order to track device deletions correctly.
|
||||
StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error
|
||||
|
||||
// PrevIDsExists returns true if all prev IDs exist for this user.
|
||||
PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error)
|
||||
|
||||
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
|
||||
DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
|
||||
|
||||
DeleteStaleDeviceLists(ctx context.Context, userIDs []string) error
|
||||
}
|
||||
|
||||
type DeviceListUpdaterAPI interface {
|
||||
PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error
|
||||
}
|
||||
|
||||
// KeyChangeProducer is the interface for producers.KeyChange useful for testing.
|
||||
type KeyChangeProducer interface {
|
||||
ProduceKeyChanges(keys []api.DeviceMessage) error
|
||||
}
|
||||
|
||||
// NewDeviceListUpdater creates a new updater which fetches fresh device lists when they go stale.
|
||||
func NewDeviceListUpdater(
|
||||
process *process.ProcessContext, db DeviceListUpdaterDatabase,
|
||||
api DeviceListUpdaterAPI, producer KeyChangeProducer,
|
||||
fedClient fedsenderapi.KeyserverFederationAPI, numWorkers int,
|
||||
rsAPI rsapi.KeyserverRoomserverAPI, thisServer gomatrixserverlib.ServerName,
|
||||
) *DeviceListUpdater {
|
||||
return &DeviceListUpdater{
|
||||
process: process,
|
||||
userIDToMutex: make(map[string]*sync.Mutex),
|
||||
mu: &sync.Mutex{},
|
||||
db: db,
|
||||
api: api,
|
||||
producer: producer,
|
||||
fedClient: fedClient,
|
||||
thisServer: thisServer,
|
||||
workerChans: make([]chan gomatrixserverlib.ServerName, numWorkers),
|
||||
userIDToChan: make(map[string]chan bool),
|
||||
userIDToChanMu: &sync.Mutex{},
|
||||
rsAPI: rsAPI,
|
||||
}
|
||||
}
|
||||
|
||||
// Start the device list updater, which will try to refresh any stale device lists.
|
||||
func (u *DeviceListUpdater) Start() error {
|
||||
for i := 0; i < len(u.workerChans); i++ {
|
||||
// Allocate a small buffer per channel.
|
||||
// If the buffer limit is reached, backpressure will cause the processing of EDUs
|
||||
// to stop (in this transaction) until key requests can be made.
|
||||
ch := make(chan gomatrixserverlib.ServerName, 10)
|
||||
u.workerChans[i] = ch
|
||||
go u.worker(ch)
|
||||
}
|
||||
|
||||
staleLists, err := u.db.StaleDeviceLists(u.process.Context(), []gomatrixserverlib.ServerName{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
offset, step := time.Second*10, time.Second
|
||||
if max := len(staleLists); max > 120 {
|
||||
step = (time.Second * 120) / time.Duration(max)
|
||||
}
|
||||
for _, userID := range staleLists {
|
||||
userID := userID // otherwise we are only sending the last entry
|
||||
time.AfterFunc(offset, func() {
|
||||
u.notifyWorkers(userID)
|
||||
})
|
||||
offset += step
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CleanUp removes stale device entries for users we don't share a room with anymore
|
||||
func (u *DeviceListUpdater) CleanUp() error {
|
||||
staleUsers, err := u.db.StaleDeviceLists(u.process.Context(), []gomatrixserverlib.ServerName{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
res := rsapi.QueryLeftUsersResponse{}
|
||||
if err = u.rsAPI.QueryLeftUsers(u.process.Context(), &rsapi.QueryLeftUsersRequest{StaleDeviceListUsers: staleUsers}, &res); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(res.LeftUsers) == 0 {
|
||||
return nil
|
||||
}
|
||||
logrus.Debugf("Deleting %d stale device list entries", len(res.LeftUsers))
|
||||
return u.db.DeleteStaleDeviceLists(u.process.Context(), res.LeftUsers)
|
||||
}
|
||||
|
||||
func (u *DeviceListUpdater) mutex(userID string) *sync.Mutex {
|
||||
u.mu.Lock()
|
||||
defer u.mu.Unlock()
|
||||
if u.userIDToMutex[userID] == nil {
|
||||
u.userIDToMutex[userID] = &sync.Mutex{}
|
||||
}
|
||||
return u.userIDToMutex[userID]
|
||||
}
|
||||
|
||||
// ManualUpdate invalidates the device list for the given user and fetches the latest and tracks it.
|
||||
// Blocks until the device list is synced or the timeout is reached.
|
||||
func (u *DeviceListUpdater) ManualUpdate(ctx context.Context, serverName gomatrixserverlib.ServerName, userID string) error {
|
||||
mu := u.mutex(userID)
|
||||
mu.Lock()
|
||||
err := u.db.MarkDeviceListStale(ctx, userID, true)
|
||||
mu.Unlock()
|
||||
if err != nil {
|
||||
return fmt.Errorf("ManualUpdate: failed to mark device list for %s as stale: %w", userID, err)
|
||||
}
|
||||
u.notifyWorkers(userID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Update blocks until the update has been stored in the database. It blocks primarily for satisfying sytest,
|
||||
// which assumes when /send 200 OKs that the device lists have been updated.
|
||||
func (u *DeviceListUpdater) Update(ctx context.Context, event gomatrixserverlib.DeviceListUpdateEvent) error {
|
||||
isDeviceListStale, err := u.update(ctx, event)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if isDeviceListStale {
|
||||
// poke workers to handle stale device lists
|
||||
u.notifyWorkers(event.UserID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (u *DeviceListUpdater) update(ctx context.Context, event gomatrixserverlib.DeviceListUpdateEvent) (bool, error) {
|
||||
mu := u.mutex(event.UserID)
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
// check if we have the prev IDs
|
||||
exists, err := u.db.PrevIDsExists(ctx, event.UserID, event.PrevID)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to check prev IDs exist for %s (%s): %w", event.UserID, event.DeviceID, err)
|
||||
}
|
||||
// if this is the first time we're hearing about this user, sync the device list manually.
|
||||
if len(event.PrevID) == 0 {
|
||||
exists = false
|
||||
}
|
||||
util.GetLogger(ctx).WithFields(logrus.Fields{
|
||||
"prev_ids_exist": exists,
|
||||
"user_id": event.UserID,
|
||||
"device_id": event.DeviceID,
|
||||
"stream_id": event.StreamID,
|
||||
"prev_ids": event.PrevID,
|
||||
"display_name": event.DeviceDisplayName,
|
||||
"deleted": event.Deleted,
|
||||
}).Trace("DeviceListUpdater.Update")
|
||||
|
||||
// if we haven't missed anything update the database and notify users
|
||||
if exists || event.Deleted {
|
||||
k := event.Keys
|
||||
if event.Deleted {
|
||||
k = nil
|
||||
}
|
||||
keys := []api.DeviceMessage{
|
||||
{
|
||||
Type: api.TypeDeviceKeyUpdate,
|
||||
DeviceKeys: &api.DeviceKeys{
|
||||
DeviceID: event.DeviceID,
|
||||
DisplayName: event.DeviceDisplayName,
|
||||
KeyJSON: k,
|
||||
UserID: event.UserID,
|
||||
},
|
||||
StreamID: event.StreamID,
|
||||
},
|
||||
}
|
||||
|
||||
// DeviceKeysJSON will side-effect modify this, so it needs
|
||||
// to be a copy, not sharing any pointers with the above.
|
||||
deviceKeysCopy := *keys[0].DeviceKeys
|
||||
deviceKeysCopy.KeyJSON = nil
|
||||
existingKeys := []api.DeviceMessage{
|
||||
{
|
||||
Type: keys[0].Type,
|
||||
DeviceKeys: &deviceKeysCopy,
|
||||
StreamID: keys[0].StreamID,
|
||||
},
|
||||
}
|
||||
|
||||
// fetch what keys we had already and only emit changes
|
||||
if err = u.db.DeviceKeysJSON(ctx, existingKeys); err != nil {
|
||||
// non-fatal, log and continue
|
||||
util.GetLogger(ctx).WithError(err).WithField("user_id", event.UserID).Errorf(
|
||||
"failed to query device keys json for calculating diffs",
|
||||
)
|
||||
}
|
||||
|
||||
err = u.db.StoreRemoteDeviceKeys(ctx, keys, nil)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to store remote device keys for %s (%s): %w", event.UserID, event.DeviceID, err)
|
||||
}
|
||||
|
||||
if err = emitDeviceKeyChanges(u.producer, existingKeys, keys, false); err != nil {
|
||||
return false, fmt.Errorf("failed to produce device key changes for %s (%s): %w", event.UserID, event.DeviceID, err)
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
err = u.db.MarkDeviceListStale(ctx, event.UserID, true)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to mark device list for %s as stale: %w", event.UserID, err)
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (u *DeviceListUpdater) notifyWorkers(userID string) {
|
||||
_, remoteServer, err := gomatrixserverlib.SplitID('@', userID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
hash := fnv.New32a()
|
||||
_, _ = hash.Write([]byte(remoteServer))
|
||||
index := int(int64(hash.Sum32()) % int64(len(u.workerChans)))
|
||||
|
||||
ch := u.assignChannel(userID)
|
||||
u.workerChans[index] <- remoteServer
|
||||
select {
|
||||
case <-ch:
|
||||
case <-time.After(10 * time.Second):
|
||||
// we don't return an error in this case as it's not a failure condition.
|
||||
// we mainly block for the benefit of sytest anyway
|
||||
}
|
||||
}
|
||||
|
||||
func (u *DeviceListUpdater) assignChannel(userID string) chan bool {
|
||||
u.userIDToChanMu.Lock()
|
||||
defer u.userIDToChanMu.Unlock()
|
||||
if ch, ok := u.userIDToChan[userID]; ok {
|
||||
return ch
|
||||
}
|
||||
ch := make(chan bool)
|
||||
u.userIDToChan[userID] = ch
|
||||
return ch
|
||||
}
|
||||
|
||||
func (u *DeviceListUpdater) clearChannel(userID string) {
|
||||
u.userIDToChanMu.Lock()
|
||||
defer u.userIDToChanMu.Unlock()
|
||||
if ch, ok := u.userIDToChan[userID]; ok {
|
||||
close(ch)
|
||||
delete(u.userIDToChan, userID)
|
||||
}
|
||||
}
|
||||
|
||||
func (u *DeviceListUpdater) worker(ch chan gomatrixserverlib.ServerName) {
|
||||
retries := make(map[gomatrixserverlib.ServerName]time.Time)
|
||||
retriesMu := &sync.Mutex{}
|
||||
// restarter goroutine which will inject failed servers into ch when it is time
|
||||
go func() {
|
||||
var serversToRetry []gomatrixserverlib.ServerName
|
||||
for {
|
||||
serversToRetry = serversToRetry[:0] // reuse memory
|
||||
time.Sleep(time.Second)
|
||||
retriesMu.Lock()
|
||||
now := time.Now()
|
||||
for srv, retryAt := range retries {
|
||||
if now.After(retryAt) {
|
||||
serversToRetry = append(serversToRetry, srv)
|
||||
}
|
||||
}
|
||||
for _, srv := range serversToRetry {
|
||||
delete(retries, srv)
|
||||
}
|
||||
retriesMu.Unlock()
|
||||
for _, srv := range serversToRetry {
|
||||
ch <- srv
|
||||
}
|
||||
}
|
||||
}()
|
||||
for serverName := range ch {
|
||||
retriesMu.Lock()
|
||||
_, exists := retries[serverName]
|
||||
retriesMu.Unlock()
|
||||
if exists {
|
||||
// Don't retry a server that we're already waiting for.
|
||||
continue
|
||||
}
|
||||
waitTime, shouldRetry := u.processServer(serverName)
|
||||
if shouldRetry {
|
||||
retriesMu.Lock()
|
||||
if _, exists = retries[serverName]; !exists {
|
||||
retries[serverName] = time.Now().Add(waitTime)
|
||||
}
|
||||
retriesMu.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (u *DeviceListUpdater) processServer(serverName gomatrixserverlib.ServerName) (time.Duration, bool) {
|
||||
ctx := u.process.Context()
|
||||
logger := util.GetLogger(ctx).WithField("server_name", serverName)
|
||||
deviceListUpdateCount.WithLabelValues(string(serverName)).Inc()
|
||||
|
||||
waitTime := defaultWaitTime // How long should we wait to try again?
|
||||
successCount := 0 // How many user requests failed?
|
||||
|
||||
userIDs, err := u.db.StaleDeviceLists(ctx, []gomatrixserverlib.ServerName{serverName})
|
||||
if err != nil {
|
||||
logger.WithError(err).Error("Failed to load stale device lists")
|
||||
return waitTime, true
|
||||
}
|
||||
|
||||
defer func() {
|
||||
for _, userID := range userIDs {
|
||||
// always clear the channel to unblock Update calls regardless of success/failure
|
||||
u.clearChannel(userID)
|
||||
}
|
||||
}()
|
||||
|
||||
for _, userID := range userIDs {
|
||||
userWait, err := u.processServerUser(ctx, serverName, userID)
|
||||
if err != nil {
|
||||
if userWait > waitTime {
|
||||
waitTime = userWait
|
||||
}
|
||||
break
|
||||
}
|
||||
successCount++
|
||||
}
|
||||
|
||||
allUsersSucceeded := successCount == len(userIDs)
|
||||
if !allUsersSucceeded {
|
||||
logger.WithFields(logrus.Fields{
|
||||
"total": len(userIDs),
|
||||
"succeeded": successCount,
|
||||
"failed": len(userIDs) - successCount,
|
||||
"wait_time": waitTime,
|
||||
}).Debug("Failed to query device keys for some users")
|
||||
}
|
||||
return waitTime, !allUsersSucceeded
|
||||
}
|
||||
|
||||
func (u *DeviceListUpdater) processServerUser(ctx context.Context, serverName gomatrixserverlib.ServerName, userID string) (time.Duration, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, requestTimeout)
|
||||
defer cancel()
|
||||
logger := util.GetLogger(ctx).WithFields(logrus.Fields{
|
||||
"server_name": serverName,
|
||||
"user_id": userID,
|
||||
})
|
||||
res, err := u.fedClient.GetUserDevices(ctx, u.thisServer, serverName, userID)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
return time.Minute * 10, err
|
||||
}
|
||||
switch e := err.(type) {
|
||||
case *json.UnmarshalTypeError, *json.SyntaxError:
|
||||
logger.WithError(err).Debugf("Device list update for %q contained invalid JSON", userID)
|
||||
return defaultWaitTime, nil
|
||||
case *fedsenderapi.FederationClientError:
|
||||
if e.RetryAfter > 0 {
|
||||
return e.RetryAfter, err
|
||||
} else if e.Blacklisted {
|
||||
return time.Hour * 8, err
|
||||
}
|
||||
case net.Error:
|
||||
// Use the default waitTime, if it's a timeout.
|
||||
// It probably doesn't make sense to try further users.
|
||||
if !e.Timeout() {
|
||||
logger.WithError(e).Debug("GetUserDevices returned net.Error")
|
||||
return time.Minute * 10, err
|
||||
}
|
||||
case gomatrix.HTTPError:
|
||||
// The remote server returned an error, give it some time to recover.
|
||||
// This is to avoid spamming remote servers, which may not be Matrix servers anymore.
|
||||
if e.Code >= 300 {
|
||||
logger.WithError(e).Debug("GetUserDevices returned gomatrix.HTTPError")
|
||||
return hourWaitTime, err
|
||||
}
|
||||
default:
|
||||
// Something else failed
|
||||
logger.WithError(err).Debugf("GetUserDevices returned unknown error type: %T", err)
|
||||
return time.Minute * 10, err
|
||||
}
|
||||
}
|
||||
if res.UserID != userID {
|
||||
logger.WithError(err).Debugf("User ID %q in device list update response doesn't match expected %q", res.UserID, userID)
|
||||
return defaultWaitTime, nil
|
||||
}
|
||||
if res.MasterKey != nil || res.SelfSigningKey != nil {
|
||||
uploadReq := &api.PerformUploadDeviceKeysRequest{
|
||||
UserID: userID,
|
||||
}
|
||||
uploadRes := &api.PerformUploadDeviceKeysResponse{}
|
||||
if res.MasterKey != nil {
|
||||
if err = sanityCheckKey(*res.MasterKey, userID, gomatrixserverlib.CrossSigningKeyPurposeMaster); err == nil {
|
||||
uploadReq.MasterKey = *res.MasterKey
|
||||
}
|
||||
}
|
||||
if res.SelfSigningKey != nil {
|
||||
if err = sanityCheckKey(*res.SelfSigningKey, userID, gomatrixserverlib.CrossSigningKeyPurposeSelfSigning); err == nil {
|
||||
uploadReq.SelfSigningKey = *res.SelfSigningKey
|
||||
}
|
||||
}
|
||||
_ = u.api.PerformUploadDeviceKeys(ctx, uploadReq, uploadRes)
|
||||
}
|
||||
err = u.updateDeviceList(&res)
|
||||
if err != nil {
|
||||
logger.WithError(err).Error("Fetched device list but failed to store/emit it")
|
||||
return defaultWaitTime, err
|
||||
}
|
||||
return defaultWaitTime, nil
|
||||
}
|
||||
|
||||
func (u *DeviceListUpdater) updateDeviceList(res *gomatrixserverlib.RespUserDevices) error {
|
||||
ctx := context.Background() // we've got the keys, don't time out when persisting them to the database.
|
||||
keys := make([]api.DeviceMessage, len(res.Devices))
|
||||
existingKeys := make([]api.DeviceMessage, len(res.Devices))
|
||||
for i, device := range res.Devices {
|
||||
keyJSON, err := json.Marshal(device.Keys)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithField("keys", device.Keys).Error("failed to marshal keys, skipping device")
|
||||
continue
|
||||
}
|
||||
keys[i] = api.DeviceMessage{
|
||||
Type: api.TypeDeviceKeyUpdate,
|
||||
StreamID: res.StreamID,
|
||||
DeviceKeys: &api.DeviceKeys{
|
||||
DeviceID: device.DeviceID,
|
||||
DisplayName: device.DisplayName,
|
||||
UserID: res.UserID,
|
||||
KeyJSON: keyJSON,
|
||||
},
|
||||
}
|
||||
existingKeys[i] = api.DeviceMessage{
|
||||
Type: api.TypeDeviceKeyUpdate,
|
||||
DeviceKeys: &api.DeviceKeys{
|
||||
UserID: res.UserID,
|
||||
DeviceID: device.DeviceID,
|
||||
},
|
||||
}
|
||||
}
|
||||
// fetch what keys we had already and only emit changes
|
||||
if err := u.db.DeviceKeysJSON(ctx, existingKeys); err != nil {
|
||||
// non-fatal, log and continue
|
||||
util.GetLogger(ctx).WithError(err).WithField("user_id", res.UserID).Errorf(
|
||||
"failed to query device keys json for calculating diffs",
|
||||
)
|
||||
}
|
||||
|
||||
err := u.db.StoreRemoteDeviceKeys(ctx, keys, []string{res.UserID})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to store remote device keys: %w", err)
|
||||
}
|
||||
err = u.db.MarkDeviceListStale(ctx, res.UserID, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to mark device list as fresh: %w", err)
|
||||
}
|
||||
err = emitDeviceKeyChanges(u.producer, existingKeys, keys, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to emit key changes for fresh device list: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
22
userapi/internal/device_list_update_default.go
Normal file
22
userapi/internal/device_list_update_default.go
Normal file
|
@ -0,0 +1,22 @@
|
|||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//go:build !vw
|
||||
|
||||
package internal
|
||||
|
||||
import "time"
|
||||
|
||||
const defaultWaitTime = time.Minute
|
||||
const hourWaitTime = time.Hour
|
25
userapi/internal/device_list_update_sytest.go
Normal file
25
userapi/internal/device_list_update_sytest.go
Normal file
|
@ -0,0 +1,25 @@
|
|||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//go:build vw
|
||||
|
||||
package internal
|
||||
|
||||
import "time"
|
||||
|
||||
// Sytest is expecting to receive a `/devices` request. The way it is implemented in Dendrite
|
||||
// results in a one-hour wait time from a previous device so the test times out. This is fine for
|
||||
// production, but makes an otherwise passing test fail.
|
||||
const defaultWaitTime = time.Second
|
||||
const hourWaitTime = time.Second
|
431
userapi/internal/device_list_update_test.go
Normal file
431
userapi/internal/device_list_update_test.go
Normal file
|
@ -0,0 +1,431 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
||||
roomserver "github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/setup/process"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
"github.com/matrix-org/dendrite/test/testrig"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage"
|
||||
)
|
||||
|
||||
var (
|
||||
ctx = context.Background()
|
||||
)
|
||||
|
||||
type mockKeyChangeProducer struct {
|
||||
events []api.DeviceMessage
|
||||
}
|
||||
|
||||
func (p *mockKeyChangeProducer) ProduceKeyChanges(keys []api.DeviceMessage) error {
|
||||
p.events = append(p.events, keys...)
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockDeviceListUpdaterDatabase struct {
|
||||
staleUsers map[string]bool
|
||||
prevIDsExist func(string, []int64) bool
|
||||
storedKeys []api.DeviceMessage
|
||||
mu sync.Mutex // protect staleUsers
|
||||
}
|
||||
|
||||
func (d *mockDeviceListUpdaterDatabase) DeleteStaleDeviceLists(ctx context.Context, userIDs []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
|
||||
// If no domains are given, all user IDs with stale device lists are returned.
|
||||
func (d *mockDeviceListUpdaterDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
var result []string
|
||||
for userID, isStale := range d.staleUsers {
|
||||
if !isStale {
|
||||
continue
|
||||
}
|
||||
_, remoteServer, err := gomatrixserverlib.SplitID('@', userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(domains) == 0 {
|
||||
result = append(result, userID)
|
||||
continue
|
||||
}
|
||||
for _, d := range domains {
|
||||
if remoteServer == d {
|
||||
result = append(result, userID)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// MarkDeviceListStale sets the stale bit for this user to isStale.
|
||||
func (d *mockDeviceListUpdaterDatabase) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
d.staleUsers[userID] = isStale
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *mockDeviceListUpdaterDatabase) isStale(userID string) bool {
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
return d.staleUsers[userID]
|
||||
}
|
||||
|
||||
// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
|
||||
// for this (user, device). Does not modify the stream ID for keys.
|
||||
func (d *mockDeviceListUpdaterDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clear []string) error {
|
||||
d.storedKeys = append(d.storedKeys, keys...)
|
||||
return nil
|
||||
}
|
||||
|
||||
// PrevIDsExists returns true if all prev IDs exist for this user.
|
||||
func (d *mockDeviceListUpdaterDatabase) PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) {
|
||||
return d.prevIDsExist(userID, prevIDs), nil
|
||||
}
|
||||
|
||||
func (d *mockDeviceListUpdaterDatabase) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockDeviceListUpdaterAPI struct {
|
||||
}
|
||||
|
||||
func (d *mockDeviceListUpdaterAPI) PerformUploadDeviceKeys(ctx context.Context, req *api.PerformUploadDeviceKeysRequest, res *api.PerformUploadDeviceKeysResponse) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type roundTripper struct {
|
||||
fn func(*http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
func (t *roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return t.fn(req)
|
||||
}
|
||||
|
||||
func newFedClient(tripper func(*http.Request) (*http.Response, error)) *gomatrixserverlib.FederationClient {
|
||||
_, pkey, _ := ed25519.GenerateKey(nil)
|
||||
fedClient := gomatrixserverlib.NewFederationClient(
|
||||
[]*gomatrixserverlib.SigningIdentity{
|
||||
{
|
||||
ServerName: gomatrixserverlib.ServerName("example.test"),
|
||||
KeyID: gomatrixserverlib.KeyID("ed25519:test"),
|
||||
PrivateKey: pkey,
|
||||
},
|
||||
},
|
||||
)
|
||||
fedClient.Client = *gomatrixserverlib.NewClient(
|
||||
gomatrixserverlib.WithTransport(&roundTripper{tripper}),
|
||||
)
|
||||
return fedClient
|
||||
}
|
||||
|
||||
// Test that the device keys get persisted and emitted if we have the previous IDs.
|
||||
func TestUpdateHavePrevID(t *testing.T) {
|
||||
db := &mockDeviceListUpdaterDatabase{
|
||||
staleUsers: make(map[string]bool),
|
||||
prevIDsExist: func(string, []int64) bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
ap := &mockDeviceListUpdaterAPI{}
|
||||
producer := &mockKeyChangeProducer{}
|
||||
updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1, nil, "localhost")
|
||||
event := gomatrixserverlib.DeviceListUpdateEvent{
|
||||
DeviceDisplayName: "Foo Bar",
|
||||
Deleted: false,
|
||||
DeviceID: "FOO",
|
||||
Keys: []byte(`{"key":"value"}`),
|
||||
PrevID: []int64{0},
|
||||
StreamID: 1,
|
||||
UserID: "@alice:localhost",
|
||||
}
|
||||
err := updater.Update(ctx, event)
|
||||
if err != nil {
|
||||
t.Fatalf("Update returned an error: %s", err)
|
||||
}
|
||||
want := api.DeviceMessage{
|
||||
Type: api.TypeDeviceKeyUpdate,
|
||||
StreamID: event.StreamID,
|
||||
DeviceKeys: &api.DeviceKeys{
|
||||
DeviceID: event.DeviceID,
|
||||
DisplayName: event.DeviceDisplayName,
|
||||
KeyJSON: event.Keys,
|
||||
UserID: event.UserID,
|
||||
},
|
||||
}
|
||||
if !reflect.DeepEqual(producer.events, []api.DeviceMessage{want}) {
|
||||
t.Errorf("Update didn't produce correct event, got %v want %v", producer.events, want)
|
||||
}
|
||||
if !reflect.DeepEqual(db.storedKeys, []api.DeviceMessage{want}) {
|
||||
t.Errorf("DB didn't store correct event, got %v want %v", db.storedKeys, want)
|
||||
}
|
||||
if db.isStale(event.UserID) {
|
||||
t.Errorf("%s incorrectly marked as stale", event.UserID)
|
||||
}
|
||||
}
|
||||
|
||||
// Test that device keys are fetched from the remote server if we are missing prev IDs
|
||||
// and that the user's devices are marked as stale until it succeeds.
|
||||
func TestUpdateNoPrevID(t *testing.T) {
|
||||
db := &mockDeviceListUpdaterDatabase{
|
||||
staleUsers: make(map[string]bool),
|
||||
prevIDsExist: func(string, []int64) bool {
|
||||
return false
|
||||
},
|
||||
}
|
||||
ap := &mockDeviceListUpdaterAPI{}
|
||||
producer := &mockKeyChangeProducer{}
|
||||
remoteUserID := "@alice:example.somewhere"
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
keyJSON := `{"user_id":"` + remoteUserID + `","device_id":"JLAFKJWSCS","algorithms":["m.olm.v1.curve25519-aes-sha2","m.megolm.v1.aes-sha2"],"keys":{"curve25519:JLAFKJWSCS":"3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIArzgyI","ed25519:JLAFKJWSCS":"lEuiRJBit0IG6nUf5pUzWTUEsRVVe/HJkoKuEww9ULI"},"signatures":{"` + remoteUserID + `":{"ed25519:JLAFKJWSCS":"dSO80A01XiigH3uBiDVx/EjzaoycHcjq9lfQX0uWsqxl2giMIiSPR8a4d291W1ihKJL/a+myXS367WT6NAIcBA"}}}`
|
||||
fedClient := newFedClient(func(req *http.Request) (*http.Response, error) {
|
||||
defer wg.Done()
|
||||
if req.URL.Path != "/_matrix/federation/v1/user/devices/"+url.PathEscape(remoteUserID) {
|
||||
return nil, fmt.Errorf("test: invalid path: %s", req.URL.Path)
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: io.NopCloser(strings.NewReader(`
|
||||
{
|
||||
"user_id": "` + remoteUserID + `",
|
||||
"stream_id": 5,
|
||||
"devices": [
|
||||
{
|
||||
"device_id": "JLAFKJWSCS",
|
||||
"keys": ` + keyJSON + `,
|
||||
"device_display_name": "Mobile Phone"
|
||||
}
|
||||
]
|
||||
}
|
||||
`)),
|
||||
}, nil
|
||||
})
|
||||
updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2, nil, "example.test")
|
||||
if err := updater.Start(); err != nil {
|
||||
t.Fatalf("failed to start updater: %s", err)
|
||||
}
|
||||
event := gomatrixserverlib.DeviceListUpdateEvent{
|
||||
DeviceDisplayName: "Mobile Phone",
|
||||
Deleted: false,
|
||||
DeviceID: "another_device_id",
|
||||
Keys: []byte(`{"key":"value"}`),
|
||||
PrevID: []int64{3},
|
||||
StreamID: 4,
|
||||
UserID: remoteUserID,
|
||||
}
|
||||
err := updater.Update(ctx, event)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Update returned an error: %s", err)
|
||||
}
|
||||
t.Log("waiting for /users/devices to be called...")
|
||||
wg.Wait()
|
||||
// wait a bit for db to be updated...
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
want := api.DeviceMessage{
|
||||
Type: api.TypeDeviceKeyUpdate,
|
||||
StreamID: 5,
|
||||
DeviceKeys: &api.DeviceKeys{
|
||||
DeviceID: "JLAFKJWSCS",
|
||||
DisplayName: "Mobile Phone",
|
||||
UserID: remoteUserID,
|
||||
KeyJSON: []byte(keyJSON),
|
||||
},
|
||||
}
|
||||
// Now we should have a fresh list and the keys and emitted something
|
||||
if db.isStale(event.UserID) {
|
||||
t.Errorf("%s still marked as stale", event.UserID)
|
||||
}
|
||||
if !reflect.DeepEqual(producer.events, []api.DeviceMessage{want}) {
|
||||
t.Logf("len got %d len want %d", len(producer.events[0].KeyJSON), len(want.KeyJSON))
|
||||
t.Errorf("Update didn't produce correct event, got %v want %v", producer.events, want)
|
||||
}
|
||||
if !reflect.DeepEqual(db.storedKeys, []api.DeviceMessage{want}) {
|
||||
t.Errorf("DB didn't store correct event, got %v want %v", db.storedKeys, want)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Test that if we make N calls to ManualUpdate for the same user, we only do it once, assuming the
|
||||
// update is still ongoing.
|
||||
func TestDebounce(t *testing.T) {
|
||||
t.Skipf("panic on closed channel on GHA")
|
||||
db := &mockDeviceListUpdaterDatabase{
|
||||
staleUsers: make(map[string]bool),
|
||||
prevIDsExist: func(string, []int64) bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
ap := &mockDeviceListUpdaterAPI{}
|
||||
producer := &mockKeyChangeProducer{}
|
||||
fedCh := make(chan *http.Response, 1)
|
||||
srv := gomatrixserverlib.ServerName("example.com")
|
||||
userID := "@alice:example.com"
|
||||
keyJSON := `{"user_id":"` + userID + `","device_id":"JLAFKJWSCS","algorithms":["m.olm.v1.curve25519-aes-sha2","m.megolm.v1.aes-sha2"],"keys":{"curve25519:JLAFKJWSCS":"3C5BFWi2Y8MaVvjM8M22DBmh24PmgR0nPvJOIArzgyI","ed25519:JLAFKJWSCS":"lEuiRJBit0IG6nUf5pUzWTUEsRVVe/HJkoKuEww9ULI"},"signatures":{"` + userID + `":{"ed25519:JLAFKJWSCS":"dSO80A01XiigH3uBiDVx/EjzaoycHcjq9lfQX0uWsqxl2giMIiSPR8a4d291W1ihKJL/a+myXS367WT6NAIcBA"}}}`
|
||||
incomingFedReq := make(chan struct{})
|
||||
fedClient := newFedClient(func(req *http.Request) (*http.Response, error) {
|
||||
if req.URL.Path != "/_matrix/federation/v1/user/devices/"+url.PathEscape(userID) {
|
||||
return nil, fmt.Errorf("test: invalid path: %s", req.URL.Path)
|
||||
}
|
||||
close(incomingFedReq)
|
||||
return <-fedCh, nil
|
||||
})
|
||||
updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1, nil, "localhost")
|
||||
if err := updater.Start(); err != nil {
|
||||
t.Fatalf("failed to start updater: %s", err)
|
||||
}
|
||||
|
||||
// hit this 5 times
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(5)
|
||||
for i := 0; i < 5; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := updater.ManualUpdate(context.Background(), srv, userID); err != nil {
|
||||
t.Errorf("ManualUpdate: %s", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// wait until the updater hits federation
|
||||
select {
|
||||
case <-incomingFedReq:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("timed out waiting for updater to hit federation")
|
||||
}
|
||||
|
||||
// user should be marked as stale
|
||||
if !db.isStale(userID) {
|
||||
t.Errorf("user %s not marked as stale", userID)
|
||||
}
|
||||
// now send the response over federation
|
||||
fedCh <- &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: io.NopCloser(strings.NewReader(`
|
||||
{
|
||||
"user_id": "` + userID + `",
|
||||
"stream_id": 5,
|
||||
"devices": [
|
||||
{
|
||||
"device_id": "JLAFKJWSCS",
|
||||
"keys": ` + keyJSON + `,
|
||||
"device_display_name": "Mobile Phone"
|
||||
}
|
||||
]
|
||||
}
|
||||
`)),
|
||||
}
|
||||
close(fedCh)
|
||||
// wait until all 5 ManualUpdates return. If we hit federation again we won't send a response
|
||||
// and should panic with read on a closed channel
|
||||
wg.Wait()
|
||||
|
||||
// user is no longer stale now
|
||||
if db.isStale(userID) {
|
||||
t.Errorf("user %s is marked as stale", userID)
|
||||
}
|
||||
}
|
||||
|
||||
func mustCreateKeyserverDB(t *testing.T, dbType test.DBType) (storage.KeyDatabase, func()) {
|
||||
t.Helper()
|
||||
|
||||
base, _, _ := testrig.Base(nil)
|
||||
connStr, clearDB := test.PrepareDBConnectionString(t, dbType)
|
||||
db, err := storage.NewKeyDatabase(base, &config.DatabaseOptions{ConnectionString: config.DataSource(connStr)})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return db, clearDB
|
||||
}
|
||||
|
||||
type mockKeyserverRoomserverAPI struct {
|
||||
leftUsers []string
|
||||
}
|
||||
|
||||
func (m *mockKeyserverRoomserverAPI) QueryLeftUsers(ctx context.Context, req *roomserver.QueryLeftUsersRequest, res *roomserver.QueryLeftUsersResponse) error {
|
||||
res.LeftUsers = m.leftUsers
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestDeviceListUpdater_CleanUp(t *testing.T) {
|
||||
processCtx := process.NewProcessContext()
|
||||
|
||||
alice := test.NewUser(t)
|
||||
bob := test.NewUser(t)
|
||||
|
||||
// Bob is not joined to any of our rooms
|
||||
rsAPI := &mockKeyserverRoomserverAPI{leftUsers: []string{bob.ID}}
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, clearDB := mustCreateKeyserverDB(t, dbType)
|
||||
defer clearDB()
|
||||
|
||||
// This should not get deleted
|
||||
if err := db.MarkDeviceListStale(processCtx.Context(), alice.ID, true); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
// this one should get deleted
|
||||
if err := db.MarkDeviceListStale(processCtx.Context(), bob.ID, true); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
updater := NewDeviceListUpdater(processCtx, db, nil,
|
||||
nil, nil,
|
||||
0, rsAPI, "test")
|
||||
if err := updater.CleanUp(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
// check that we still have Alice in our stale list
|
||||
staleUsers, err := db.StaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
// There should only be Alice
|
||||
wantCount := 1
|
||||
if count := len(staleUsers); count != wantCount {
|
||||
t.Fatalf("expected there to be %d stale device lists, got %d", wantCount, count)
|
||||
}
|
||||
|
||||
if staleUsers[0] != alice.ID {
|
||||
t.Fatalf("unexpected stale device list user: %s, want %s", staleUsers[0], alice.ID)
|
||||
}
|
||||
})
|
||||
}
|
798
userapi/internal/key_api.go
Normal file
798
userapi/internal/key_api.go
Normal file
|
@ -0,0 +1,798 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package internal
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
)
|
||||
|
||||
func (a *UserInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) error {
|
||||
userIDs, latest, err := a.KeyDatabase.KeyChanges(ctx, req.Offset, req.ToOffset)
|
||||
if err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: err.Error(),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
res.Offset = latest
|
||||
res.UserIDs = userIDs
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) error {
|
||||
res.KeyErrors = make(map[string]map[string]*api.KeyError)
|
||||
if len(req.DeviceKeys) > 0 {
|
||||
a.uploadLocalDeviceKeys(ctx, req, res)
|
||||
}
|
||||
if len(req.OneTimeKeys) > 0 {
|
||||
a.uploadOneTimeKeys(ctx, req, res)
|
||||
}
|
||||
otks, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
res.OneTimeKeyCounts = []api.OneTimeKeysCount{*otks}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) PerformClaimKeys(ctx context.Context, req *api.PerformClaimKeysRequest, res *api.PerformClaimKeysResponse) error {
|
||||
res.OneTimeKeys = make(map[string]map[string]map[string]json.RawMessage)
|
||||
res.Failures = make(map[string]interface{})
|
||||
// wrap request map in a top-level by-domain map
|
||||
domainToDeviceKeys := make(map[string]map[string]map[string]string)
|
||||
for userID, val := range req.OneTimeKeys {
|
||||
_, serverName, err := gomatrixserverlib.SplitID('@', userID)
|
||||
if err != nil {
|
||||
continue // ignore invalid users
|
||||
}
|
||||
nested, ok := domainToDeviceKeys[string(serverName)]
|
||||
if !ok {
|
||||
nested = make(map[string]map[string]string)
|
||||
}
|
||||
nested[userID] = val
|
||||
domainToDeviceKeys[string(serverName)] = nested
|
||||
}
|
||||
for domain, local := range domainToDeviceKeys {
|
||||
if !a.Config.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) {
|
||||
continue
|
||||
}
|
||||
// claim local keys
|
||||
keys, err := a.KeyDatabase.ClaimKeys(ctx, local)
|
||||
if err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("failed to ClaimKeys locally: %s", err),
|
||||
}
|
||||
}
|
||||
util.GetLogger(ctx).WithField("keys_claimed", len(keys)).WithField("num_users", len(local)).Info("Claimed local keys")
|
||||
for _, key := range keys {
|
||||
_, ok := res.OneTimeKeys[key.UserID]
|
||||
if !ok {
|
||||
res.OneTimeKeys[key.UserID] = make(map[string]map[string]json.RawMessage)
|
||||
}
|
||||
_, ok = res.OneTimeKeys[key.UserID][key.DeviceID]
|
||||
if !ok {
|
||||
res.OneTimeKeys[key.UserID][key.DeviceID] = make(map[string]json.RawMessage)
|
||||
}
|
||||
for keyID, keyJSON := range key.KeyJSON {
|
||||
res.OneTimeKeys[key.UserID][key.DeviceID][keyID] = keyJSON
|
||||
}
|
||||
}
|
||||
delete(domainToDeviceKeys, domain)
|
||||
}
|
||||
if len(domainToDeviceKeys) > 0 {
|
||||
a.claimRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) claimRemoteKeys(
|
||||
ctx context.Context, timeout time.Duration, res *api.PerformClaimKeysResponse, domainToDeviceKeys map[string]map[string]map[string]string,
|
||||
) {
|
||||
var wg sync.WaitGroup // Wait for fan-out goroutines to finish
|
||||
var mu sync.Mutex // Protects the response struct
|
||||
var claimed int // Number of keys claimed in total
|
||||
var failures int // Number of servers we failed to ask
|
||||
|
||||
util.GetLogger(ctx).Infof("Claiming remote keys from %d server(s)", len(domainToDeviceKeys))
|
||||
wg.Add(len(domainToDeviceKeys))
|
||||
|
||||
for d, k := range domainToDeviceKeys {
|
||||
go func(domain string, keysToClaim map[string]map[string]string) {
|
||||
fedCtx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
defer wg.Done()
|
||||
|
||||
claimKeyRes, err := a.FedClient.ClaimKeys(fedCtx, a.Config.Matrix.ServerName, gomatrixserverlib.ServerName(domain), keysToClaim)
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithError(err).WithField("server", domain).Error("ClaimKeys failed")
|
||||
res.Failures[domain] = map[string]interface{}{
|
||||
"message": err.Error(),
|
||||
}
|
||||
failures++
|
||||
return
|
||||
}
|
||||
|
||||
for userID, deviceIDToKeys := range claimKeyRes.OneTimeKeys {
|
||||
res.OneTimeKeys[userID] = make(map[string]map[string]json.RawMessage)
|
||||
for deviceID, keys := range deviceIDToKeys {
|
||||
res.OneTimeKeys[userID][deviceID] = keys
|
||||
claimed += len(keys)
|
||||
}
|
||||
}
|
||||
}(d, k)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
util.GetLogger(ctx).WithFields(logrus.Fields{
|
||||
"num_keys": claimed,
|
||||
"num_failures": failures,
|
||||
}).Infof("Claimed remote keys from %d server(s)", len(domainToDeviceKeys))
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) PerformDeleteKeys(ctx context.Context, req *api.PerformDeleteKeysRequest, res *api.PerformDeleteKeysResponse) error {
|
||||
if err := a.KeyDatabase.DeleteDeviceKeys(ctx, req.UserID, req.KeyIDs); err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("Failed to delete device keys: %s", err),
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) QueryOneTimeKeys(ctx context.Context, req *api.QueryOneTimeKeysRequest, res *api.QueryOneTimeKeysResponse) error {
|
||||
count, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
|
||||
if err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("Failed to query OTK counts: %s", err),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
res.Count = *count
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.QueryDeviceMessagesRequest, res *api.QueryDeviceMessagesResponse) error {
|
||||
msgs, err := a.KeyDatabase.DeviceKeysForUser(ctx, req.UserID, nil, false)
|
||||
if err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("failed to query DB for device keys: %s", err),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
maxStreamID := int64(0)
|
||||
// remove deleted devices
|
||||
var result []api.DeviceMessage
|
||||
for _, m := range msgs {
|
||||
if m.StreamID > maxStreamID {
|
||||
maxStreamID = m.StreamID
|
||||
}
|
||||
if m.KeyJSON == nil || len(m.KeyJSON) == 0 {
|
||||
continue
|
||||
}
|
||||
result = append(result, m)
|
||||
}
|
||||
res.Devices = result
|
||||
res.StreamID = maxStreamID
|
||||
return nil
|
||||
}
|
||||
|
||||
// PerformMarkAsStaleIfNeeded marks the users device list as stale, if the given deviceID is not present
|
||||
// in our database.
|
||||
func (a *UserInternalAPI) PerformMarkAsStaleIfNeeded(ctx context.Context, req *api.PerformMarkAsStaleRequest, res *struct{}) error {
|
||||
knownDevices, err := a.KeyDatabase.DeviceKeysForUser(ctx, req.UserID, []string{}, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(knownDevices) == 0 {
|
||||
return nil // fmt.Errorf("unknown user %s", req.UserID)
|
||||
}
|
||||
|
||||
for i := range knownDevices {
|
||||
if knownDevices[i].DeviceID == req.DeviceID {
|
||||
return nil // we already know about this device
|
||||
}
|
||||
}
|
||||
|
||||
return a.Updater.ManualUpdate(ctx, req.Domain, req.UserID)
|
||||
}
|
||||
|
||||
// nolint:gocyclo
|
||||
func (a *UserInternalAPI) QueryKeys(ctx context.Context, req *api.QueryKeysRequest, res *api.QueryKeysResponse) error {
|
||||
var respMu sync.Mutex
|
||||
res.DeviceKeys = make(map[string]map[string]json.RawMessage)
|
||||
res.MasterKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
|
||||
res.SelfSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
|
||||
res.UserSigningKeys = make(map[string]gomatrixserverlib.CrossSigningKey)
|
||||
res.Failures = make(map[string]interface{})
|
||||
|
||||
// make a map from domain to device keys
|
||||
domainToDeviceKeys := make(map[string]map[string][]string)
|
||||
domainToCrossSigningKeys := make(map[string]map[string]struct{})
|
||||
for userID, deviceIDs := range req.UserToDevices {
|
||||
_, serverName, err := gomatrixserverlib.SplitID('@', userID)
|
||||
if err != nil {
|
||||
continue // ignore invalid users
|
||||
}
|
||||
domain := string(serverName)
|
||||
// query local devices
|
||||
if a.Config.Matrix.IsLocalServerName(serverName) {
|
||||
deviceKeys, err := a.KeyDatabase.DeviceKeysForUser(ctx, userID, deviceIDs, false)
|
||||
if err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("failed to query local device keys: %s", err),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// pull out display names after we have the keys so we handle wildcards correctly
|
||||
var dids []string
|
||||
for _, dk := range deviceKeys {
|
||||
dids = append(dids, dk.DeviceID)
|
||||
}
|
||||
var queryRes api.QueryDeviceInfosResponse
|
||||
err = a.QueryDeviceInfos(ctx, &api.QueryDeviceInfosRequest{
|
||||
DeviceIDs: dids,
|
||||
}, &queryRes)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).Warnf("Failed to QueryDeviceInfos for device IDs, display names will be missing")
|
||||
}
|
||||
|
||||
if res.DeviceKeys[userID] == nil {
|
||||
res.DeviceKeys[userID] = make(map[string]json.RawMessage)
|
||||
}
|
||||
for _, dk := range deviceKeys {
|
||||
if len(dk.KeyJSON) == 0 {
|
||||
continue // don't include blank keys
|
||||
}
|
||||
// inject display name if known (either locally or remotely)
|
||||
displayName := dk.DisplayName
|
||||
if queryRes.DeviceInfo[dk.DeviceID].DisplayName != "" {
|
||||
displayName = queryRes.DeviceInfo[dk.DeviceID].DisplayName
|
||||
}
|
||||
dk.KeyJSON, _ = sjson.SetBytes(dk.KeyJSON, "unsigned", struct {
|
||||
DisplayName string `json:"device_display_name,omitempty"`
|
||||
}{displayName})
|
||||
res.DeviceKeys[userID][dk.DeviceID] = dk.KeyJSON
|
||||
}
|
||||
} else {
|
||||
domainToDeviceKeys[domain] = make(map[string][]string)
|
||||
domainToDeviceKeys[domain][userID] = append(domainToDeviceKeys[domain][userID], deviceIDs...)
|
||||
}
|
||||
// work out if our cross-signing request for this user was
|
||||
// satisfied, if not add them to the list of things to fetch
|
||||
if _, ok := res.MasterKeys[userID]; !ok {
|
||||
if _, ok := domainToCrossSigningKeys[domain]; !ok {
|
||||
domainToCrossSigningKeys[domain] = make(map[string]struct{})
|
||||
}
|
||||
domainToCrossSigningKeys[domain][userID] = struct{}{}
|
||||
}
|
||||
if _, ok := res.SelfSigningKeys[userID]; !ok {
|
||||
if _, ok := domainToCrossSigningKeys[domain]; !ok {
|
||||
domainToCrossSigningKeys[domain] = make(map[string]struct{})
|
||||
}
|
||||
domainToCrossSigningKeys[domain][userID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// attempt to satisfy key queries from the local database first as we should get device updates pushed to us
|
||||
domainToDeviceKeys = a.remoteKeysFromDatabase(ctx, res, &respMu, domainToDeviceKeys)
|
||||
if len(domainToDeviceKeys) > 0 || len(domainToCrossSigningKeys) > 0 {
|
||||
// perform key queries for remote devices
|
||||
a.queryRemoteKeys(ctx, req.Timeout, res, domainToDeviceKeys, domainToCrossSigningKeys)
|
||||
}
|
||||
|
||||
// Now that we've done the potentially expensive work of asking the federation,
|
||||
// try filling the cross-signing keys from the database that we know about.
|
||||
a.crossSigningKeysFromDatabase(ctx, req, res)
|
||||
|
||||
// Finally, append signatures that we know about
|
||||
// TODO: This is horrible because we need to round-trip the signature from
|
||||
// JSON, add the signatures and marshal it again, for some reason?
|
||||
|
||||
for targetUserID, masterKey := range res.MasterKeys {
|
||||
if masterKey.Signatures == nil {
|
||||
masterKey.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
|
||||
}
|
||||
for targetKeyID := range masterKey.Keys {
|
||||
sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, targetKeyID)
|
||||
if err != nil {
|
||||
// Stop executing the function if the context was canceled/the deadline was exceeded,
|
||||
// as we can't continue without a valid context.
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return nil
|
||||
}
|
||||
logrus.WithError(err).Errorf("a.KeyDatabase.CrossSigningSigsForTarget failed")
|
||||
continue
|
||||
}
|
||||
if len(sigMap) == 0 {
|
||||
continue
|
||||
}
|
||||
for sourceUserID, forSourceUser := range sigMap {
|
||||
for sourceKeyID, sourceSig := range forSourceUser {
|
||||
if _, ok := masterKey.Signatures[sourceUserID]; !ok {
|
||||
masterKey.Signatures[sourceUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
|
||||
}
|
||||
masterKey.Signatures[sourceUserID][sourceKeyID] = sourceSig
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for targetUserID, forUserID := range res.DeviceKeys {
|
||||
for targetKeyID, key := range forUserID {
|
||||
sigMap, err := a.KeyDatabase.CrossSigningSigsForTarget(ctx, req.UserID, targetUserID, gomatrixserverlib.KeyID(targetKeyID))
|
||||
if err != nil {
|
||||
// Stop executing the function if the context was canceled/the deadline was exceeded,
|
||||
// as we can't continue without a valid context.
|
||||
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||
return nil
|
||||
}
|
||||
logrus.WithError(err).Errorf("a.KeyDatabase.CrossSigningSigsForTarget failed")
|
||||
continue
|
||||
}
|
||||
if len(sigMap) == 0 {
|
||||
continue
|
||||
}
|
||||
var deviceKey gomatrixserverlib.DeviceKeys
|
||||
if err = json.Unmarshal(key, &deviceKey); err != nil {
|
||||
continue
|
||||
}
|
||||
if deviceKey.Signatures == nil {
|
||||
deviceKey.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
|
||||
}
|
||||
for sourceUserID, forSourceUser := range sigMap {
|
||||
for sourceKeyID, sourceSig := range forSourceUser {
|
||||
if _, ok := deviceKey.Signatures[sourceUserID]; !ok {
|
||||
deviceKey.Signatures[sourceUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
|
||||
}
|
||||
deviceKey.Signatures[sourceUserID][sourceKeyID] = sourceSig
|
||||
}
|
||||
}
|
||||
if js, err := json.Marshal(deviceKey); err == nil {
|
||||
res.DeviceKeys[targetUserID][targetKeyID] = js
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) remoteKeysFromDatabase(
|
||||
ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, domainToDeviceKeys map[string]map[string][]string,
|
||||
) map[string]map[string][]string {
|
||||
fetchRemote := make(map[string]map[string][]string)
|
||||
for domain, userToDeviceMap := range domainToDeviceKeys {
|
||||
for userID, deviceIDs := range userToDeviceMap {
|
||||
// we can't safely return keys from the db when all devices are requested as we don't
|
||||
// know if one has just been added.
|
||||
if len(deviceIDs) > 0 {
|
||||
err := a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, deviceIDs)
|
||||
if err == nil {
|
||||
continue
|
||||
}
|
||||
util.GetLogger(ctx).WithError(err).Error("populateResponseWithDeviceKeysFromDatabase")
|
||||
}
|
||||
// fetch device lists from remote
|
||||
if _, ok := fetchRemote[domain]; !ok {
|
||||
fetchRemote[domain] = make(map[string][]string)
|
||||
}
|
||||
fetchRemote[domain][userID] = append(fetchRemote[domain][userID], deviceIDs...)
|
||||
|
||||
}
|
||||
}
|
||||
return fetchRemote
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) queryRemoteKeys(
|
||||
ctx context.Context, timeout time.Duration, res *api.QueryKeysResponse,
|
||||
domainToDeviceKeys map[string]map[string][]string, domainToCrossSigningKeys map[string]map[string]struct{},
|
||||
) {
|
||||
resultCh := make(chan *gomatrixserverlib.RespQueryKeys, len(domainToDeviceKeys))
|
||||
// allows us to wait until all federation servers have been poked
|
||||
var wg sync.WaitGroup
|
||||
// mutex for writing directly to res (e.g failures)
|
||||
var respMu sync.Mutex
|
||||
|
||||
domains := map[string]struct{}{}
|
||||
for domain := range domainToDeviceKeys {
|
||||
if a.Config.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) {
|
||||
continue
|
||||
}
|
||||
domains[domain] = struct{}{}
|
||||
}
|
||||
for domain := range domainToCrossSigningKeys {
|
||||
if a.Config.Matrix.IsLocalServerName(gomatrixserverlib.ServerName(domain)) {
|
||||
continue
|
||||
}
|
||||
domains[domain] = struct{}{}
|
||||
}
|
||||
wg.Add(len(domains))
|
||||
|
||||
// fan out
|
||||
for domain := range domains {
|
||||
go a.queryRemoteKeysOnServer(
|
||||
ctx, domain, domainToDeviceKeys[domain], domainToCrossSigningKeys[domain],
|
||||
&wg, &respMu, timeout, resultCh, res,
|
||||
)
|
||||
}
|
||||
|
||||
// Close the result channel when the goroutines have quit so the for .. range exits
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(resultCh)
|
||||
}()
|
||||
|
||||
processResult := func(result *gomatrixserverlib.RespQueryKeys) {
|
||||
respMu.Lock()
|
||||
defer respMu.Unlock()
|
||||
for userID, nest := range result.DeviceKeys {
|
||||
res.DeviceKeys[userID] = make(map[string]json.RawMessage)
|
||||
for deviceID, deviceKey := range nest {
|
||||
keyJSON, err := json.Marshal(deviceKey)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
res.DeviceKeys[userID][deviceID] = keyJSON
|
||||
}
|
||||
}
|
||||
|
||||
for userID, body := range result.MasterKeys {
|
||||
res.MasterKeys[userID] = body
|
||||
}
|
||||
|
||||
for userID, body := range result.SelfSigningKeys {
|
||||
res.SelfSigningKeys[userID] = body
|
||||
}
|
||||
|
||||
// TODO: do we want to persist these somewhere now
|
||||
// that we have fetched them?
|
||||
}
|
||||
|
||||
for result := range resultCh {
|
||||
processResult(result)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) queryRemoteKeysOnServer(
|
||||
ctx context.Context, serverName string, devKeys map[string][]string, crossSigningKeys map[string]struct{},
|
||||
wg *sync.WaitGroup, respMu *sync.Mutex, timeout time.Duration, resultCh chan<- *gomatrixserverlib.RespQueryKeys,
|
||||
res *api.QueryKeysResponse,
|
||||
) {
|
||||
defer wg.Done()
|
||||
fedCtx := ctx
|
||||
if timeout > 0 {
|
||||
var cancel context.CancelFunc
|
||||
fedCtx, cancel = context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
}
|
||||
// for users who we do not have any knowledge about, try to start doing device list updates for them
|
||||
// by hitting /users/devices - otherwise fallback to /keys/query which has nicer bulk properties but
|
||||
// lack a stream ID.
|
||||
userIDsForAllDevices := map[string]struct{}{}
|
||||
for userID, deviceIDs := range devKeys {
|
||||
if len(deviceIDs) == 0 {
|
||||
userIDsForAllDevices[userID] = struct{}{}
|
||||
}
|
||||
}
|
||||
// for cross-signing keys, it's probably easier just to hit /keys/query if we aren't already doing
|
||||
// a device list update, so we'll populate those back into the /keys/query list if not
|
||||
for userID := range crossSigningKeys {
|
||||
if devKeys == nil {
|
||||
devKeys = map[string][]string{}
|
||||
}
|
||||
if _, ok := userIDsForAllDevices[userID]; !ok {
|
||||
devKeys[userID] = []string{}
|
||||
}
|
||||
}
|
||||
for userID := range userIDsForAllDevices {
|
||||
err := a.Updater.ManualUpdate(context.Background(), gomatrixserverlib.ServerName(serverName), userID)
|
||||
if err != nil {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
logrus.ErrorKey: err,
|
||||
"user_id": userID,
|
||||
"server": serverName,
|
||||
}).Error("Failed to manually update device lists for user")
|
||||
// try to do it via /keys/query
|
||||
devKeys[userID] = []string{}
|
||||
continue
|
||||
}
|
||||
// refresh entries from DB: unlike remoteKeysFromDatabase we know we previously had no device info for this
|
||||
// user so the fact that we're populating all devices here isn't a problem so long as we have devices.
|
||||
err = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, nil)
|
||||
if err != nil {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
logrus.ErrorKey: err,
|
||||
"user_id": userID,
|
||||
"server": serverName,
|
||||
}).Error("Failed to manually update device lists for user")
|
||||
// try to do it via /keys/query
|
||||
devKeys[userID] = []string{}
|
||||
continue
|
||||
}
|
||||
}
|
||||
if len(devKeys) == 0 {
|
||||
return
|
||||
}
|
||||
queryKeysResp, err := a.FedClient.QueryKeys(fedCtx, a.Config.Matrix.ServerName, gomatrixserverlib.ServerName(serverName), devKeys)
|
||||
if err == nil {
|
||||
resultCh <- &queryKeysResp
|
||||
return
|
||||
}
|
||||
respMu.Lock()
|
||||
res.Failures[serverName] = map[string]interface{}{
|
||||
"message": err.Error(),
|
||||
}
|
||||
respMu.Unlock()
|
||||
|
||||
// last ditch, use the cache only. This is good for when clients hit /keys/query and the remote server
|
||||
// is down, better to return something than nothing at all. Clients can know about the failure by
|
||||
// inspecting the failures map though so they can know it's a cached response.
|
||||
for userID, dkeys := range devKeys {
|
||||
// drop the error as it's already a failure at this point
|
||||
_ = a.populateResponseWithDeviceKeysFromDatabase(ctx, res, respMu, userID, dkeys)
|
||||
}
|
||||
|
||||
// Sytest expects no failures, if we still could retrieve keys, e.g. from local cache
|
||||
respMu.Lock()
|
||||
if len(res.DeviceKeys) > 0 {
|
||||
delete(res.Failures, serverName)
|
||||
}
|
||||
respMu.Unlock()
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) populateResponseWithDeviceKeysFromDatabase(
|
||||
ctx context.Context, res *api.QueryKeysResponse, respMu *sync.Mutex, userID string, deviceIDs []string,
|
||||
) error {
|
||||
keys, err := a.KeyDatabase.DeviceKeysForUser(ctx, userID, deviceIDs, false)
|
||||
// if we can't query the db or there are fewer keys than requested, fetch from remote.
|
||||
if err != nil {
|
||||
return fmt.Errorf("DeviceKeysForUser %s %v failed: %w", userID, deviceIDs, err)
|
||||
}
|
||||
if len(keys) < len(deviceIDs) {
|
||||
return fmt.Errorf("DeviceKeysForUser %s returned fewer devices than requested, falling back to remote", userID)
|
||||
}
|
||||
if len(deviceIDs) == 0 && len(keys) == 0 {
|
||||
return fmt.Errorf("DeviceKeysForUser %s returned no keys but wanted all keys, falling back to remote", userID)
|
||||
}
|
||||
respMu.Lock()
|
||||
if res.DeviceKeys[userID] == nil {
|
||||
res.DeviceKeys[userID] = make(map[string]json.RawMessage)
|
||||
}
|
||||
respMu.Unlock()
|
||||
|
||||
for _, key := range keys {
|
||||
if len(key.KeyJSON) == 0 {
|
||||
continue // ignore deleted keys
|
||||
}
|
||||
// inject the display name
|
||||
key.KeyJSON, _ = sjson.SetBytes(key.KeyJSON, "unsigned", struct {
|
||||
DisplayName string `json:"device_display_name,omitempty"`
|
||||
}{key.DisplayName})
|
||||
respMu.Lock()
|
||||
res.DeviceKeys[userID][key.DeviceID] = key.KeyJSON
|
||||
respMu.Unlock()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
|
||||
// get a list of devices from the user API that actually exist, as
|
||||
// we won't store keys for devices that don't exist
|
||||
uapidevices := &api.QueryDevicesResponse{}
|
||||
if err := a.QueryDevices(ctx, &api.QueryDevicesRequest{UserID: req.UserID}, uapidevices); err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: err.Error(),
|
||||
}
|
||||
return
|
||||
}
|
||||
if !uapidevices.UserExists {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("user %q does not exist", req.UserID),
|
||||
}
|
||||
return
|
||||
}
|
||||
existingDeviceMap := make(map[string]struct{}, len(uapidevices.Devices))
|
||||
for _, key := range uapidevices.Devices {
|
||||
existingDeviceMap[key.ID] = struct{}{}
|
||||
}
|
||||
|
||||
// Get all of the user existing device keys so we can check for changes.
|
||||
existingKeys, err := a.KeyDatabase.DeviceKeysForUser(ctx, req.UserID, nil, true)
|
||||
if err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("failed to query existing device keys: %s", err.Error()),
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Work out whether we have device keys in the keyserver for devices that
|
||||
// no longer exist in the user API. This is mostly an exercise to ensure
|
||||
// that we keep some integrity between the two.
|
||||
var toClean []gomatrixserverlib.KeyID
|
||||
for _, k := range existingKeys {
|
||||
if _, ok := existingDeviceMap[k.DeviceID]; !ok {
|
||||
toClean = append(toClean, gomatrixserverlib.KeyID(k.DeviceID))
|
||||
}
|
||||
}
|
||||
|
||||
if len(toClean) > 0 {
|
||||
if err = a.KeyDatabase.DeleteDeviceKeys(ctx, req.UserID, toClean); err != nil {
|
||||
logrus.WithField("user_id", req.UserID).WithError(err).Errorf("Failed to clean up %d stale keyserver device key entries", len(toClean))
|
||||
} else {
|
||||
logrus.WithField("user_id", req.UserID).Debugf("Cleaned up %d stale keyserver device key entries", len(toClean))
|
||||
}
|
||||
}
|
||||
|
||||
var keysToStore []api.DeviceMessage
|
||||
|
||||
if req.OnlyDisplayNameUpdates {
|
||||
for _, existingKey := range existingKeys {
|
||||
for _, newKey := range req.DeviceKeys {
|
||||
switch {
|
||||
case existingKey.UserID != newKey.UserID:
|
||||
continue
|
||||
case existingKey.DeviceID != newKey.DeviceID:
|
||||
continue
|
||||
case existingKey.DisplayName != newKey.DisplayName:
|
||||
existingKey.DisplayName = newKey.DisplayName
|
||||
}
|
||||
}
|
||||
keysToStore = append(keysToStore, existingKey)
|
||||
}
|
||||
} else {
|
||||
// assert that the user ID / device ID are not lying for each key
|
||||
for _, key := range req.DeviceKeys {
|
||||
var serverName gomatrixserverlib.ServerName
|
||||
_, serverName, err = gomatrixserverlib.SplitID('@', key.UserID)
|
||||
if err != nil {
|
||||
continue // ignore invalid users
|
||||
}
|
||||
if !a.Config.Matrix.IsLocalServerName(serverName) {
|
||||
continue // ignore remote users
|
||||
}
|
||||
if len(key.KeyJSON) == 0 {
|
||||
keysToStore = append(keysToStore, key.WithStreamID(0))
|
||||
continue // deleted keys don't need sanity checking
|
||||
}
|
||||
// check that the device in question actually exists in the user
|
||||
// API before we try and store a key for it
|
||||
if _, ok := existingDeviceMap[key.DeviceID]; !ok {
|
||||
continue
|
||||
}
|
||||
gotUserID := gjson.GetBytes(key.KeyJSON, "user_id").Str
|
||||
gotDeviceID := gjson.GetBytes(key.KeyJSON, "device_id").Str
|
||||
if gotUserID == key.UserID && gotDeviceID == key.DeviceID {
|
||||
keysToStore = append(keysToStore, key.WithStreamID(0))
|
||||
continue
|
||||
}
|
||||
|
||||
res.KeyError(key.UserID, key.DeviceID, &api.KeyError{
|
||||
Err: fmt.Sprintf(
|
||||
"user_id or device_id mismatch: users: %s - %s, devices: %s - %s",
|
||||
gotUserID, key.UserID, gotDeviceID, key.DeviceID,
|
||||
),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// store the device keys and emit changes
|
||||
err = a.KeyDatabase.StoreLocalDeviceKeys(ctx, keysToStore)
|
||||
if err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("failed to store device keys: %s", err.Error()),
|
||||
}
|
||||
return
|
||||
}
|
||||
err = emitDeviceKeyChanges(a.KeyChangeProducer, existingKeys, keysToStore, req.OnlyDisplayNameUpdates)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).Errorf("Failed to emitDeviceKeyChanges: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
|
||||
if req.UserID == "" {
|
||||
res.Error = &api.KeyError{
|
||||
Err: "user ID missing",
|
||||
}
|
||||
}
|
||||
if req.DeviceID != "" && len(req.OneTimeKeys) == 0 {
|
||||
counts, err := a.KeyDatabase.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
|
||||
if err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: fmt.Sprintf("a.KeyDatabase.OneTimeKeysCount: %s", err),
|
||||
}
|
||||
}
|
||||
if counts != nil {
|
||||
res.OneTimeKeyCounts = append(res.OneTimeKeyCounts, *counts)
|
||||
}
|
||||
return
|
||||
}
|
||||
for _, key := range req.OneTimeKeys {
|
||||
// grab existing keys based on (user/device/algorithm/key ID)
|
||||
keyIDsWithAlgorithms := make([]string, len(key.KeyJSON))
|
||||
i := 0
|
||||
for keyIDWithAlgo := range key.KeyJSON {
|
||||
keyIDsWithAlgorithms[i] = keyIDWithAlgo
|
||||
i++
|
||||
}
|
||||
existingKeys, err := a.KeyDatabase.ExistingOneTimeKeys(ctx, req.UserID, req.DeviceID, keyIDsWithAlgorithms)
|
||||
if err != nil {
|
||||
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
|
||||
Err: "failed to query existing one-time keys: " + err.Error(),
|
||||
})
|
||||
continue
|
||||
}
|
||||
for keyIDWithAlgo := range existingKeys {
|
||||
// if keys exist and the JSON doesn't match, error out as the key already exists
|
||||
if !bytes.Equal(existingKeys[keyIDWithAlgo], key.KeyJSON[keyIDWithAlgo]) {
|
||||
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
|
||||
Err: fmt.Sprintf("%s device %s: algorithm / key ID %s one-time key already exists", req.UserID, req.DeviceID, keyIDWithAlgo),
|
||||
})
|
||||
continue
|
||||
}
|
||||
}
|
||||
// store one-time keys
|
||||
counts, err := a.KeyDatabase.StoreOneTimeKeys(ctx, key)
|
||||
if err != nil {
|
||||
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
|
||||
Err: fmt.Sprintf("%s device %s : failed to store one-time keys: %s", req.UserID, req.DeviceID, err.Error()),
|
||||
})
|
||||
continue
|
||||
}
|
||||
// collect counts
|
||||
res.OneTimeKeyCounts = append(res.OneTimeKeyCounts, *counts)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func emitDeviceKeyChanges(producer KeyChangeProducer, existing, new []api.DeviceMessage, onlyUpdateDisplayName bool) error {
|
||||
// if we only want to update the display names, we can skip the checks below
|
||||
if onlyUpdateDisplayName {
|
||||
return producer.ProduceKeyChanges(new)
|
||||
}
|
||||
// find keys in new that are not in existing
|
||||
var keysAdded []api.DeviceMessage
|
||||
for _, newKey := range new {
|
||||
exists := false
|
||||
for _, existingKey := range existing {
|
||||
// Do not treat the absence of keys as equal, or else we will not emit key changes
|
||||
// when users delete devices which never had a key to begin with as both KeyJSONs are nil.
|
||||
if existingKey.DeviceKeysEqual(&newKey) {
|
||||
exists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !exists {
|
||||
keysAdded = append(keysAdded, newKey)
|
||||
}
|
||||
}
|
||||
return producer.ProduceKeyChanges(keysAdded)
|
||||
}
|
161
userapi/internal/key_api_test.go
Normal file
161
userapi/internal/key_api_test.go
Normal file
|
@ -0,0 +1,161 @@
|
|||
package internal_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
"github.com/matrix-org/dendrite/test/testrig"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/internal"
|
||||
"github.com/matrix-org/dendrite/userapi/storage"
|
||||
)
|
||||
|
||||
func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.KeyDatabase, func()) {
|
||||
t.Helper()
|
||||
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||
base, _, _ := testrig.Base(nil)
|
||||
db, err := storage.NewKeyDatabase(base, &config.DatabaseOptions{
|
||||
ConnectionString: config.DataSource(connStr),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create new user db: %v", err)
|
||||
}
|
||||
return db, func() {
|
||||
base.Close()
|
||||
close()
|
||||
}
|
||||
}
|
||||
|
||||
func Test_QueryDeviceMessages(t *testing.T) {
|
||||
alice := test.NewUser(t)
|
||||
type args struct {
|
||||
req *api.QueryDeviceMessagesRequest
|
||||
res *api.QueryDeviceMessagesResponse
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr bool
|
||||
want *api.QueryDeviceMessagesResponse
|
||||
}{
|
||||
{
|
||||
name: "no existing keys",
|
||||
args: args{
|
||||
req: &api.QueryDeviceMessagesRequest{
|
||||
UserID: "@doesNotExist:localhost",
|
||||
},
|
||||
res: &api.QueryDeviceMessagesResponse{},
|
||||
},
|
||||
want: &api.QueryDeviceMessagesResponse{},
|
||||
},
|
||||
{
|
||||
name: "existing user returns devices",
|
||||
args: args{
|
||||
req: &api.QueryDeviceMessagesRequest{
|
||||
UserID: alice.ID,
|
||||
},
|
||||
res: &api.QueryDeviceMessagesResponse{},
|
||||
},
|
||||
want: &api.QueryDeviceMessagesResponse{
|
||||
StreamID: 6,
|
||||
Devices: []api.DeviceMessage{
|
||||
{
|
||||
Type: api.TypeDeviceKeyUpdate, StreamID: 5, DeviceKeys: &api.DeviceKeys{
|
||||
DeviceID: "myDevice",
|
||||
DisplayName: "first device",
|
||||
UserID: alice.ID,
|
||||
KeyJSON: []byte("ghi"),
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: api.TypeDeviceKeyUpdate, StreamID: 6, DeviceKeys: &api.DeviceKeys{
|
||||
DeviceID: "mySecondDevice",
|
||||
DisplayName: "second device",
|
||||
UserID: alice.ID,
|
||||
KeyJSON: []byte("jkl"),
|
||||
}, // streamID 6
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
deviceMessages := []api.DeviceMessage{
|
||||
{ // not the user we're looking for
|
||||
Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{
|
||||
UserID: "@doesNotExist:localhost",
|
||||
},
|
||||
// streamID 1 for this user
|
||||
},
|
||||
{ // empty keyJSON will be ignored
|
||||
Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{
|
||||
DeviceID: "myDevice",
|
||||
UserID: alice.ID,
|
||||
}, // streamID 1
|
||||
},
|
||||
{
|
||||
Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{
|
||||
DeviceID: "myDevice",
|
||||
UserID: alice.ID,
|
||||
KeyJSON: []byte("abc"),
|
||||
}, // streamID 2
|
||||
},
|
||||
{
|
||||
Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{
|
||||
DeviceID: "myDevice",
|
||||
UserID: alice.ID,
|
||||
KeyJSON: []byte("def"),
|
||||
}, // streamID 3
|
||||
},
|
||||
{
|
||||
Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{
|
||||
DeviceID: "myDevice",
|
||||
UserID: alice.ID,
|
||||
KeyJSON: []byte(""),
|
||||
}, // streamID 4
|
||||
},
|
||||
{
|
||||
Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{
|
||||
DeviceID: "myDevice",
|
||||
DisplayName: "first device",
|
||||
UserID: alice.ID,
|
||||
KeyJSON: []byte("ghi"),
|
||||
}, // streamID 5
|
||||
},
|
||||
{
|
||||
Type: api.TypeDeviceKeyUpdate, DeviceKeys: &api.DeviceKeys{
|
||||
DeviceID: "mySecondDevice",
|
||||
UserID: alice.ID,
|
||||
KeyJSON: []byte("jkl"),
|
||||
DisplayName: "second device",
|
||||
}, // streamID 6
|
||||
},
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, closeDB := mustCreateDatabase(t, dbType)
|
||||
defer closeDB()
|
||||
if err := db.StoreLocalDeviceKeys(ctx, deviceMessages); err != nil {
|
||||
t.Fatalf("failed to store local devicesKeys")
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := &internal.UserInternalAPI{
|
||||
KeyDatabase: db,
|
||||
}
|
||||
if err := a.QueryDeviceMessages(ctx, tt.args.req, tt.args.res); (err != nil) != tt.wantErr {
|
||||
t.Errorf("QueryDeviceMessages() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
got := tt.args.res
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("QueryDeviceMessages(): got:\n%+v, want:\n%+v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
|
@ -23,6 +23,7 @@ import (
|
|||
"strconv"
|
||||
"time"
|
||||
|
||||
fedsenderapi "github.com/matrix-org/dendrite/federationapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
@ -32,7 +33,6 @@ import (
|
|||
"github.com/matrix-org/dendrite/internal/eventutil"
|
||||
"github.com/matrix-org/dendrite/internal/pushgateway"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
keyapi "github.com/matrix-org/dendrite/keyserver/api"
|
||||
rsapi "github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
synctypes "github.com/matrix-org/dendrite/syncapi/types"
|
||||
|
@ -44,17 +44,19 @@ import (
|
|||
)
|
||||
|
||||
type UserInternalAPI struct {
|
||||
DB storage.Database
|
||||
SyncProducer *producers.SyncAPI
|
||||
Config *config.UserAPI
|
||||
DB storage.UserDatabase
|
||||
KeyDatabase storage.KeyDatabase
|
||||
SyncProducer *producers.SyncAPI
|
||||
KeyChangeProducer *producers.KeyChange
|
||||
Config *config.UserAPI
|
||||
|
||||
DisableTLSValidation bool
|
||||
// AppServices is the list of all registered AS
|
||||
AppServices []config.ApplicationService
|
||||
KeyAPI keyapi.UserKeyAPI
|
||||
RSAPI rsapi.UserRoomserverAPI
|
||||
PgClient pushgateway.Client
|
||||
Cfg *config.UserAPI
|
||||
FedClient fedsenderapi.KeyserverFederationAPI
|
||||
Updater *DeviceListUpdater
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAccountDataRequest, res *api.InputAccountDataResponse) error {
|
||||
|
@ -221,7 +223,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
|
|||
return fmt.Errorf("a.DB.SetDisplayName: %w", err)
|
||||
}
|
||||
|
||||
postRegisterJoinRooms(a.Cfg, acc, a.RSAPI)
|
||||
postRegisterJoinRooms(a.Config, acc, a.RSAPI)
|
||||
|
||||
res.AccountCreated = true
|
||||
res.Account = acc
|
||||
|
@ -293,14 +295,14 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe
|
|||
return err
|
||||
}
|
||||
// Ask the keyserver to delete device keys and signatures for those devices
|
||||
deleteReq := &keyapi.PerformDeleteKeysRequest{
|
||||
deleteReq := &api.PerformDeleteKeysRequest{
|
||||
UserID: req.UserID,
|
||||
}
|
||||
for _, keyID := range req.DeviceIDs {
|
||||
deleteReq.KeyIDs = append(deleteReq.KeyIDs, gomatrixserverlib.KeyID(keyID))
|
||||
}
|
||||
deleteRes := &keyapi.PerformDeleteKeysResponse{}
|
||||
if err := a.KeyAPI.PerformDeleteKeys(ctx, deleteReq, deleteRes); err != nil {
|
||||
deleteRes := &api.PerformDeleteKeysResponse{}
|
||||
if err := a.PerformDeleteKeys(ctx, deleteReq, deleteRes); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := deleteRes.Error; err != nil {
|
||||
|
@ -311,17 +313,17 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe
|
|||
}
|
||||
|
||||
func (a *UserInternalAPI) deviceListUpdate(userID string, deviceIDs []string) error {
|
||||
deviceKeys := make([]keyapi.DeviceKeys, len(deviceIDs))
|
||||
deviceKeys := make([]api.DeviceKeys, len(deviceIDs))
|
||||
for i, did := range deviceIDs {
|
||||
deviceKeys[i] = keyapi.DeviceKeys{
|
||||
deviceKeys[i] = api.DeviceKeys{
|
||||
UserID: userID,
|
||||
DeviceID: did,
|
||||
KeyJSON: nil,
|
||||
}
|
||||
}
|
||||
|
||||
var uploadRes keyapi.PerformUploadKeysResponse
|
||||
if err := a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{
|
||||
var uploadRes api.PerformUploadKeysResponse
|
||||
if err := a.PerformUploadKeys(context.Background(), &api.PerformUploadKeysRequest{
|
||||
UserID: userID,
|
||||
DeviceKeys: deviceKeys,
|
||||
}, &uploadRes); err != nil {
|
||||
|
@ -385,10 +387,10 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
|
|||
}
|
||||
if req.DisplayName != nil && dev.DisplayName != *req.DisplayName {
|
||||
// display name has changed: update the device key
|
||||
var uploadRes keyapi.PerformUploadKeysResponse
|
||||
if err := a.KeyAPI.PerformUploadKeys(context.Background(), &keyapi.PerformUploadKeysRequest{
|
||||
var uploadRes api.PerformUploadKeysResponse
|
||||
if err := a.PerformUploadKeys(context.Background(), &api.PerformUploadKeysRequest{
|
||||
UserID: req.RequestingUserID,
|
||||
DeviceKeys: []keyapi.DeviceKeys{
|
||||
DeviceKeys: []api.DeviceKeys{
|
||||
{
|
||||
DeviceID: dev.ID,
|
||||
DisplayName: *req.DisplayName,
|
107
userapi/producers/keychange.go
Normal file
107
userapi/producers/keychange.go
Normal file
|
@ -0,0 +1,107 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package producers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage"
|
||||
"github.com/nats-io/nats.go"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// KeyChange produces key change events for the sync API and federation sender to consume
|
||||
type KeyChange struct {
|
||||
Topic string
|
||||
JetStream JetStreamPublisher
|
||||
DB storage.KeyChangeDatabase
|
||||
}
|
||||
|
||||
// ProduceKeyChanges creates new change events for each key
|
||||
func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceMessage) error {
|
||||
userToDeviceCount := make(map[string]int)
|
||||
for _, key := range keys {
|
||||
id, err := p.DB.StoreKeyChange(context.Background(), key.UserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
key.DeviceChangeID = id
|
||||
value, err := json.Marshal(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m := &nats.Msg{
|
||||
Subject: p.Topic,
|
||||
Header: nats.Header{},
|
||||
}
|
||||
m.Header.Set(jetstream.UserID, key.UserID)
|
||||
m.Data = value
|
||||
|
||||
_, err = p.JetStream.PublishMsg(m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
userToDeviceCount[key.UserID]++
|
||||
}
|
||||
for userID, count := range userToDeviceCount {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"user_id": userID,
|
||||
"num_key_changes": count,
|
||||
}).Tracef("Produced to key change topic '%s'", p.Topic)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *KeyChange) ProduceSigningKeyUpdate(key api.CrossSigningKeyUpdate) error {
|
||||
output := &api.DeviceMessage{
|
||||
Type: api.TypeCrossSigningUpdate,
|
||||
OutputCrossSigningKeyUpdate: &api.OutputCrossSigningKeyUpdate{
|
||||
CrossSigningKeyUpdate: key,
|
||||
},
|
||||
}
|
||||
|
||||
id, err := p.DB.StoreKeyChange(context.Background(), key.UserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
output.DeviceChangeID = id
|
||||
|
||||
value, err := json.Marshal(output)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m := &nats.Msg{
|
||||
Subject: p.Topic,
|
||||
Header: nats.Header{},
|
||||
}
|
||||
m.Header.Set(jetstream.UserID, key.UserID)
|
||||
m.Data = value
|
||||
|
||||
_, err = p.JetStream.PublishMsg(m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"user_id": key.UserID,
|
||||
}).Tracef("Produced to cross-signing update topic '%s'", p.Topic)
|
||||
return nil
|
||||
}
|
|
@ -19,13 +19,13 @@ type JetStreamPublisher interface {
|
|||
|
||||
// SyncAPI produces messages for the Sync API server to consume.
|
||||
type SyncAPI struct {
|
||||
db storage.Database
|
||||
db storage.Notification
|
||||
producer JetStreamPublisher
|
||||
clientDataTopic string
|
||||
notificationDataTopic string
|
||||
}
|
||||
|
||||
func NewSyncAPI(db storage.Database, js JetStreamPublisher, clientDataTopic string, notificationDataTopic string) *SyncAPI {
|
||||
func NewSyncAPI(db storage.UserDatabase, js JetStreamPublisher, clientDataTopic string, notificationDataTopic string) *SyncAPI {
|
||||
return &SyncAPI{
|
||||
db: db,
|
||||
producer: js,
|
||||
|
|
|
@ -90,7 +90,7 @@ type KeyBackup interface {
|
|||
|
||||
type LoginToken interface {
|
||||
// CreateLoginToken generates a token, stores and returns it. The lifetime is
|
||||
// determined by the loginTokenLifetime given to the Database constructor.
|
||||
// determined by the loginTokenLifetime given to the UserDatabase constructor.
|
||||
CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error)
|
||||
|
||||
// RemoveLoginToken removes the named token (and may clean up other expired tokens).
|
||||
|
@ -130,7 +130,7 @@ type Notification interface {
|
|||
DeleteOldNotifications(ctx context.Context) error
|
||||
}
|
||||
|
||||
type Database interface {
|
||||
type UserDatabase interface {
|
||||
Account
|
||||
AccountData
|
||||
Device
|
||||
|
@ -144,6 +144,78 @@ type Database interface {
|
|||
ThreePID
|
||||
}
|
||||
|
||||
type KeyChangeDatabase interface {
|
||||
// StoreKeyChange stores key change metadata and returns the device change ID which represents the position in the /sync stream for this device change.
|
||||
// `userID` is the the user who has changed their keys in some way.
|
||||
StoreKeyChange(ctx context.Context, userID string) (int64, error)
|
||||
}
|
||||
|
||||
type KeyDatabase interface {
|
||||
KeyChangeDatabase
|
||||
// ExistingOneTimeKeys returns a map of keyIDWithAlgorithm to key JSON for the given parameters. If no keys exist with this combination
|
||||
// of user/device/key/algorithm 4-uple then it is omitted from the map. Returns an error when failing to communicate with the database.
|
||||
ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
|
||||
|
||||
// StoreOneTimeKeys persists the given one-time keys.
|
||||
StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error)
|
||||
|
||||
// OneTimeKeysCount returns a count of all OTKs for this device.
|
||||
OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)
|
||||
|
||||
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
|
||||
DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
|
||||
|
||||
// StoreLocalDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
|
||||
// for this (user, device).
|
||||
// The `StreamID` for each message is set on successful insertion. In the event the key already exists, the existing StreamID is set.
|
||||
// Returns an error if there was a problem storing the keys.
|
||||
StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error
|
||||
|
||||
// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
|
||||
// for this (user, device). Does not modify the stream ID for keys. User IDs in `clearUserIDs` will have all their device keys deleted prior
|
||||
// to insertion - use this when you have a complete snapshot of a user's keys in order to track device deletions correctly.
|
||||
StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error
|
||||
|
||||
// PrevIDsExists returns true if all prev IDs exist for this user.
|
||||
PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error)
|
||||
|
||||
// DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected.
|
||||
// If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice.
|
||||
DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error)
|
||||
|
||||
// DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying
|
||||
// cross-signing signatures relating to that device.
|
||||
DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error
|
||||
|
||||
// ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key
|
||||
// cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice.
|
||||
ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error)
|
||||
|
||||
// KeyChanges returns a list of user IDs who have modified their keys from the offset given (exclusive) to the offset given (inclusive).
|
||||
// A to offset of types.OffsetNewest means no upper limit.
|
||||
// Returns the offset of the latest key change.
|
||||
KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error)
|
||||
|
||||
// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
|
||||
// If no domains are given, all user IDs with stale device lists are returned.
|
||||
StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error)
|
||||
|
||||
// MarkDeviceListStale sets the stale bit for this user to isStale.
|
||||
MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error
|
||||
|
||||
CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error)
|
||||
CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error)
|
||||
CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error)
|
||||
|
||||
StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error
|
||||
StoreCrossSigningSigsForTarget(ctx context.Context, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error
|
||||
|
||||
DeleteStaleDeviceLists(
|
||||
ctx context.Context,
|
||||
userIDs []string,
|
||||
) error
|
||||
}
|
||||
|
||||
type Statistics interface {
|
||||
UserStatistics(ctx context.Context) (*types.UserStatistics, *types.DatabaseEngine, error)
|
||||
DailyRoomsMessages(ctx context.Context, serverName gomatrixserverlib.ServerName) (stats types.MessageStats, activeRooms, activeE2EERooms int64, err error)
|
||||
|
|
|
@ -78,7 +78,13 @@ func (s *accountDataStatements) InsertAccountData(
|
|||
roomID, dataType string, content json.RawMessage,
|
||||
) (err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertAccountDataStmt)
|
||||
_, err = stmt.ExecContext(ctx, localpart, serverName, roomID, dataType, content)
|
||||
// Empty/nil json.RawMessage is not interpreted as "nil", so use *json.RawMessage
|
||||
// when passing the data to trigger "NOT NULL" constraint
|
||||
var data *json.RawMessage
|
||||
if len(content) > 0 {
|
||||
data = &content
|
||||
}
|
||||
_, err = stmt.ExecContext(ctx, localpart, serverName, roomID, dataType, data)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
102
userapi/storage/postgres/cross_signing_keys_table.go
Normal file
102
userapi/storage/postgres/cross_signing_keys_table.go
Normal file
|
@ -0,0 +1,102 @@
|
|||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/userapi/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
var crossSigningKeysSchema = `
|
||||
CREATE TABLE IF NOT EXISTS keyserver_cross_signing_keys (
|
||||
user_id TEXT NOT NULL,
|
||||
key_type SMALLINT NOT NULL,
|
||||
key_data TEXT NOT NULL,
|
||||
PRIMARY KEY (user_id, key_type)
|
||||
);
|
||||
`
|
||||
|
||||
const selectCrossSigningKeysForUserSQL = "" +
|
||||
"SELECT key_type, key_data FROM keyserver_cross_signing_keys" +
|
||||
" WHERE user_id = $1"
|
||||
|
||||
const upsertCrossSigningKeysForUserSQL = "" +
|
||||
"INSERT INTO keyserver_cross_signing_keys (user_id, key_type, key_data)" +
|
||||
" VALUES($1, $2, $3)" +
|
||||
" ON CONFLICT (user_id, key_type) DO UPDATE SET key_data = $3"
|
||||
|
||||
type crossSigningKeysStatements struct {
|
||||
db *sql.DB
|
||||
selectCrossSigningKeysForUserStmt *sql.Stmt
|
||||
upsertCrossSigningKeysForUserStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error) {
|
||||
s := &crossSigningKeysStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(crossSigningKeysSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.selectCrossSigningKeysForUserStmt, selectCrossSigningKeysForUserSQL},
|
||||
{&s.upsertCrossSigningKeysForUserStmt, upsertCrossSigningKeysForUserSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser(
|
||||
ctx context.Context, txn *sql.Tx, userID string,
|
||||
) (r types.CrossSigningKeyMap, err error) {
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningKeysForUserStmt).QueryContext(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningKeysForUserStmt: rows.close() failed")
|
||||
r = types.CrossSigningKeyMap{}
|
||||
for rows.Next() {
|
||||
var keyTypeInt int16
|
||||
var keyData gomatrixserverlib.Base64Bytes
|
||||
if err := rows.Scan(&keyTypeInt, &keyData); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keyType, ok := types.KeyTypeIntToPurpose[keyTypeInt]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown key purpose int %d", keyTypeInt)
|
||||
}
|
||||
r[keyType] = keyData
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser(
|
||||
ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes,
|
||||
) error {
|
||||
keyTypeInt, ok := types.KeyTypePurposeToInt[keyType]
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown key purpose %q", keyType)
|
||||
}
|
||||
if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningKeysForUserStmt).ExecContext(ctx, userID, keyTypeInt, keyData); err != nil {
|
||||
return fmt.Errorf("s.upsertCrossSigningKeysForUserStmt: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
131
userapi/storage/postgres/cross_signing_sigs_table.go
Normal file
131
userapi/storage/postgres/cross_signing_sigs_table.go
Normal file
|
@ -0,0 +1,131 @@
|
|||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/postgres/deltas"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/userapi/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
var crossSigningSigsSchema = `
|
||||
CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs (
|
||||
origin_user_id TEXT NOT NULL,
|
||||
origin_key_id TEXT NOT NULL,
|
||||
target_user_id TEXT NOT NULL,
|
||||
target_key_id TEXT NOT NULL,
|
||||
signature TEXT NOT NULL,
|
||||
PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id);
|
||||
`
|
||||
|
||||
const selectCrossSigningSigsForTargetSQL = "" +
|
||||
"SELECT origin_user_id, origin_key_id, signature FROM keyserver_cross_signing_sigs" +
|
||||
" WHERE (origin_user_id = $1 OR origin_user_id = $2) AND target_user_id = $2 AND target_key_id = $3"
|
||||
|
||||
const upsertCrossSigningSigsForTargetSQL = "" +
|
||||
"INSERT INTO keyserver_cross_signing_sigs (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" +
|
||||
" VALUES($1, $2, $3, $4, $5)" +
|
||||
" ON CONFLICT (origin_user_id, origin_key_id, target_user_id, target_key_id) DO UPDATE SET signature = $5"
|
||||
|
||||
const deleteCrossSigningSigsForTargetSQL = "" +
|
||||
"DELETE FROM keyserver_cross_signing_sigs WHERE target_user_id=$1 AND target_key_id=$2"
|
||||
|
||||
type crossSigningSigsStatements struct {
|
||||
db *sql.DB
|
||||
selectCrossSigningSigsForTargetStmt *sql.Stmt
|
||||
upsertCrossSigningSigsForTargetStmt *sql.Stmt
|
||||
deleteCrossSigningSigsForTargetStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, error) {
|
||||
s := &crossSigningSigsStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(crossSigningSigsSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m := sqlutil.NewMigrator(db)
|
||||
m.AddMigrations(sqlutil.Migration{
|
||||
Version: "keyserver: cross signing signature indexes",
|
||||
Up: deltas.UpFixCrossSigningSignatureIndexes,
|
||||
})
|
||||
if err = m.Up(context.Background()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.selectCrossSigningSigsForTargetStmt, selectCrossSigningSigsForTargetSQL},
|
||||
{&s.upsertCrossSigningSigsForTargetStmt, upsertCrossSigningSigsForTargetSQL},
|
||||
{&s.deleteCrossSigningSigsForTargetStmt, deleteCrossSigningSigsForTargetSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget(
|
||||
ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID,
|
||||
) (r types.CrossSigningSigMap, err error) {
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningSigsForTargetStmt).QueryContext(ctx, originUserID, targetUserID, targetKeyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningSigsForTargetStmt: rows.close() failed")
|
||||
r = types.CrossSigningSigMap{}
|
||||
for rows.Next() {
|
||||
var userID string
|
||||
var keyID gomatrixserverlib.KeyID
|
||||
var signature gomatrixserverlib.Base64Bytes
|
||||
if err := rows.Scan(&userID, &keyID, &signature); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, ok := r[userID]; !ok {
|
||||
r[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
|
||||
}
|
||||
r[userID][keyID] = signature
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *crossSigningSigsStatements) UpsertCrossSigningSigsForTarget(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
originUserID string, originKeyID gomatrixserverlib.KeyID,
|
||||
targetUserID string, targetKeyID gomatrixserverlib.KeyID,
|
||||
signature gomatrixserverlib.Base64Bytes,
|
||||
) error {
|
||||
if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningSigsForTargetStmt).ExecContext(ctx, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil {
|
||||
return fmt.Errorf("s.upsertCrossSigningSigsForTargetStmt: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *crossSigningSigsStatements) DeleteCrossSigningSigsForTarget(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
targetUserID string, targetKeyID gomatrixserverlib.KeyID,
|
||||
) error {
|
||||
if _, err := sqlutil.TxStmt(txn, s.deleteCrossSigningSigsForTargetStmt).ExecContext(ctx, targetUserID, targetKeyID); err != nil {
|
||||
return fmt.Errorf("s.deleteCrossSigningSigsForTargetStmt: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,69 @@
|
|||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package deltas
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func UpRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error {
|
||||
// start counting from the last max offset, else 0. We need to do a count(*) first to see if there
|
||||
// even are entries in this table to know if we can query for log_offset. Without the count then
|
||||
// the query to SELECT the max log offset fails on new Dendrite instances as log_offset doesn't
|
||||
// exist on that table. Even though we discard the error, the txn is tainted and gets aborted :/
|
||||
var count int
|
||||
_ = tx.QueryRowContext(ctx, `SELECT count(*) FROM keyserver_key_changes`).Scan(&count)
|
||||
if count > 0 {
|
||||
var maxOffset int64
|
||||
_ = tx.QueryRowContext(ctx, `SELECT coalesce(MAX(log_offset), 0) AS offset FROM keyserver_key_changes`).Scan(&maxOffset)
|
||||
if _, err := tx.ExecContext(ctx, fmt.Sprintf(`CREATE SEQUENCE IF NOT EXISTS keyserver_key_changes_seq START %d`, maxOffset)); err != nil {
|
||||
return fmt.Errorf("failed to CREATE SEQUENCE for key changes, starting at %d: %s", maxOffset, err)
|
||||
}
|
||||
}
|
||||
|
||||
_, err := tx.ExecContext(ctx, `
|
||||
-- make the new table
|
||||
DROP TABLE IF EXISTS keyserver_key_changes;
|
||||
CREATE TABLE IF NOT EXISTS keyserver_key_changes (
|
||||
change_id BIGINT PRIMARY KEY DEFAULT nextval('keyserver_key_changes_seq'),
|
||||
user_id TEXT NOT NULL,
|
||||
CONSTRAINT keyserver_key_changes_unique_per_user UNIQUE (user_id)
|
||||
);
|
||||
`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute upgrade: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func DownRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error {
|
||||
_, err := tx.ExecContext(ctx, `
|
||||
-- Drop all data and revert back, we can't keep the data as Kafka offsets determine the numbers
|
||||
DROP SEQUENCE IF EXISTS keyserver_key_changes_seq;
|
||||
DROP TABLE IF EXISTS keyserver_key_changes;
|
||||
CREATE TABLE IF NOT EXISTS keyserver_key_changes (
|
||||
partition BIGINT NOT NULL,
|
||||
log_offset BIGINT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
CONSTRAINT keyserver_key_changes_unique UNIQUE (partition, log_offset)
|
||||
);
|
||||
`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute downgrade: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,47 @@
|
|||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package deltas
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func UpFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error {
|
||||
_, err := tx.ExecContext(ctx, `
|
||||
ALTER TABLE keyserver_cross_signing_sigs DROP CONSTRAINT keyserver_cross_signing_sigs_pkey;
|
||||
ALTER TABLE keyserver_cross_signing_sigs ADD PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id);
|
||||
`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute upgrade: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func DownFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error {
|
||||
_, err := tx.ExecContext(ctx, `
|
||||
ALTER TABLE keyserver_cross_signing_sigs DROP CONSTRAINT keyserver_cross_signing_sigs_pkey;
|
||||
ALTER TABLE keyserver_cross_signing_sigs ADD PRIMARY KEY (origin_user_id, target_user_id, target_key_id);
|
||||
|
||||
DROP INDEX IF EXISTS keyserver_cross_signing_sigs_idx;
|
||||
`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute downgrade: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
213
userapi/storage/postgres/device_keys_table.go
Normal file
213
userapi/storage/postgres/device_keys_table.go
Normal file
|
@ -0,0 +1,213 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
)
|
||||
|
||||
var deviceKeysSchema = `
|
||||
-- Stores device keys for users
|
||||
CREATE TABLE IF NOT EXISTS keyserver_device_keys (
|
||||
user_id TEXT NOT NULL,
|
||||
device_id TEXT NOT NULL,
|
||||
ts_added_secs BIGINT NOT NULL,
|
||||
key_json TEXT NOT NULL,
|
||||
-- the stream ID of this key, scoped per-user. This gets updated when the device key changes.
|
||||
-- This means we do not store an unbounded append-only log of device keys, which is not actually
|
||||
-- required in the spec because in the event of a missed update the server fetches the entire
|
||||
-- current set of keys rather than trying to 'fast-forward' or catchup missing stream IDs.
|
||||
stream_id BIGINT NOT NULL,
|
||||
display_name TEXT,
|
||||
-- Clobber based on tuple of user/device.
|
||||
CONSTRAINT keyserver_device_keys_unique UNIQUE (user_id, device_id)
|
||||
);
|
||||
`
|
||||
|
||||
const upsertDeviceKeysSQL = "" +
|
||||
"INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" +
|
||||
" VALUES ($1, $2, $3, $4, $5, $6)" +
|
||||
" ON CONFLICT ON CONSTRAINT keyserver_device_keys_unique" +
|
||||
" DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6"
|
||||
|
||||
const selectDeviceKeysSQL = "" +
|
||||
"SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
|
||||
|
||||
const selectBatchDeviceKeysSQL = "" +
|
||||
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''"
|
||||
|
||||
const selectBatchDeviceKeysWithEmptiesSQL = "" +
|
||||
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1"
|
||||
|
||||
const selectMaxStreamForUserSQL = "" +
|
||||
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
|
||||
|
||||
const countStreamIDsForUserSQL = "" +
|
||||
"SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id = ANY($2)"
|
||||
|
||||
const deleteDeviceKeysSQL = "" +
|
||||
"DELETE FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
|
||||
|
||||
const deleteAllDeviceKeysSQL = "" +
|
||||
"DELETE FROM keyserver_device_keys WHERE user_id=$1"
|
||||
|
||||
type deviceKeysStatements struct {
|
||||
db *sql.DB
|
||||
upsertDeviceKeysStmt *sql.Stmt
|
||||
selectDeviceKeysStmt *sql.Stmt
|
||||
selectBatchDeviceKeysStmt *sql.Stmt
|
||||
selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt
|
||||
selectMaxStreamForUserStmt *sql.Stmt
|
||||
countStreamIDsForUserStmt *sql.Stmt
|
||||
deleteDeviceKeysStmt *sql.Stmt
|
||||
deleteAllDeviceKeysStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
||||
s := &deviceKeysStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(deviceKeysSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.upsertDeviceKeysStmt, upsertDeviceKeysSQL},
|
||||
{&s.selectDeviceKeysStmt, selectDeviceKeysSQL},
|
||||
{&s.selectBatchDeviceKeysStmt, selectBatchDeviceKeysSQL},
|
||||
{&s.selectBatchDeviceKeysWithEmptiesStmt, selectBatchDeviceKeysWithEmptiesSQL},
|
||||
{&s.selectMaxStreamForUserStmt, selectMaxStreamForUserSQL},
|
||||
{&s.countStreamIDsForUserStmt, countStreamIDsForUserSQL},
|
||||
{&s.deleteDeviceKeysStmt, deleteDeviceKeysSQL},
|
||||
{&s.deleteAllDeviceKeysStmt, deleteAllDeviceKeysSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
|
||||
for i, key := range keys {
|
||||
var keyJSONStr string
|
||||
var streamID int64
|
||||
var displayName sql.NullString
|
||||
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
// this will be '' when there is no device
|
||||
keys[i].Type = api.TypeDeviceKeyUpdate
|
||||
keys[i].KeyJSON = []byte(keyJSONStr)
|
||||
keys[i].StreamID = streamID
|
||||
if displayName.Valid {
|
||||
keys[i].DisplayName = displayName.String
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error) {
|
||||
// nullable if there are no results
|
||||
var nullStream sql.NullInt64
|
||||
err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
|
||||
if err == sql.ErrNoRows {
|
||||
err = nil
|
||||
}
|
||||
if nullStream.Valid {
|
||||
streamID = nullStream.Int64
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) {
|
||||
// nullable if there are no results
|
||||
var count sql.NullInt32
|
||||
err := s.countStreamIDsForUserStmt.QueryRowContext(ctx, userID, pq.Int64Array(streamIDs)).Scan(&count)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if count.Valid {
|
||||
return int(count.Int32), nil
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
|
||||
for _, key := range keys {
|
||||
now := time.Now().Unix()
|
||||
_, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext(
|
||||
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.deleteDeviceKeysStmt).ExecContext(ctx, userID, deviceID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
|
||||
var stmt *sql.Stmt
|
||||
if includeEmpty {
|
||||
stmt = s.selectBatchDeviceKeysWithEmptiesStmt
|
||||
} else {
|
||||
stmt = s.selectBatchDeviceKeysStmt
|
||||
}
|
||||
rows, err := stmt.QueryContext(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed")
|
||||
deviceIDMap := make(map[string]bool)
|
||||
for _, d := range deviceIDs {
|
||||
deviceIDMap[d] = true
|
||||
}
|
||||
var result []api.DeviceMessage
|
||||
var displayName sql.NullString
|
||||
for rows.Next() {
|
||||
dk := api.DeviceMessage{
|
||||
Type: api.TypeDeviceKeyUpdate,
|
||||
DeviceKeys: &api.DeviceKeys{
|
||||
UserID: userID,
|
||||
},
|
||||
}
|
||||
if err := rows.Scan(&dk.DeviceID, &dk.KeyJSON, &dk.StreamID, &displayName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if displayName.Valid {
|
||||
dk.DisplayName = displayName.String
|
||||
}
|
||||
// include the key if we want all keys (no device) or it was asked
|
||||
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
|
||||
result = append(result, dk)
|
||||
}
|
||||
}
|
||||
return result, rows.Err()
|
||||
}
|
|
@ -160,7 +160,7 @@ func (s *devicesStatements) InsertDevice(
|
|||
if err := stmt.QueryRowContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, createdTimeMS, ipAddr, userAgent).Scan(&sessionID); err != nil {
|
||||
return nil, fmt.Errorf("insertDeviceStmt: %w", err)
|
||||
}
|
||||
return &api.Device{
|
||||
dev := &api.Device{
|
||||
ID: id,
|
||||
UserID: userutil.MakeUserID(localpart, serverName),
|
||||
AccessToken: accessToken,
|
||||
|
@ -168,7 +168,11 @@ func (s *devicesStatements) InsertDevice(
|
|||
LastSeenTS: createdTimeMS,
|
||||
LastSeenIP: ipAddr,
|
||||
UserAgent: userAgent,
|
||||
}, nil
|
||||
}
|
||||
if displayName != nil {
|
||||
dev.DisplayName = *displayName
|
||||
}
|
||||
return dev, nil
|
||||
}
|
||||
|
||||
func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn *sql.Tx, id,
|
||||
|
|
|
@ -52,7 +52,7 @@ const updateBackupKeySQL = "" +
|
|||
const countKeysSQL = "" +
|
||||
"SELECT COUNT(*) FROM userapi_key_backups WHERE user_id = $1 AND version = $2"
|
||||
|
||||
const selectKeysSQL = "" +
|
||||
const selectBackupKeysSQL = "" +
|
||||
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " +
|
||||
"WHERE user_id = $1 AND version = $2"
|
||||
|
||||
|
@ -83,7 +83,7 @@ func NewPostgresKeyBackupTable(db *sql.DB) (tables.KeyBackupTable, error) {
|
|||
{&s.insertBackupKeyStmt, insertBackupKeySQL},
|
||||
{&s.updateBackupKeyStmt, updateBackupKeySQL},
|
||||
{&s.countKeysStmt, countKeysSQL},
|
||||
{&s.selectKeysStmt, selectKeysSQL},
|
||||
{&s.selectKeysStmt, selectBackupKeysSQL},
|
||||
{&s.selectKeysByRoomIDStmt, selectKeysByRoomIDSQL},
|
||||
{&s.selectKeysByRoomIDAndSessionIDStmt, selectKeysByRoomIDAndSessionIDSQL},
|
||||
}.Prepare(db)
|
||||
|
|
127
userapi/storage/postgres/key_changes_table.go
Normal file
127
userapi/storage/postgres/key_changes_table.go
Normal file
|
@ -0,0 +1,127 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/postgres/deltas"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
)
|
||||
|
||||
var keyChangesSchema = `
|
||||
-- Stores key change information about users. Used to determine when to send updated device lists to clients.
|
||||
CREATE SEQUENCE IF NOT EXISTS keyserver_key_changes_seq;
|
||||
CREATE TABLE IF NOT EXISTS keyserver_key_changes (
|
||||
change_id BIGINT PRIMARY KEY DEFAULT nextval('keyserver_key_changes_seq'),
|
||||
user_id TEXT NOT NULL,
|
||||
CONSTRAINT keyserver_key_changes_unique_per_user UNIQUE (user_id)
|
||||
);
|
||||
`
|
||||
|
||||
// Replace based on user ID. We don't care how many times the user's keys have changed, only that they
|
||||
// have changed, hence we can just keep bumping the change ID for this user.
|
||||
const upsertKeyChangeSQL = "" +
|
||||
"INSERT INTO keyserver_key_changes (user_id)" +
|
||||
" VALUES ($1)" +
|
||||
" ON CONFLICT ON CONSTRAINT keyserver_key_changes_unique_per_user" +
|
||||
" DO UPDATE SET change_id = nextval('keyserver_key_changes_seq')" +
|
||||
" RETURNING change_id"
|
||||
|
||||
const selectKeyChangesSQL = "" +
|
||||
"SELECT user_id, change_id FROM keyserver_key_changes WHERE change_id > $1 AND change_id <= $2"
|
||||
|
||||
type keyChangesStatements struct {
|
||||
db *sql.DB
|
||||
upsertKeyChangeStmt *sql.Stmt
|
||||
selectKeyChangesStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
|
||||
s := &keyChangesStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(keyChangesSchema)
|
||||
if err != nil {
|
||||
return s, err
|
||||
}
|
||||
|
||||
if err = executeMigration(context.Background(), db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.upsertKeyChangeStmt, upsertKeyChangeSQL},
|
||||
{&s.selectKeyChangesStmt, selectKeyChangesSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func executeMigration(ctx context.Context, db *sql.DB) error {
|
||||
// TODO: Remove when we are sure we are not having goose artefacts in the db
|
||||
// This forces an error, which indicates the migration is already applied, since the
|
||||
// column partition was removed from the table
|
||||
migrationName := "keyserver: refactor key changes"
|
||||
|
||||
var cName string
|
||||
err := db.QueryRowContext(ctx, "select column_name from information_schema.columns where table_name = 'keyserver_key_changes' AND column_name = 'partition'").Scan(&cName)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) { // migration was already executed, as the column was removed
|
||||
if err = sqlutil.InsertMigration(ctx, db, migrationName); err != nil {
|
||||
return fmt.Errorf("unable to manually insert migration '%s': %w", migrationName, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
m := sqlutil.NewMigrator(db)
|
||||
m.AddMigrations(sqlutil.Migration{
|
||||
Version: migrationName,
|
||||
Up: deltas.UpRefactorKeyChanges,
|
||||
})
|
||||
|
||||
return m.Up(ctx)
|
||||
}
|
||||
|
||||
func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, userID string) (changeID int64, err error) {
|
||||
err = s.upsertKeyChangeStmt.QueryRowContext(ctx, userID).Scan(&changeID)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *keyChangesStatements) SelectKeyChanges(
|
||||
ctx context.Context, fromOffset, toOffset int64,
|
||||
) (userIDs []string, latestOffset int64, err error) {
|
||||
latestOffset = fromOffset
|
||||
rows, err := s.selectKeyChangesStmt.QueryContext(ctx, fromOffset, toOffset)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectKeyChangesStmt: rows.close() failed")
|
||||
for rows.Next() {
|
||||
var userID string
|
||||
var offset int64
|
||||
if err := rows.Scan(&userID, &offset); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if offset > latestOffset {
|
||||
latestOffset = offset
|
||||
}
|
||||
userIDs = append(userIDs, userID)
|
||||
}
|
||||
return
|
||||
}
|
194
userapi/storage/postgres/one_time_keys_table.go
Normal file
194
userapi/storage/postgres/one_time_keys_table.go
Normal file
|
@ -0,0 +1,194 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
)
|
||||
|
||||
var oneTimeKeysSchema = `
|
||||
-- Stores one-time public keys for users
|
||||
CREATE TABLE IF NOT EXISTS keyserver_one_time_keys (
|
||||
user_id TEXT NOT NULL,
|
||||
device_id TEXT NOT NULL,
|
||||
key_id TEXT NOT NULL,
|
||||
algorithm TEXT NOT NULL,
|
||||
ts_added_secs BIGINT NOT NULL,
|
||||
key_json TEXT NOT NULL,
|
||||
-- Clobber based on 4-uple of user/device/key/algorithm.
|
||||
CONSTRAINT keyserver_one_time_keys_unique UNIQUE (user_id, device_id, key_id, algorithm)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS keyserver_one_time_keys_idx ON keyserver_one_time_keys (user_id, device_id);
|
||||
`
|
||||
|
||||
const upsertKeysSQL = "" +
|
||||
"INSERT INTO keyserver_one_time_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json)" +
|
||||
" VALUES ($1, $2, $3, $4, $5, $6)" +
|
||||
" ON CONFLICT ON CONSTRAINT keyserver_one_time_keys_unique" +
|
||||
" DO UPDATE SET key_json = $6"
|
||||
|
||||
const selectOneTimeKeysSQL = "" +
|
||||
"SELECT concat(algorithm, ':', key_id) as algorithmwithid, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 AND concat(algorithm, ':', key_id) = ANY($3);"
|
||||
|
||||
const selectKeysCountSQL = "" +
|
||||
"SELECT algorithm, COUNT(key_id) FROM " +
|
||||
" (SELECT algorithm, key_id FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 LIMIT 100)" +
|
||||
" x GROUP BY algorithm"
|
||||
|
||||
const deleteOneTimeKeySQL = "" +
|
||||
"DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 AND key_id = $4"
|
||||
|
||||
const selectKeyByAlgorithmSQL = "" +
|
||||
"SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1"
|
||||
|
||||
const deleteOneTimeKeysSQL = "" +
|
||||
"DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2"
|
||||
|
||||
type oneTimeKeysStatements struct {
|
||||
db *sql.DB
|
||||
upsertKeysStmt *sql.Stmt
|
||||
selectKeysStmt *sql.Stmt
|
||||
selectKeysCountStmt *sql.Stmt
|
||||
selectKeyByAlgorithmStmt *sql.Stmt
|
||||
deleteOneTimeKeyStmt *sql.Stmt
|
||||
deleteOneTimeKeysStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
|
||||
s := &oneTimeKeysStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(oneTimeKeysSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.upsertKeysStmt, upsertKeysSQL},
|
||||
{&s.selectKeysStmt, selectOneTimeKeysSQL},
|
||||
{&s.selectKeysCountStmt, selectKeysCountSQL},
|
||||
{&s.selectKeyByAlgorithmStmt, selectKeyByAlgorithmSQL},
|
||||
{&s.deleteOneTimeKeyStmt, deleteOneTimeKeySQL},
|
||||
{&s.deleteOneTimeKeysStmt, deleteOneTimeKeysSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
|
||||
rows, err := s.selectKeysStmt.QueryContext(ctx, userID, deviceID, pq.Array(keyIDsWithAlgorithms))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt: rows.close() failed")
|
||||
|
||||
result := make(map[string]json.RawMessage)
|
||||
var (
|
||||
algorithmWithID string
|
||||
keyJSONStr string
|
||||
)
|
||||
for rows.Next() {
|
||||
if err := rows.Scan(&algorithmWithID, &keyJSONStr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result[algorithmWithID] = json.RawMessage(keyJSONStr)
|
||||
}
|
||||
return result, rows.Err()
|
||||
}
|
||||
|
||||
func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) {
|
||||
counts := &api.OneTimeKeysCount{
|
||||
DeviceID: deviceID,
|
||||
UserID: userID,
|
||||
KeyCount: make(map[string]int),
|
||||
}
|
||||
rows, err := s.selectKeysCountStmt.QueryContext(ctx, userID, deviceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
|
||||
for rows.Next() {
|
||||
var algorithm string
|
||||
var count int
|
||||
if err = rows.Scan(&algorithm, &count); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
counts.KeyCount[algorithm] = count
|
||||
}
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, txn *sql.Tx, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) {
|
||||
now := time.Now().Unix()
|
||||
counts := &api.OneTimeKeysCount{
|
||||
DeviceID: keys.DeviceID,
|
||||
UserID: keys.UserID,
|
||||
KeyCount: make(map[string]int),
|
||||
}
|
||||
for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
|
||||
algo, keyID := keys.Split(keyIDWithAlgo)
|
||||
_, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext(
|
||||
ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
|
||||
for rows.Next() {
|
||||
var algorithm string
|
||||
var count int
|
||||
if err = rows.Scan(&algorithm, &count); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
counts.KeyCount[algorithm] = count
|
||||
}
|
||||
|
||||
return counts, rows.Err()
|
||||
}
|
||||
|
||||
func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey(
|
||||
ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string,
|
||||
) (map[string]json.RawMessage, error) {
|
||||
var keyID string
|
||||
var keyJSON string
|
||||
err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
_, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
|
||||
return map[string]json.RawMessage{
|
||||
algorithm + ":" + keyID: json.RawMessage(keyJSON),
|
||||
}, err
|
||||
}
|
||||
|
||||
func (s *oneTimeKeysStatements) DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.deleteOneTimeKeysStmt).ExecContext(ctx, userID, deviceID)
|
||||
return err
|
||||
}
|
131
userapi/storage/postgres/stale_device_lists.go
Normal file
131
userapi/storage/postgres/stale_device_lists.go
Normal file
|
@ -0,0 +1,131 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
var staleDeviceListsSchema = `
|
||||
-- Stores whether a user's device lists are stale or not.
|
||||
CREATE TABLE IF NOT EXISTS keyserver_stale_device_lists (
|
||||
user_id TEXT PRIMARY KEY NOT NULL,
|
||||
domain TEXT NOT NULL,
|
||||
is_stale BOOLEAN NOT NULL,
|
||||
ts_added_secs BIGINT NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS keyserver_stale_device_lists_idx ON keyserver_stale_device_lists (domain, is_stale);
|
||||
`
|
||||
|
||||
const upsertStaleDeviceListSQL = "" +
|
||||
"INSERT INTO keyserver_stale_device_lists (user_id, domain, is_stale, ts_added_secs)" +
|
||||
" VALUES ($1, $2, $3, $4)" +
|
||||
" ON CONFLICT (user_id)" +
|
||||
" DO UPDATE SET is_stale = $3, ts_added_secs = $4"
|
||||
|
||||
const selectStaleDeviceListsWithDomainsSQL = "" +
|
||||
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2 ORDER BY ts_added_secs DESC"
|
||||
|
||||
const selectStaleDeviceListsSQL = "" +
|
||||
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 ORDER BY ts_added_secs DESC"
|
||||
|
||||
const deleteStaleDevicesSQL = "" +
|
||||
"DELETE FROM keyserver_stale_device_lists WHERE user_id = ANY($1)"
|
||||
|
||||
type staleDeviceListsStatements struct {
|
||||
upsertStaleDeviceListStmt *sql.Stmt
|
||||
selectStaleDeviceListsWithDomainsStmt *sql.Stmt
|
||||
selectStaleDeviceListsStmt *sql.Stmt
|
||||
deleteStaleDeviceListsStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) {
|
||||
s := &staleDeviceListsStatements{}
|
||||
_, err := db.Exec(staleDeviceListsSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.upsertStaleDeviceListStmt, upsertStaleDeviceListSQL},
|
||||
{&s.selectStaleDeviceListsStmt, selectStaleDeviceListsSQL},
|
||||
{&s.selectStaleDeviceListsWithDomainsStmt, selectStaleDeviceListsWithDomainsSQL},
|
||||
{&s.deleteStaleDeviceListsStmt, deleteStaleDevicesSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error {
|
||||
_, domain, err := gomatrixserverlib.SplitID('@', userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, gomatrixserverlib.AsTimestamp(time.Now()))
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
|
||||
// we only query for 1 domain or all domains so optimise for those use cases
|
||||
if len(domains) == 0 {
|
||||
rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return rowsToUserIDs(ctx, rows)
|
||||
}
|
||||
var result []string
|
||||
for _, domain := range domains {
|
||||
rows, err := s.selectStaleDeviceListsWithDomainsStmt.QueryContext(ctx, true, string(domain))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userIDs, err := rowsToUserIDs(ctx, rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, userIDs...)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// DeleteStaleDeviceLists removes users from stale device lists
|
||||
func (s *staleDeviceListsStatements) DeleteStaleDeviceLists(
|
||||
ctx context.Context, txn *sql.Tx, userIDs []string,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteStaleDeviceListsStmt)
|
||||
_, err := stmt.ExecContext(ctx, pq.Array(userIDs))
|
||||
return err
|
||||
}
|
||||
|
||||
func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) {
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed")
|
||||
for rows.Next() {
|
||||
var userID string
|
||||
if err := rows.Scan(&userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, userID)
|
||||
}
|
||||
return result, rows.Err()
|
||||
}
|
|
@ -136,3 +136,44 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
|
|||
OpenIDTokenLifetimeMS: openIDTokenLifetimeMS,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func NewKeyDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.KeyDatabase, error) {
|
||||
db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
otk, err := NewPostgresOneTimeKeysTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dk, err := NewPostgresDeviceKeysTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
kc, err := NewPostgresKeyChangesTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sdl, err := NewPostgresStaleDeviceListsTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
csk, err := NewPostgresCrossSigningKeysTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
css, err := NewPostgresCrossSigningSigsTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &shared.KeyDatabase{
|
||||
OneTimeKeysTable: otk,
|
||||
DeviceKeysTable: dk,
|
||||
KeyChangesTable: kc,
|
||||
StaleDeviceListsTable: sdl,
|
||||
CrossSigningKeysTable: csk,
|
||||
CrossSigningSigsTable: css,
|
||||
Writer: writer,
|
||||
}, nil
|
||||
}
|
||||
|
|
|
@ -59,6 +59,17 @@ type Database struct {
|
|||
OpenIDTokenLifetimeMS int64
|
||||
}
|
||||
|
||||
type KeyDatabase struct {
|
||||
OneTimeKeysTable tables.OneTimeKeys
|
||||
DeviceKeysTable tables.DeviceKeys
|
||||
KeyChangesTable tables.KeyChanges
|
||||
StaleDeviceListsTable tables.StaleDeviceLists
|
||||
CrossSigningKeysTable tables.CrossSigningKeys
|
||||
CrossSigningSigsTable tables.CrossSigningSigs
|
||||
DB *sql.DB
|
||||
Writer sqlutil.Writer
|
||||
}
|
||||
|
||||
const (
|
||||
// The length of generated device IDs
|
||||
deviceIDByteLength = 6
|
||||
|
@ -875,3 +886,227 @@ func (d *Database) DailyRoomsMessages(
|
|||
) (stats types.MessageStats, activeRooms, activeE2EERooms int64, err error) {
|
||||
return d.Stats.DailyRoomsMessages(ctx, nil, serverName)
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
func (d *KeyDatabase) ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
|
||||
return d.OneTimeKeysTable.SelectOneTimeKeys(ctx, userID, deviceID, keyIDsWithAlgorithms)
|
||||
}
|
||||
|
||||
func (d *KeyDatabase) StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (counts *api.OneTimeKeysCount, err error) {
|
||||
_ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
counts, err = d.OneTimeKeysTable.InsertOneTimeKeys(ctx, txn, keys)
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (d *KeyDatabase) OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) {
|
||||
return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID)
|
||||
}
|
||||
|
||||
func (d *KeyDatabase) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
|
||||
return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys)
|
||||
}
|
||||
|
||||
func (d *KeyDatabase) PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) {
|
||||
count, err := d.DeviceKeysTable.CountStreamIDsForUser(ctx, userID, prevIDs)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return count == len(prevIDs), nil
|
||||
}
|
||||
|
||||
func (d *KeyDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
for _, userID := range clearUserIDs {
|
||||
err := d.DeviceKeysTable.DeleteAllDeviceKeys(ctx, txn, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys)
|
||||
})
|
||||
}
|
||||
|
||||
func (d *KeyDatabase) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error {
|
||||
// work out the latest stream IDs for each user
|
||||
userIDToStreamID := make(map[string]int64)
|
||||
for _, k := range keys {
|
||||
userIDToStreamID[k.UserID] = 0
|
||||
}
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
for userID := range userIDToStreamID {
|
||||
streamID, err := d.DeviceKeysTable.SelectMaxStreamIDForUser(ctx, txn, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userIDToStreamID[userID] = streamID
|
||||
}
|
||||
// set the stream IDs for each key
|
||||
for i := range keys {
|
||||
k := keys[i]
|
||||
userIDToStreamID[k.UserID]++ // start stream from 1
|
||||
k.StreamID = userIDToStreamID[k.UserID]
|
||||
keys[i] = k
|
||||
}
|
||||
return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys)
|
||||
})
|
||||
}
|
||||
|
||||
func (d *KeyDatabase) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
|
||||
return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs, includeEmpty)
|
||||
}
|
||||
|
||||
func (d *KeyDatabase) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) {
|
||||
var result []api.OneTimeKeys
|
||||
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
for userID, deviceToAlgo := range userToDeviceToAlgorithm {
|
||||
for deviceID, algo := range deviceToAlgo {
|
||||
keyJSON, err := d.OneTimeKeysTable.SelectAndDeleteOneTimeKey(ctx, txn, userID, deviceID, algo)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if keyJSON != nil {
|
||||
result = append(result, api.OneTimeKeys{
|
||||
UserID: userID,
|
||||
DeviceID: deviceID,
|
||||
KeyJSON: keyJSON,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (d *KeyDatabase) StoreKeyChange(ctx context.Context, userID string) (id int64, err error) {
|
||||
err = d.Writer.Do(nil, nil, func(_ *sql.Tx) error {
|
||||
id, err = d.KeyChangesTable.InsertKeyChange(ctx, userID)
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (d *KeyDatabase) KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) {
|
||||
return d.KeyChangesTable.SelectKeyChanges(ctx, fromOffset, toOffset)
|
||||
}
|
||||
|
||||
// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
|
||||
// If no domains are given, all user IDs with stale device lists are returned.
|
||||
func (d *KeyDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
|
||||
return d.StaleDeviceListsTable.SelectUserIDsWithStaleDeviceLists(ctx, domains)
|
||||
}
|
||||
|
||||
// MarkDeviceListStale sets the stale bit for this user to isStale.
|
||||
func (d *KeyDatabase) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error {
|
||||
return d.Writer.Do(nil, nil, func(_ *sql.Tx) error {
|
||||
return d.StaleDeviceListsTable.InsertStaleDeviceList(ctx, userID, isStale)
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying
|
||||
// cross-signing signatures relating to that device.
|
||||
func (d *KeyDatabase) DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
for _, deviceID := range deviceIDs {
|
||||
if err := d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget(ctx, txn, userID, deviceID); err != nil && err != sql.ErrNoRows {
|
||||
return fmt.Errorf("d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget: %w", err)
|
||||
}
|
||||
if err := d.DeviceKeysTable.DeleteDeviceKeys(ctx, txn, userID, string(deviceID)); err != nil && err != sql.ErrNoRows {
|
||||
return fmt.Errorf("d.DeviceKeysTable.DeleteDeviceKeys: %w", err)
|
||||
}
|
||||
if err := d.OneTimeKeysTable.DeleteOneTimeKeys(ctx, txn, userID, string(deviceID)); err != nil && err != sql.ErrNoRows {
|
||||
return fmt.Errorf("d.OneTimeKeysTable.DeleteOneTimeKeys: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any.
|
||||
func (d *KeyDatabase) CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error) {
|
||||
keyMap, err := d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("d.CrossSigningKeysTable.SelectCrossSigningKeysForUser: %w", err)
|
||||
}
|
||||
results := map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey{}
|
||||
for purpose, key := range keyMap {
|
||||
keyID := gomatrixserverlib.KeyID("ed25519:" + key.Encode())
|
||||
result := gomatrixserverlib.CrossSigningKey{
|
||||
UserID: userID,
|
||||
Usage: []gomatrixserverlib.CrossSigningKeyPurpose{purpose},
|
||||
Keys: map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{
|
||||
keyID: key,
|
||||
},
|
||||
}
|
||||
sigMap, err := d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, userID, userID, keyID)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
for sigUserID, forSigUserID := range sigMap {
|
||||
if userID != sigUserID {
|
||||
continue
|
||||
}
|
||||
if result.Signatures == nil {
|
||||
result.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
|
||||
}
|
||||
if _, ok := result.Signatures[sigUserID]; !ok {
|
||||
result.Signatures[sigUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
|
||||
}
|
||||
for sigKeyID, sigBytes := range forSigUserID {
|
||||
result.Signatures[sigUserID][sigKeyID] = sigBytes
|
||||
}
|
||||
}
|
||||
results[purpose] = result
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any.
|
||||
func (d *KeyDatabase) CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error) {
|
||||
return d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID)
|
||||
}
|
||||
|
||||
// CrossSigningSigsForTarget returns the signatures for a given user's key ID, if any.
|
||||
func (d *KeyDatabase) CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) {
|
||||
return d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, originUserID, targetUserID, targetKeyID)
|
||||
}
|
||||
|
||||
// StoreCrossSigningKeysForUser stores the latest known cross-signing keys for a user.
|
||||
func (d *KeyDatabase) StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
for keyType, keyData := range keyMap {
|
||||
if err := d.CrossSigningKeysTable.UpsertCrossSigningKeysForUser(ctx, txn, userID, keyType, keyData); err != nil {
|
||||
return fmt.Errorf("d.CrossSigningKeysTable.InsertCrossSigningKeysForUser: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// StoreCrossSigningSigsForTarget stores a signature for a target user ID and key/dvice.
|
||||
func (d *KeyDatabase) StoreCrossSigningSigsForTarget(
|
||||
ctx context.Context,
|
||||
originUserID string, originKeyID gomatrixserverlib.KeyID,
|
||||
targetUserID string, targetKeyID gomatrixserverlib.KeyID,
|
||||
signature gomatrixserverlib.Base64Bytes,
|
||||
) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
if err := d.CrossSigningSigsTable.UpsertCrossSigningSigsForTarget(ctx, nil, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil {
|
||||
return fmt.Errorf("d.CrossSigningSigsTable.InsertCrossSigningSigsForTarget: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteStaleDeviceLists deletes stale device list entries for users we don't share a room with anymore.
|
||||
func (d *KeyDatabase) DeleteStaleDeviceLists(
|
||||
ctx context.Context,
|
||||
userIDs []string,
|
||||
) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.StaleDeviceListsTable.DeleteStaleDeviceLists(ctx, txn, userIDs)
|
||||
})
|
||||
}
|
||||
|
|
101
userapi/storage/sqlite3/cross_signing_keys_table.go
Normal file
101
userapi/storage/sqlite3/cross_signing_keys_table.go
Normal file
|
@ -0,0 +1,101 @@
|
|||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/userapi/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
var crossSigningKeysSchema = `
|
||||
CREATE TABLE IF NOT EXISTS keyserver_cross_signing_keys (
|
||||
user_id TEXT NOT NULL,
|
||||
key_type INTEGER NOT NULL,
|
||||
key_data TEXT NOT NULL,
|
||||
PRIMARY KEY (user_id, key_type)
|
||||
);
|
||||
`
|
||||
|
||||
const selectCrossSigningKeysForUserSQL = "" +
|
||||
"SELECT key_type, key_data FROM keyserver_cross_signing_keys" +
|
||||
" WHERE user_id = $1"
|
||||
|
||||
const upsertCrossSigningKeysForUserSQL = "" +
|
||||
"INSERT OR REPLACE INTO keyserver_cross_signing_keys (user_id, key_type, key_data)" +
|
||||
" VALUES($1, $2, $3)"
|
||||
|
||||
type crossSigningKeysStatements struct {
|
||||
db *sql.DB
|
||||
selectCrossSigningKeysForUserStmt *sql.Stmt
|
||||
upsertCrossSigningKeysForUserStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error) {
|
||||
s := &crossSigningKeysStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(crossSigningKeysSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.selectCrossSigningKeysForUserStmt, selectCrossSigningKeysForUserSQL},
|
||||
{&s.upsertCrossSigningKeysForUserStmt, upsertCrossSigningKeysForUserSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser(
|
||||
ctx context.Context, txn *sql.Tx, userID string,
|
||||
) (r types.CrossSigningKeyMap, err error) {
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningKeysForUserStmt).QueryContext(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningKeysForUserStmt: rows.close() failed")
|
||||
r = types.CrossSigningKeyMap{}
|
||||
for rows.Next() {
|
||||
var keyTypeInt int16
|
||||
var keyData gomatrixserverlib.Base64Bytes
|
||||
if err := rows.Scan(&keyTypeInt, &keyData); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keyType, ok := types.KeyTypeIntToPurpose[keyTypeInt]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown key purpose int %d", keyTypeInt)
|
||||
}
|
||||
r[keyType] = keyData
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser(
|
||||
ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes,
|
||||
) error {
|
||||
keyTypeInt, ok := types.KeyTypePurposeToInt[keyType]
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown key purpose %q", keyType)
|
||||
}
|
||||
if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningKeysForUserStmt).ExecContext(ctx, userID, keyTypeInt, keyData); err != nil {
|
||||
return fmt.Errorf("s.upsertCrossSigningKeysForUserStmt: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
129
userapi/storage/sqlite3/cross_signing_sigs_table.go
Normal file
129
userapi/storage/sqlite3/cross_signing_sigs_table.go
Normal file
|
@ -0,0 +1,129 @@
|
|||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/userapi/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
var crossSigningSigsSchema = `
|
||||
CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs (
|
||||
origin_user_id TEXT NOT NULL,
|
||||
origin_key_id TEXT NOT NULL,
|
||||
target_user_id TEXT NOT NULL,
|
||||
target_key_id TEXT NOT NULL,
|
||||
signature TEXT NOT NULL,
|
||||
PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id);
|
||||
`
|
||||
|
||||
const selectCrossSigningSigsForTargetSQL = "" +
|
||||
"SELECT origin_user_id, origin_key_id, signature FROM keyserver_cross_signing_sigs" +
|
||||
" WHERE (origin_user_id = $1 OR origin_user_id = $2) AND target_user_id = $3 AND target_key_id = $4"
|
||||
|
||||
const upsertCrossSigningSigsForTargetSQL = "" +
|
||||
"INSERT OR REPLACE INTO keyserver_cross_signing_sigs (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" +
|
||||
" VALUES($1, $2, $3, $4, $5)"
|
||||
|
||||
const deleteCrossSigningSigsForTargetSQL = "" +
|
||||
"DELETE FROM keyserver_cross_signing_sigs WHERE target_user_id=$1 AND target_key_id=$2"
|
||||
|
||||
type crossSigningSigsStatements struct {
|
||||
db *sql.DB
|
||||
selectCrossSigningSigsForTargetStmt *sql.Stmt
|
||||
upsertCrossSigningSigsForTargetStmt *sql.Stmt
|
||||
deleteCrossSigningSigsForTargetStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, error) {
|
||||
s := &crossSigningSigsStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(crossSigningSigsSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m := sqlutil.NewMigrator(db)
|
||||
m.AddMigrations(sqlutil.Migration{
|
||||
Version: "keyserver: cross signing signature indexes",
|
||||
Up: deltas.UpFixCrossSigningSignatureIndexes,
|
||||
})
|
||||
if err = m.Up(context.Background()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.selectCrossSigningSigsForTargetStmt, selectCrossSigningSigsForTargetSQL},
|
||||
{&s.upsertCrossSigningSigsForTargetStmt, upsertCrossSigningSigsForTargetSQL},
|
||||
{&s.deleteCrossSigningSigsForTargetStmt, deleteCrossSigningSigsForTargetSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget(
|
||||
ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID,
|
||||
) (r types.CrossSigningSigMap, err error) {
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningSigsForTargetStmt).QueryContext(ctx, originUserID, targetUserID, targetUserID, targetKeyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningSigsForOriginTargetStmt: rows.close() failed")
|
||||
r = types.CrossSigningSigMap{}
|
||||
for rows.Next() {
|
||||
var userID string
|
||||
var keyID gomatrixserverlib.KeyID
|
||||
var signature gomatrixserverlib.Base64Bytes
|
||||
if err := rows.Scan(&userID, &keyID, &signature); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, ok := r[userID]; !ok {
|
||||
r[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
|
||||
}
|
||||
r[userID][keyID] = signature
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *crossSigningSigsStatements) UpsertCrossSigningSigsForTarget(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
originUserID string, originKeyID gomatrixserverlib.KeyID,
|
||||
targetUserID string, targetKeyID gomatrixserverlib.KeyID,
|
||||
signature gomatrixserverlib.Base64Bytes,
|
||||
) error {
|
||||
if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningSigsForTargetStmt).ExecContext(ctx, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil {
|
||||
return fmt.Errorf("s.upsertCrossSigningSigsForTargetStmt: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *crossSigningSigsStatements) DeleteCrossSigningSigsForTarget(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
targetUserID string, targetKeyID gomatrixserverlib.KeyID,
|
||||
) error {
|
||||
if _, err := sqlutil.TxStmt(txn, s.deleteCrossSigningSigsForTargetStmt).ExecContext(ctx, targetUserID, targetKeyID); err != nil {
|
||||
return fmt.Errorf("s.deleteCrossSigningSigsForTargetStmt: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,66 @@
|
|||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package deltas
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func UpRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error {
|
||||
// start counting from the last max offset, else 0.
|
||||
var maxOffset int64
|
||||
var userID string
|
||||
_ = tx.QueryRowContext(ctx, `SELECT user_id, MAX(log_offset) FROM keyserver_key_changes GROUP BY user_id`).Scan(&userID, &maxOffset)
|
||||
|
||||
_, err := tx.ExecContext(ctx, `
|
||||
-- make the new table
|
||||
DROP TABLE IF EXISTS keyserver_key_changes;
|
||||
CREATE TABLE IF NOT EXISTS keyserver_key_changes (
|
||||
change_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
-- The key owner
|
||||
user_id TEXT NOT NULL,
|
||||
UNIQUE (user_id)
|
||||
);
|
||||
`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute upgrade: %w", err)
|
||||
}
|
||||
// to start counting from maxOffset, insert a row with that value
|
||||
if userID != "" {
|
||||
_, err = tx.ExecContext(ctx, `INSERT INTO keyserver_key_changes(change_id, user_id) VALUES($1, $2)`, maxOffset, userID)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func DownRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error {
|
||||
_, err := tx.ExecContext(ctx, `
|
||||
-- Drop all data and revert back, we can't keep the data as Kafka offsets determine the numbers
|
||||
DROP TABLE IF EXISTS keyserver_key_changes;
|
||||
CREATE TABLE IF NOT EXISTS keyserver_key_changes (
|
||||
partition BIGINT NOT NULL,
|
||||
offset BIGINT NOT NULL,
|
||||
-- The key owner
|
||||
user_id TEXT NOT NULL,
|
||||
UNIQUE (partition, offset)
|
||||
);
|
||||
`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute downgrade: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,71 @@
|
|||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package deltas
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func UpFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error {
|
||||
_, err := tx.ExecContext(ctx, `
|
||||
CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs_tmp (
|
||||
origin_user_id TEXT NOT NULL,
|
||||
origin_key_id TEXT NOT NULL,
|
||||
target_user_id TEXT NOT NULL,
|
||||
target_key_id TEXT NOT NULL,
|
||||
signature TEXT NOT NULL,
|
||||
PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id)
|
||||
);
|
||||
|
||||
INSERT INTO keyserver_cross_signing_sigs_tmp (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)
|
||||
SELECT origin_user_id, origin_key_id, target_user_id, target_key_id, signature FROM keyserver_cross_signing_sigs;
|
||||
|
||||
DROP TABLE keyserver_cross_signing_sigs;
|
||||
ALTER TABLE keyserver_cross_signing_sigs_tmp RENAME TO keyserver_cross_signing_sigs;
|
||||
|
||||
CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id);
|
||||
`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute upgrade: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func DownFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error {
|
||||
_, err := tx.ExecContext(ctx, `
|
||||
CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs_tmp (
|
||||
origin_user_id TEXT NOT NULL,
|
||||
origin_key_id TEXT NOT NULL,
|
||||
target_user_id TEXT NOT NULL,
|
||||
target_key_id TEXT NOT NULL,
|
||||
signature TEXT NOT NULL,
|
||||
PRIMARY KEY (origin_user_id, target_user_id, target_key_id)
|
||||
);
|
||||
|
||||
INSERT INTO keyserver_cross_signing_sigs_tmp (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)
|
||||
SELECT origin_user_id, origin_key_id, target_user_id, target_key_id, signature FROM keyserver_cross_signing_sigs;
|
||||
|
||||
DROP TABLE keyserver_cross_signing_sigs;
|
||||
ALTER TABLE keyserver_cross_signing_sigs_tmp RENAME TO keyserver_cross_signing_sigs;
|
||||
|
||||
DELETE INDEX IF EXISTS keyserver_cross_signing_sigs_idx;
|
||||
`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute downgrade: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
213
userapi/storage/sqlite3/device_keys_table.go
Normal file
213
userapi/storage/sqlite3/device_keys_table.go
Normal file
|
@ -0,0 +1,213 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
)
|
||||
|
||||
var deviceKeysSchema = `
|
||||
-- Stores device keys for users
|
||||
CREATE TABLE IF NOT EXISTS keyserver_device_keys (
|
||||
user_id TEXT NOT NULL,
|
||||
device_id TEXT NOT NULL,
|
||||
ts_added_secs BIGINT NOT NULL,
|
||||
key_json TEXT NOT NULL,
|
||||
stream_id BIGINT NOT NULL,
|
||||
display_name TEXT,
|
||||
-- Clobber based on tuple of user/device.
|
||||
UNIQUE (user_id, device_id)
|
||||
);
|
||||
`
|
||||
|
||||
const upsertDeviceKeysSQL = "" +
|
||||
"INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" +
|
||||
" VALUES ($1, $2, $3, $4, $5, $6)" +
|
||||
" ON CONFLICT (user_id, device_id)" +
|
||||
" DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6"
|
||||
|
||||
const selectDeviceKeysSQL = "" +
|
||||
"SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
|
||||
|
||||
const selectBatchDeviceKeysSQL = "" +
|
||||
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''"
|
||||
|
||||
const selectBatchDeviceKeysWithEmptiesSQL = "" +
|
||||
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1"
|
||||
|
||||
const selectMaxStreamForUserSQL = "" +
|
||||
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
|
||||
|
||||
const countStreamIDsForUserSQL = "" +
|
||||
"SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id IN ($2)"
|
||||
|
||||
const deleteDeviceKeysSQL = "" +
|
||||
"DELETE FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
|
||||
|
||||
const deleteAllDeviceKeysSQL = "" +
|
||||
"DELETE FROM keyserver_device_keys WHERE user_id=$1"
|
||||
|
||||
type deviceKeysStatements struct {
|
||||
db *sql.DB
|
||||
upsertDeviceKeysStmt *sql.Stmt
|
||||
selectDeviceKeysStmt *sql.Stmt
|
||||
selectBatchDeviceKeysStmt *sql.Stmt
|
||||
selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt
|
||||
selectMaxStreamForUserStmt *sql.Stmt
|
||||
deleteDeviceKeysStmt *sql.Stmt
|
||||
deleteAllDeviceKeysStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
||||
s := &deviceKeysStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(deviceKeysSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.upsertDeviceKeysStmt, upsertDeviceKeysSQL},
|
||||
{&s.selectDeviceKeysStmt, selectDeviceKeysSQL},
|
||||
{&s.selectBatchDeviceKeysStmt, selectBatchDeviceKeysSQL},
|
||||
{&s.selectBatchDeviceKeysWithEmptiesStmt, selectBatchDeviceKeysWithEmptiesSQL},
|
||||
{&s.selectMaxStreamForUserStmt, selectMaxStreamForUserSQL},
|
||||
// {&s.countStreamIDsForUserStmt, countStreamIDsForUserSQL}, // prepared at runtime
|
||||
{&s.deleteDeviceKeysStmt, deleteDeviceKeysSQL},
|
||||
{&s.deleteAllDeviceKeysStmt, deleteAllDeviceKeysSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.deleteDeviceKeysStmt).ExecContext(ctx, userID, deviceID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
|
||||
deviceIDMap := make(map[string]bool)
|
||||
for _, d := range deviceIDs {
|
||||
deviceIDMap[d] = true
|
||||
}
|
||||
var stmt *sql.Stmt
|
||||
if includeEmpty {
|
||||
stmt = s.selectBatchDeviceKeysWithEmptiesStmt
|
||||
} else {
|
||||
stmt = s.selectBatchDeviceKeysStmt
|
||||
}
|
||||
rows, err := stmt.QueryContext(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed")
|
||||
var result []api.DeviceMessage
|
||||
var displayName sql.NullString
|
||||
for rows.Next() {
|
||||
dk := api.DeviceMessage{
|
||||
Type: api.TypeDeviceKeyUpdate,
|
||||
DeviceKeys: &api.DeviceKeys{
|
||||
UserID: userID,
|
||||
},
|
||||
}
|
||||
if err := rows.Scan(&dk.DeviceID, &dk.KeyJSON, &dk.StreamID, &displayName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if displayName.Valid {
|
||||
dk.DisplayName = displayName.String
|
||||
}
|
||||
// include the key if we want all keys (no device) or it was asked
|
||||
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
|
||||
result = append(result, dk)
|
||||
}
|
||||
}
|
||||
return result, rows.Err()
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
|
||||
for i, key := range keys {
|
||||
var keyJSONStr string
|
||||
var streamID int64
|
||||
var displayName sql.NullString
|
||||
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
// this will be '' when there is no device
|
||||
keys[i].Type = api.TypeDeviceKeyUpdate
|
||||
keys[i].KeyJSON = []byte(keyJSONStr)
|
||||
keys[i].StreamID = streamID
|
||||
if displayName.Valid {
|
||||
keys[i].DisplayName = displayName.String
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error) {
|
||||
// nullable if there are no results
|
||||
var nullStream sql.NullInt64
|
||||
err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
|
||||
if err == sql.ErrNoRows {
|
||||
err = nil
|
||||
}
|
||||
if nullStream.Valid {
|
||||
streamID = nullStream.Int64
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) {
|
||||
iStreamIDs := make([]interface{}, len(streamIDs)+1)
|
||||
iStreamIDs[0] = userID
|
||||
for i := range streamIDs {
|
||||
iStreamIDs[i+1] = streamIDs[i]
|
||||
}
|
||||
query := strings.Replace(countStreamIDsForUserSQL, "($2)", sqlutil.QueryVariadicOffset(len(streamIDs), 1), 1)
|
||||
// nullable if there are no results
|
||||
var count sql.NullInt64
|
||||
err := s.db.QueryRowContext(ctx, query, iStreamIDs...).Scan(&count)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if count.Valid {
|
||||
return int(count.Int64), nil
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
|
||||
for _, key := range keys {
|
||||
now := time.Now().Unix()
|
||||
_, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext(
|
||||
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -151,7 +151,7 @@ func (s *devicesStatements) InsertDevice(
|
|||
if _, err := insertStmt.ExecContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &api.Device{
|
||||
dev := &api.Device{
|
||||
ID: id,
|
||||
UserID: userutil.MakeUserID(localpart, serverName),
|
||||
AccessToken: accessToken,
|
||||
|
@ -159,7 +159,11 @@ func (s *devicesStatements) InsertDevice(
|
|||
LastSeenTS: createdTimeMS,
|
||||
LastSeenIP: ipAddr,
|
||||
UserAgent: userAgent,
|
||||
}, nil
|
||||
}
|
||||
if displayName != nil {
|
||||
dev.DisplayName = *displayName
|
||||
}
|
||||
return dev, nil
|
||||
}
|
||||
|
||||
func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn *sql.Tx, id,
|
||||
|
@ -172,7 +176,7 @@ func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn *
|
|||
if _, err := insertStmt.ExecContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &api.Device{
|
||||
dev := &api.Device{
|
||||
ID: id,
|
||||
UserID: userutil.MakeUserID(localpart, serverName),
|
||||
AccessToken: accessToken,
|
||||
|
@ -180,7 +184,11 @@ func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn *
|
|||
LastSeenTS: createdTimeMS,
|
||||
LastSeenIP: ipAddr,
|
||||
UserAgent: userAgent,
|
||||
}, nil
|
||||
}
|
||||
if displayName != nil {
|
||||
dev.DisplayName = *displayName
|
||||
}
|
||||
return dev, nil
|
||||
}
|
||||
|
||||
func (s *devicesStatements) DeleteDevice(
|
||||
|
@ -202,6 +210,7 @@ func (s *devicesStatements) DeleteDevices(
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, prep, "DeleteDevices.StmtClose() failed")
|
||||
stmt := sqlutil.TxStmt(txn, prep)
|
||||
params := make([]interface{}, len(devices)+2)
|
||||
params[0] = localpart
|
||||
|
|
|
@ -52,7 +52,7 @@ const updateBackupKeySQL = "" +
|
|||
const countKeysSQL = "" +
|
||||
"SELECT COUNT(*) FROM userapi_key_backups WHERE user_id = $1 AND version = $2"
|
||||
|
||||
const selectKeysSQL = "" +
|
||||
const selectBackupKeysSQL = "" +
|
||||
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " +
|
||||
"WHERE user_id = $1 AND version = $2"
|
||||
|
||||
|
@ -83,7 +83,7 @@ func NewSQLiteKeyBackupTable(db *sql.DB) (tables.KeyBackupTable, error) {
|
|||
{&s.insertBackupKeyStmt, insertBackupKeySQL},
|
||||
{&s.updateBackupKeyStmt, updateBackupKeySQL},
|
||||
{&s.countKeysStmt, countKeysSQL},
|
||||
{&s.selectKeysStmt, selectKeysSQL},
|
||||
{&s.selectKeysStmt, selectBackupKeysSQL},
|
||||
{&s.selectKeysByRoomIDStmt, selectKeysByRoomIDSQL},
|
||||
{&s.selectKeysByRoomIDAndSessionIDStmt, selectKeysByRoomIDAndSessionIDSQL},
|
||||
}.Prepare(db)
|
||||
|
|
125
userapi/storage/sqlite3/key_changes_table.go
Normal file
125
userapi/storage/sqlite3/key_changes_table.go
Normal file
|
@ -0,0 +1,125 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
)
|
||||
|
||||
var keyChangesSchema = `
|
||||
-- Stores key change information about users. Used to determine when to send updated device lists to clients.
|
||||
CREATE TABLE IF NOT EXISTS keyserver_key_changes (
|
||||
change_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
-- The key owner
|
||||
user_id TEXT NOT NULL,
|
||||
UNIQUE (user_id)
|
||||
);
|
||||
`
|
||||
|
||||
// Replace based on user ID. We don't care how many times the user's keys have changed, only that they
|
||||
// have changed, hence we can just keep bumping the change ID for this user.
|
||||
const upsertKeyChangeSQL = "" +
|
||||
"INSERT OR REPLACE INTO keyserver_key_changes (user_id)" +
|
||||
" VALUES ($1)" +
|
||||
" RETURNING change_id"
|
||||
|
||||
const selectKeyChangesSQL = "" +
|
||||
"SELECT user_id, change_id FROM keyserver_key_changes WHERE change_id > $1 AND change_id <= $2"
|
||||
|
||||
type keyChangesStatements struct {
|
||||
db *sql.DB
|
||||
upsertKeyChangeStmt *sql.Stmt
|
||||
selectKeyChangesStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
|
||||
s := &keyChangesStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(keyChangesSchema)
|
||||
if err != nil {
|
||||
return s, err
|
||||
}
|
||||
|
||||
if err = executeMigration(context.Background(), db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.upsertKeyChangeStmt, upsertKeyChangeSQL},
|
||||
{&s.selectKeyChangesStmt, selectKeyChangesSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func executeMigration(ctx context.Context, db *sql.DB) error {
|
||||
// TODO: Remove when we are sure we are not having goose artefacts in the db
|
||||
// This forces an error, which indicates the migration is already applied, since the
|
||||
// column partition was removed from the table
|
||||
migrationName := "keyserver: refactor key changes"
|
||||
|
||||
var cName string
|
||||
err := db.QueryRowContext(ctx, `SELECT p.name FROM sqlite_master AS m JOIN pragma_table_info(m.name) AS p WHERE m.name = 'keyserver_key_changes' AND p.name = 'partition'`).Scan(&cName)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) { // migration was already executed, as the column was removed
|
||||
if err = sqlutil.InsertMigration(ctx, db, migrationName); err != nil {
|
||||
return fmt.Errorf("unable to manually insert migration '%s': %w", migrationName, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
m := sqlutil.NewMigrator(db)
|
||||
m.AddMigrations(sqlutil.Migration{
|
||||
Version: migrationName,
|
||||
Up: deltas.UpRefactorKeyChanges,
|
||||
})
|
||||
return m.Up(ctx)
|
||||
}
|
||||
|
||||
func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, userID string) (changeID int64, err error) {
|
||||
err = s.upsertKeyChangeStmt.QueryRowContext(ctx, userID).Scan(&changeID)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *keyChangesStatements) SelectKeyChanges(
|
||||
ctx context.Context, fromOffset, toOffset int64,
|
||||
) (userIDs []string, latestOffset int64, err error) {
|
||||
latestOffset = fromOffset
|
||||
rows, err := s.selectKeyChangesStmt.QueryContext(ctx, fromOffset, toOffset)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectKeyChangesStmt: rows.close() failed")
|
||||
for rows.Next() {
|
||||
var userID string
|
||||
var offset int64
|
||||
if err := rows.Scan(&userID, &offset); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if offset > latestOffset {
|
||||
latestOffset = offset
|
||||
}
|
||||
userIDs = append(userIDs, userID)
|
||||
}
|
||||
return
|
||||
}
|
208
userapi/storage/sqlite3/one_time_keys_table.go
Normal file
208
userapi/storage/sqlite3/one_time_keys_table.go
Normal file
|
@ -0,0 +1,208 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
)
|
||||
|
||||
var oneTimeKeysSchema = `
|
||||
-- Stores one-time public keys for users
|
||||
CREATE TABLE IF NOT EXISTS keyserver_one_time_keys (
|
||||
user_id TEXT NOT NULL,
|
||||
device_id TEXT NOT NULL,
|
||||
key_id TEXT NOT NULL,
|
||||
algorithm TEXT NOT NULL,
|
||||
ts_added_secs BIGINT NOT NULL,
|
||||
key_json TEXT NOT NULL,
|
||||
-- Clobber based on 4-uple of user/device/key/algorithm.
|
||||
UNIQUE (user_id, device_id, key_id, algorithm)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS keyserver_one_time_keys_idx ON keyserver_one_time_keys (user_id, device_id);
|
||||
`
|
||||
|
||||
const upsertKeysSQL = "" +
|
||||
"INSERT INTO keyserver_one_time_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json)" +
|
||||
" VALUES ($1, $2, $3, $4, $5, $6)" +
|
||||
" ON CONFLICT (user_id, device_id, key_id, algorithm)" +
|
||||
" DO UPDATE SET key_json = $6"
|
||||
|
||||
const selectOneTimeKeysSQL = "" +
|
||||
"SELECT key_id, algorithm, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2"
|
||||
|
||||
const selectKeysCountSQL = "" +
|
||||
"SELECT algorithm, COUNT(key_id) FROM " +
|
||||
" (SELECT algorithm, key_id FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 LIMIT 100)" +
|
||||
" x GROUP BY algorithm"
|
||||
|
||||
const deleteOneTimeKeySQL = "" +
|
||||
"DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 AND key_id = $4"
|
||||
|
||||
const selectKeyByAlgorithmSQL = "" +
|
||||
"SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1"
|
||||
|
||||
const deleteOneTimeKeysSQL = "" +
|
||||
"DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2"
|
||||
|
||||
type oneTimeKeysStatements struct {
|
||||
db *sql.DB
|
||||
upsertKeysStmt *sql.Stmt
|
||||
selectKeysStmt *sql.Stmt
|
||||
selectKeysCountStmt *sql.Stmt
|
||||
selectKeyByAlgorithmStmt *sql.Stmt
|
||||
deleteOneTimeKeyStmt *sql.Stmt
|
||||
deleteOneTimeKeysStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
|
||||
s := &oneTimeKeysStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(oneTimeKeysSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.upsertKeysStmt, upsertKeysSQL},
|
||||
{&s.selectKeysStmt, selectOneTimeKeysSQL},
|
||||
{&s.selectKeysCountStmt, selectKeysCountSQL},
|
||||
{&s.selectKeyByAlgorithmStmt, selectKeyByAlgorithmSQL},
|
||||
{&s.deleteOneTimeKeyStmt, deleteOneTimeKeySQL},
|
||||
{&s.deleteOneTimeKeysStmt, deleteOneTimeKeysSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
|
||||
rows, err := s.selectKeysStmt.QueryContext(ctx, userID, deviceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt: rows.close() failed")
|
||||
|
||||
wantSet := make(map[string]bool, len(keyIDsWithAlgorithms))
|
||||
for _, ka := range keyIDsWithAlgorithms {
|
||||
wantSet[ka] = true
|
||||
}
|
||||
|
||||
result := make(map[string]json.RawMessage)
|
||||
for rows.Next() {
|
||||
var keyID string
|
||||
var algorithm string
|
||||
var keyJSONStr string
|
||||
if err := rows.Scan(&keyID, &algorithm, &keyJSONStr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keyIDWithAlgo := algorithm + ":" + keyID
|
||||
if wantSet[keyIDWithAlgo] {
|
||||
result[keyIDWithAlgo] = json.RawMessage(keyJSONStr)
|
||||
}
|
||||
}
|
||||
return result, rows.Err()
|
||||
}
|
||||
|
||||
func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) {
|
||||
counts := &api.OneTimeKeysCount{
|
||||
DeviceID: deviceID,
|
||||
UserID: userID,
|
||||
KeyCount: make(map[string]int),
|
||||
}
|
||||
rows, err := s.selectKeysCountStmt.QueryContext(ctx, userID, deviceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
|
||||
for rows.Next() {
|
||||
var algorithm string
|
||||
var count int
|
||||
if err = rows.Scan(&algorithm, &count); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
counts.KeyCount[algorithm] = count
|
||||
}
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
func (s *oneTimeKeysStatements) InsertOneTimeKeys(
|
||||
ctx context.Context, txn *sql.Tx, keys api.OneTimeKeys,
|
||||
) (*api.OneTimeKeysCount, error) {
|
||||
now := time.Now().Unix()
|
||||
counts := &api.OneTimeKeysCount{
|
||||
DeviceID: keys.DeviceID,
|
||||
UserID: keys.UserID,
|
||||
KeyCount: make(map[string]int),
|
||||
}
|
||||
for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
|
||||
algo, keyID := keys.Split(keyIDWithAlgo)
|
||||
_, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext(
|
||||
ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
|
||||
for rows.Next() {
|
||||
var algorithm string
|
||||
var count int
|
||||
if err = rows.Scan(&algorithm, &count); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
counts.KeyCount[algorithm] = count
|
||||
}
|
||||
|
||||
return counts, rows.Err()
|
||||
}
|
||||
|
||||
func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey(
|
||||
ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string,
|
||||
) (map[string]json.RawMessage, error) {
|
||||
var keyID string
|
||||
var keyJSON string
|
||||
err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
_, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if keyJSON == "" {
|
||||
return nil, nil
|
||||
}
|
||||
return map[string]json.RawMessage{
|
||||
algorithm + ":" + keyID: json.RawMessage(keyJSON),
|
||||
}, err
|
||||
}
|
||||
|
||||
func (s *oneTimeKeysStatements) DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.deleteOneTimeKeysStmt).ExecContext(ctx, userID, deviceID)
|
||||
return err
|
||||
}
|
145
userapi/storage/sqlite3/stale_device_lists.go
Normal file
145
userapi/storage/sqlite3/stale_device_lists.go
Normal file
|
@ -0,0 +1,145 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
var staleDeviceListsSchema = `
|
||||
-- Stores whether a user's device lists are stale or not.
|
||||
CREATE TABLE IF NOT EXISTS keyserver_stale_device_lists (
|
||||
user_id TEXT PRIMARY KEY NOT NULL,
|
||||
domain TEXT NOT NULL,
|
||||
is_stale BOOLEAN NOT NULL,
|
||||
ts_added_secs BIGINT NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS keyserver_stale_device_lists_idx ON keyserver_stale_device_lists (domain, is_stale);
|
||||
`
|
||||
|
||||
const upsertStaleDeviceListSQL = "" +
|
||||
"INSERT INTO keyserver_stale_device_lists (user_id, domain, is_stale, ts_added_secs)" +
|
||||
" VALUES ($1, $2, $3, $4)" +
|
||||
" ON CONFLICT (user_id)" +
|
||||
" DO UPDATE SET is_stale = $3, ts_added_secs = $4"
|
||||
|
||||
const selectStaleDeviceListsWithDomainsSQL = "" +
|
||||
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2 ORDER BY ts_added_secs DESC"
|
||||
|
||||
const selectStaleDeviceListsSQL = "" +
|
||||
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 ORDER BY ts_added_secs DESC"
|
||||
|
||||
const deleteStaleDevicesSQL = "" +
|
||||
"DELETE FROM keyserver_stale_device_lists WHERE user_id IN ($1)"
|
||||
|
||||
type staleDeviceListsStatements struct {
|
||||
db *sql.DB
|
||||
upsertStaleDeviceListStmt *sql.Stmt
|
||||
selectStaleDeviceListsWithDomainsStmt *sql.Stmt
|
||||
selectStaleDeviceListsStmt *sql.Stmt
|
||||
// deleteStaleDeviceListsStmt *sql.Stmt // Prepared at runtime
|
||||
}
|
||||
|
||||
func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) {
|
||||
s := &staleDeviceListsStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(staleDeviceListsSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.upsertStaleDeviceListStmt, upsertStaleDeviceListSQL},
|
||||
{&s.selectStaleDeviceListsStmt, selectStaleDeviceListsSQL},
|
||||
{&s.selectStaleDeviceListsWithDomainsStmt, selectStaleDeviceListsWithDomainsSQL},
|
||||
// { &s.deleteStaleDeviceListsStmt, deleteStaleDevicesSQL}, // Prepared at runtime
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error {
|
||||
_, domain, err := gomatrixserverlib.SplitID('@', userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, gomatrixserverlib.AsTimestamp(time.Now()))
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
|
||||
// we only query for 1 domain or all domains so optimise for those use cases
|
||||
if len(domains) == 0 {
|
||||
rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return rowsToUserIDs(ctx, rows)
|
||||
}
|
||||
var result []string
|
||||
for _, domain := range domains {
|
||||
rows, err := s.selectStaleDeviceListsWithDomainsStmt.QueryContext(ctx, true, string(domain))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userIDs, err := rowsToUserIDs(ctx, rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, userIDs...)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// DeleteStaleDeviceLists removes users from stale device lists
|
||||
func (s *staleDeviceListsStatements) DeleteStaleDeviceLists(
|
||||
ctx context.Context, txn *sql.Tx, userIDs []string,
|
||||
) error {
|
||||
qry := strings.Replace(deleteStaleDevicesSQL, "($1)", sqlutil.QueryVariadic(len(userIDs)), 1)
|
||||
stmt, err := s.db.Prepare(qry)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, stmt, "DeleteStaleDeviceLists: stmt.Close failed")
|
||||
stmt = sqlutil.TxStmt(txn, stmt)
|
||||
|
||||
params := make([]any, len(userIDs))
|
||||
for i := range userIDs {
|
||||
params[i] = userIDs[i]
|
||||
}
|
||||
|
||||
_, err = stmt.ExecContext(ctx, params...)
|
||||
return err
|
||||
}
|
||||
|
||||
func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) {
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed")
|
||||
for rows.Next() {
|
||||
var userID string
|
||||
if err := rows.Scan(&userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, userID)
|
||||
}
|
||||
return result, rows.Err()
|
||||
}
|
|
@ -256,6 +256,7 @@ func (s *statsStatements) allUsers(ctx context.Context, txn *sql.Tx) (result int
|
|||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, queryStmt, "allUsers.StmtClose() failed")
|
||||
stmt := sqlutil.TxStmt(txn, queryStmt)
|
||||
err = stmt.QueryRowContext(ctx,
|
||||
1, 2, 3, 4,
|
||||
|
@ -269,6 +270,7 @@ func (s *statsStatements) nonBridgedUsers(ctx context.Context, txn *sql.Tx) (res
|
|||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, queryStmt, "nonBridgedUsers.StmtClose() failed")
|
||||
stmt := sqlutil.TxStmt(txn, queryStmt)
|
||||
err = stmt.QueryRowContext(ctx,
|
||||
1, 2, 3,
|
||||
|
@ -286,6 +288,7 @@ func (s *statsStatements) registeredUserByType(ctx context.Context, txn *sql.Tx)
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, queryStmt, "registeredUserByType.StmtClose() failed")
|
||||
stmt := sqlutil.TxStmt(txn, queryStmt)
|
||||
registeredAfter := time.Now().AddDate(0, 0, -30)
|
||||
|
||||
|
|
|
@ -30,8 +30,8 @@ import (
|
|||
"github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
|
||||
)
|
||||
|
||||
// NewDatabase creates a new accounts and profiles database
|
||||
func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (*shared.Database, error) {
|
||||
// NewUserDatabase creates a new accounts and profiles database
|
||||
func NewUserDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (*shared.Database, error) {
|
||||
db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -134,3 +134,44 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
|
|||
OpenIDTokenLifetimeMS: openIDTokenLifetimeMS,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func NewKeyDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.KeyDatabase, error) {
|
||||
db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
otk, err := NewSqliteOneTimeKeysTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dk, err := NewSqliteDeviceKeysTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
kc, err := NewSqliteKeyChangesTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sdl, err := NewSqliteStaleDeviceListsTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
csk, err := NewSqliteCrossSigningKeysTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
css, err := NewSqliteCrossSigningSigsTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &shared.KeyDatabase{
|
||||
OneTimeKeysTable: otk,
|
||||
DeviceKeysTable: dk,
|
||||
KeyChangesTable: kc,
|
||||
StaleDeviceListsTable: sdl,
|
||||
CrossSigningKeysTable: csk,
|
||||
CrossSigningSigsTable: css,
|
||||
Writer: writer,
|
||||
}, nil
|
||||
}
|
||||
|
|
|
@ -29,15 +29,36 @@ import (
|
|||
"github.com/matrix-org/dendrite/userapi/storage/sqlite3"
|
||||
)
|
||||
|
||||
// NewUserAPIDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
|
||||
// NewUserDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
|
||||
// and sets postgres connection parameters
|
||||
func NewUserAPIDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (Database, error) {
|
||||
func NewUserDatabase(
|
||||
base *base.BaseDendrite,
|
||||
dbProperties *config.DatabaseOptions,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
bcryptCost int,
|
||||
openIDTokenLifetimeMS int64,
|
||||
loginTokenLifetime time.Duration,
|
||||
serverNoticesLocalpart string,
|
||||
) (UserDatabase, error) {
|
||||
switch {
|
||||
case dbProperties.ConnectionString.IsSQLite():
|
||||
return sqlite3.NewDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart)
|
||||
return sqlite3.NewUserDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart)
|
||||
case dbProperties.ConnectionString.IsPostgres():
|
||||
return postgres.NewDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart)
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected database type")
|
||||
}
|
||||
}
|
||||
|
||||
// NewKeyDatabase opens a new Postgres or Sqlite database (base on dataSourceName) scheme)
|
||||
// and sets postgres connection parameters.
|
||||
func NewKeyDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (KeyDatabase, error) {
|
||||
switch {
|
||||
case dbProperties.ConnectionString.IsSQLite():
|
||||
return sqlite3.NewKeyDatabase(base, dbProperties)
|
||||
case dbProperties.ConnectionString.IsPostgres():
|
||||
return postgres.NewKeyDatabase(base, dbProperties)
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected database type")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,9 +4,12 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/userapi/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
@ -29,14 +32,14 @@ var (
|
|||
ctx = context.Background()
|
||||
)
|
||||
|
||||
func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
|
||||
func mustCreateUserDatabase(t *testing.T, dbType test.DBType) (storage.UserDatabase, func()) {
|
||||
base, baseclose := testrig.CreateBaseDendrite(t, dbType)
|
||||
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||
db, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{
|
||||
db, err := storage.NewUserDatabase(base, &config.DatabaseOptions{
|
||||
ConnectionString: config.DataSource(connStr),
|
||||
}, "localhost", bcrypt.MinCost, openIDLifetimeMS, loginTokenLifetime, "_server")
|
||||
if err != nil {
|
||||
t.Fatalf("NewUserAPIDatabase returned %s", err)
|
||||
t.Fatalf("NewUserDatabase returned %s", err)
|
||||
}
|
||||
return db, func() {
|
||||
close()
|
||||
|
@ -47,7 +50,7 @@ func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, fun
|
|||
// Tests storing and getting account data
|
||||
func Test_AccountData(t *testing.T) {
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
db, close := mustCreateUserDatabase(t, dbType)
|
||||
defer close()
|
||||
alice := test.NewUser(t)
|
||||
localpart, domain, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
|
@ -78,7 +81,7 @@ func Test_AccountData(t *testing.T) {
|
|||
// Tests the creation of accounts
|
||||
func Test_Accounts(t *testing.T) {
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
db, close := mustCreateUserDatabase(t, dbType)
|
||||
defer close()
|
||||
alice := test.NewUser(t)
|
||||
aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
|
@ -158,7 +161,7 @@ func Test_Devices(t *testing.T) {
|
|||
accessToken := util.RandomString(16)
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
db, close := mustCreateUserDatabase(t, dbType)
|
||||
defer close()
|
||||
|
||||
deviceWithID, err := db.CreateDevice(ctx, localpart, domain, &deviceID, accessToken, nil, "", "")
|
||||
|
@ -238,7 +241,7 @@ func Test_KeyBackup(t *testing.T) {
|
|||
room := test.NewRoom(t, alice)
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
db, close := mustCreateUserDatabase(t, dbType)
|
||||
defer close()
|
||||
|
||||
wantAuthData := json.RawMessage("my auth data")
|
||||
|
@ -315,7 +318,7 @@ func Test_KeyBackup(t *testing.T) {
|
|||
func Test_LoginToken(t *testing.T) {
|
||||
alice := test.NewUser(t)
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
db, close := mustCreateUserDatabase(t, dbType)
|
||||
defer close()
|
||||
|
||||
// create a new token
|
||||
|
@ -347,7 +350,7 @@ func Test_OpenID(t *testing.T) {
|
|||
token := util.RandomString(24)
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
db, close := mustCreateUserDatabase(t, dbType)
|
||||
defer close()
|
||||
|
||||
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + openIDLifetimeMS
|
||||
|
@ -368,7 +371,7 @@ func Test_Profile(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
db, close := mustCreateUserDatabase(t, dbType)
|
||||
defer close()
|
||||
|
||||
// create account, which also creates a profile
|
||||
|
@ -417,7 +420,7 @@ func Test_Pusher(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
db, close := mustCreateUserDatabase(t, dbType)
|
||||
defer close()
|
||||
|
||||
appID := util.RandomString(8)
|
||||
|
@ -468,7 +471,7 @@ func Test_ThreePID(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
db, close := mustCreateUserDatabase(t, dbType)
|
||||
defer close()
|
||||
threePID := util.RandomString(8)
|
||||
medium := util.RandomString(8)
|
||||
|
@ -507,7 +510,7 @@ func Test_Notification(t *testing.T) {
|
|||
room := test.NewRoom(t, alice)
|
||||
room2 := test.NewRoom(t, alice)
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
db, close := mustCreateUserDatabase(t, dbType)
|
||||
defer close()
|
||||
// generate some dummy notifications
|
||||
for i := 0; i < 10; i++ {
|
||||
|
@ -571,3 +574,184 @@ func Test_Notification(t *testing.T) {
|
|||
assert.Equal(t, int64(0), total)
|
||||
})
|
||||
}
|
||||
|
||||
func mustCreateKeyDatabase(t *testing.T, dbType test.DBType) (storage.KeyDatabase, func()) {
|
||||
base, close := testrig.CreateBaseDendrite(t, dbType)
|
||||
db, err := storage.NewKeyDatabase(base, &base.Cfg.KeyServer.Database)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create new database: %v", err)
|
||||
}
|
||||
return db, close
|
||||
}
|
||||
|
||||
func MustNotError(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
t.Fatalf("operation failed: %s", err)
|
||||
}
|
||||
|
||||
func TestKeyChanges(t *testing.T) {
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, clean := mustCreateKeyDatabase(t, dbType)
|
||||
defer clean()
|
||||
_, err := db.StoreKeyChange(ctx, "@alice:localhost")
|
||||
MustNotError(t, err)
|
||||
deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost")
|
||||
MustNotError(t, err)
|
||||
deviceChangeIDC, err := db.StoreKeyChange(ctx, "@charlie:localhost")
|
||||
MustNotError(t, err)
|
||||
userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDB, types.OffsetNewest)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to KeyChanges: %s", err)
|
||||
}
|
||||
if latest != deviceChangeIDC {
|
||||
t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDC)
|
||||
}
|
||||
if !reflect.DeepEqual(userIDs, []string{"@charlie:localhost"}) {
|
||||
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestKeyChangesNoDupes(t *testing.T) {
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, clean := mustCreateKeyDatabase(t, dbType)
|
||||
defer clean()
|
||||
deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost")
|
||||
MustNotError(t, err)
|
||||
deviceChangeIDB, err := db.StoreKeyChange(ctx, "@alice:localhost")
|
||||
MustNotError(t, err)
|
||||
if deviceChangeIDA == deviceChangeIDB {
|
||||
t.Fatalf("Expected change ID to be different even when inserting key change for the same user, got %d for both changes", deviceChangeIDA)
|
||||
}
|
||||
deviceChangeID, err := db.StoreKeyChange(ctx, "@alice:localhost")
|
||||
MustNotError(t, err)
|
||||
userIDs, latest, err := db.KeyChanges(ctx, 0, types.OffsetNewest)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to KeyChanges: %s", err)
|
||||
}
|
||||
if latest != deviceChangeID {
|
||||
t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeID)
|
||||
}
|
||||
if !reflect.DeepEqual(userIDs, []string{"@alice:localhost"}) {
|
||||
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestKeyChangesUpperLimit(t *testing.T) {
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, clean := mustCreateKeyDatabase(t, dbType)
|
||||
defer clean()
|
||||
deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost")
|
||||
MustNotError(t, err)
|
||||
deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost")
|
||||
MustNotError(t, err)
|
||||
_, err = db.StoreKeyChange(ctx, "@charlie:localhost")
|
||||
MustNotError(t, err)
|
||||
userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDA, deviceChangeIDB)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to KeyChanges: %s", err)
|
||||
}
|
||||
if latest != deviceChangeIDB {
|
||||
t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDB)
|
||||
}
|
||||
if !reflect.DeepEqual(userIDs, []string{"@bob:localhost"}) {
|
||||
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
var dbLock sync.Mutex
|
||||
var deviceArray = []string{"AAA", "another_device"}
|
||||
|
||||
// The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user,
|
||||
// and that they are returned correctly when querying for device keys.
|
||||
func TestDeviceKeysStreamIDGeneration(t *testing.T) {
|
||||
var err error
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, clean := mustCreateKeyDatabase(t, dbType)
|
||||
defer clean()
|
||||
alice := "@alice:TestDeviceKeysStreamIDGeneration"
|
||||
bob := "@bob:TestDeviceKeysStreamIDGeneration"
|
||||
msgs := []api.DeviceMessage{
|
||||
{
|
||||
Type: api.TypeDeviceKeyUpdate,
|
||||
DeviceKeys: &api.DeviceKeys{
|
||||
DeviceID: "AAA",
|
||||
UserID: alice,
|
||||
KeyJSON: []byte(`{"key":"v1"}`),
|
||||
},
|
||||
// StreamID: 1
|
||||
},
|
||||
{
|
||||
Type: api.TypeDeviceKeyUpdate,
|
||||
DeviceKeys: &api.DeviceKeys{
|
||||
DeviceID: "AAA",
|
||||
UserID: bob,
|
||||
KeyJSON: []byte(`{"key":"v1"}`),
|
||||
},
|
||||
// StreamID: 1 as this is a different user
|
||||
},
|
||||
{
|
||||
Type: api.TypeDeviceKeyUpdate,
|
||||
DeviceKeys: &api.DeviceKeys{
|
||||
DeviceID: "another_device",
|
||||
UserID: alice,
|
||||
KeyJSON: []byte(`{"key":"v1"}`),
|
||||
},
|
||||
// StreamID: 2 as this is a 2nd device key
|
||||
},
|
||||
}
|
||||
MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs))
|
||||
if msgs[0].StreamID != 1 {
|
||||
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID)
|
||||
}
|
||||
if msgs[1].StreamID != 1 {
|
||||
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID)
|
||||
}
|
||||
if msgs[2].StreamID != 2 {
|
||||
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID)
|
||||
}
|
||||
|
||||
// updating a device sets the next stream ID for that user
|
||||
msgs = []api.DeviceMessage{
|
||||
{
|
||||
Type: api.TypeDeviceKeyUpdate,
|
||||
DeviceKeys: &api.DeviceKeys{
|
||||
DeviceID: "AAA",
|
||||
UserID: alice,
|
||||
KeyJSON: []byte(`{"key":"v2"}`),
|
||||
},
|
||||
// StreamID: 3
|
||||
},
|
||||
}
|
||||
MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs))
|
||||
if msgs[0].StreamID != 3 {
|
||||
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID)
|
||||
}
|
||||
|
||||
dbLock.Lock()
|
||||
defer dbLock.Unlock()
|
||||
// Querying for device keys returns the latest stream IDs
|
||||
msgs, err = db.DeviceKeysForUser(ctx, alice, deviceArray, false)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("DeviceKeysForUser returned error: %s", err)
|
||||
}
|
||||
wantStreamIDs := map[string]int64{
|
||||
"AAA": 3,
|
||||
"another_device": 2,
|
||||
}
|
||||
if len(msgs) != len(wantStreamIDs) {
|
||||
t.Fatalf("DeviceKeysForUser: wrong number of devices, got %d want %d", len(msgs), len(wantStreamIDs))
|
||||
}
|
||||
for _, m := range msgs {
|
||||
if m.StreamID != wantStreamIDs[m.DeviceID] {
|
||||
t.Errorf("DeviceKeysForUser: wrong returned stream ID for key, got %d want %d", m.StreamID, wantStreamIDs[m.DeviceID])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -32,10 +32,10 @@ func NewUserAPIDatabase(
|
|||
openIDTokenLifetimeMS int64,
|
||||
loginTokenLifetime time.Duration,
|
||||
serverNoticesLocalpart string,
|
||||
) (Database, error) {
|
||||
) (UserDatabase, error) {
|
||||
switch {
|
||||
case dbProperties.ConnectionString.IsSQLite():
|
||||
return sqlite3.NewDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart)
|
||||
return sqlite3.NewUserDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart)
|
||||
case dbProperties.ConnectionString.IsPostgres():
|
||||
return nil, fmt.Errorf("can't use Postgres implementation")
|
||||
default:
|
||||
|
|
|
@ -20,10 +20,10 @@ import (
|
|||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/types"
|
||||
)
|
||||
|
||||
|
@ -145,3 +145,47 @@ const (
|
|||
// uint32.
|
||||
AllNotifications NotificationFilter = (1 << 31) - 1
|
||||
)
|
||||
|
||||
type OneTimeKeys interface {
|
||||
SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
|
||||
CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)
|
||||
InsertOneTimeKeys(ctx context.Context, txn *sql.Tx, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error)
|
||||
// SelectAndDeleteOneTimeKey selects a single one time key matching the user/device/algorithm specified and returns the algo:key_id => JSON.
|
||||
// Returns an empty map if the key does not exist.
|
||||
SelectAndDeleteOneTimeKey(ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string) (map[string]json.RawMessage, error)
|
||||
DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error
|
||||
}
|
||||
|
||||
type DeviceKeys interface {
|
||||
SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
|
||||
InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error
|
||||
SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error)
|
||||
CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error)
|
||||
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error)
|
||||
DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error
|
||||
DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error
|
||||
}
|
||||
|
||||
type KeyChanges interface {
|
||||
InsertKeyChange(ctx context.Context, userID string) (int64, error)
|
||||
// SelectKeyChanges returns the set (de-duplicated) of users who have changed their keys between the two offsets.
|
||||
// Results are exclusive of fromOffset and inclusive of toOffset. A toOffset of types.OffsetNewest means no upper offset.
|
||||
SelectKeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error)
|
||||
}
|
||||
|
||||
type StaleDeviceLists interface {
|
||||
InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error
|
||||
SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error)
|
||||
DeleteStaleDeviceLists(ctx context.Context, txn *sql.Tx, userIDs []string) error
|
||||
}
|
||||
|
||||
type CrossSigningKeys interface {
|
||||
SelectCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string) (r types.CrossSigningKeyMap, err error)
|
||||
UpsertCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes) error
|
||||
}
|
||||
|
||||
type CrossSigningSigs interface {
|
||||
SelectCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (r types.CrossSigningSigMap, err error)
|
||||
UpsertCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error
|
||||
DeleteCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, targetUserID string, targetKeyID gomatrixserverlib.KeyID) error
|
||||
}
|
||||
|
|
94
userapi/storage/tables/stale_device_lists_test.go
Normal file
94
userapi/storage/tables/stale_device_lists_test.go
Normal file
|
@ -0,0 +1,94 @@
|
|||
package tables_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/dendrite/userapi/storage/postgres"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/sqlite3"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
)
|
||||
|
||||
func mustCreateTable(t *testing.T, dbType test.DBType) (tab tables.StaleDeviceLists, close func()) {
|
||||
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||
db, err := sqlutil.Open(&config.DatabaseOptions{
|
||||
ConnectionString: config.DataSource(connStr),
|
||||
}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open database: %s", err)
|
||||
}
|
||||
switch dbType {
|
||||
case test.DBTypePostgres:
|
||||
tab, err = postgres.NewPostgresStaleDeviceListsTable(db)
|
||||
case test.DBTypeSQLite:
|
||||
tab, err = sqlite3.NewSqliteStaleDeviceListsTable(db)
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create new table: %s", err)
|
||||
}
|
||||
return tab, close
|
||||
}
|
||||
|
||||
func TestStaleDeviceLists(t *testing.T) {
|
||||
alice := test.NewUser(t)
|
||||
bob := test.NewUser(t)
|
||||
charlie := "@charlie:localhost"
|
||||
ctx := context.Background()
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
tab, closeDB := mustCreateTable(t, dbType)
|
||||
defer closeDB()
|
||||
|
||||
if err := tab.InsertStaleDeviceList(ctx, alice.ID, true); err != nil {
|
||||
t.Fatalf("failed to insert stale device: %s", err)
|
||||
}
|
||||
if err := tab.InsertStaleDeviceList(ctx, bob.ID, true); err != nil {
|
||||
t.Fatalf("failed to insert stale device: %s", err)
|
||||
}
|
||||
if err := tab.InsertStaleDeviceList(ctx, charlie, true); err != nil {
|
||||
t.Fatalf("failed to insert stale device: %s", err)
|
||||
}
|
||||
|
||||
// Query one server
|
||||
wantStaleUsers := []string{alice.ID, bob.ID}
|
||||
gotStaleUsers, err := tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to query stale device lists: %s", err)
|
||||
}
|
||||
if !test.UnsortedStringSliceEqual(wantStaleUsers, gotStaleUsers) {
|
||||
t.Fatalf("expected stale users %v, got %v", wantStaleUsers, gotStaleUsers)
|
||||
}
|
||||
|
||||
// Query all servers
|
||||
wantStaleUsers = []string{alice.ID, bob.ID, charlie}
|
||||
gotStaleUsers, err = tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to query stale device lists: %s", err)
|
||||
}
|
||||
if !test.UnsortedStringSliceEqual(wantStaleUsers, gotStaleUsers) {
|
||||
t.Fatalf("expected stale users %v, got %v", wantStaleUsers, gotStaleUsers)
|
||||
}
|
||||
|
||||
// Delete stale devices
|
||||
deleteUsers := []string{alice.ID, bob.ID}
|
||||
if err = tab.DeleteStaleDeviceLists(ctx, nil, deleteUsers); err != nil {
|
||||
t.Fatalf("failed to delete stale device lists: %s", err)
|
||||
}
|
||||
|
||||
// Verify we don't get anything back after deleting
|
||||
gotStaleUsers, err = tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to query stale device lists: %s", err)
|
||||
}
|
||||
|
||||
if gotCount := len(gotStaleUsers); gotCount > 0 {
|
||||
t.Fatalf("expected no stale users, got %d", gotCount)
|
||||
}
|
||||
})
|
||||
}
|
50
userapi/types/storage.go
Normal file
50
userapi/types/storage.go
Normal file
|
@ -0,0 +1,50 @@
|
|||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package types
|
||||
|
||||
import (
|
||||
"math"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const (
|
||||
// OffsetNewest tells e.g. the database to get the most current data
|
||||
OffsetNewest int64 = math.MaxInt64
|
||||
// OffsetOldest tells e.g. the database to get the oldest data
|
||||
OffsetOldest int64 = 0
|
||||
)
|
||||
|
||||
// KeyTypePurposeToInt maps a purpose to an integer, which is used in the
|
||||
// database to reduce the amount of space taken up by this column.
|
||||
var KeyTypePurposeToInt = map[gomatrixserverlib.CrossSigningKeyPurpose]int16{
|
||||
gomatrixserverlib.CrossSigningKeyPurposeMaster: 1,
|
||||
gomatrixserverlib.CrossSigningKeyPurposeSelfSigning: 2,
|
||||
gomatrixserverlib.CrossSigningKeyPurposeUserSigning: 3,
|
||||
}
|
||||
|
||||
// KeyTypeIntToPurpose maps an integer to a purpose, which is used in the
|
||||
// database to reduce the amount of space taken up by this column.
|
||||
var KeyTypeIntToPurpose = map[int16]gomatrixserverlib.CrossSigningKeyPurpose{
|
||||
1: gomatrixserverlib.CrossSigningKeyPurposeMaster,
|
||||
2: gomatrixserverlib.CrossSigningKeyPurposeSelfSigning,
|
||||
3: gomatrixserverlib.CrossSigningKeyPurposeUserSigning,
|
||||
}
|
||||
|
||||
// Map of purpose -> public key
|
||||
type CrossSigningKeyMap map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.Base64Bytes
|
||||
|
||||
// Map of user ID -> key ID -> signature
|
||||
type CrossSigningSigMap map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes
|
|
@ -17,13 +17,11 @@ package userapi
|
|||
import (
|
||||
"time"
|
||||
|
||||
fedsenderapi "github.com/matrix-org/dendrite/federationapi/api"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/pushgateway"
|
||||
keyapi "github.com/matrix-org/dendrite/keyserver/api"
|
||||
rsapi "github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/setup/base"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/consumers"
|
||||
|
@ -33,16 +31,20 @@ import (
|
|||
"github.com/matrix-org/dendrite/userapi/util"
|
||||
)
|
||||
|
||||
// NewInternalAPI returns a concerete implementation of the internal API. Callers
|
||||
// NewInternalAPI returns a concrete implementation of the internal API. Callers
|
||||
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
|
||||
func NewInternalAPI(
|
||||
base *base.BaseDendrite, cfg *config.UserAPI,
|
||||
appServices []config.ApplicationService, keyAPI keyapi.UserKeyAPI,
|
||||
rsAPI rsapi.UserRoomserverAPI, pgClient pushgateway.Client,
|
||||
) api.UserInternalAPI {
|
||||
base *base.BaseDendrite,
|
||||
rsAPI rsapi.UserRoomserverAPI,
|
||||
fedClient fedsenderapi.KeyserverFederationAPI,
|
||||
) *internal.UserInternalAPI {
|
||||
cfg := &base.Cfg.UserAPI
|
||||
js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream)
|
||||
appServices := base.Cfg.Derived.ApplicationServices
|
||||
|
||||
db, err := storage.NewUserAPIDatabase(
|
||||
pgClient := base.PushGatewayHTTPClient()
|
||||
|
||||
db, err := storage.NewUserDatabase(
|
||||
base,
|
||||
&cfg.AccountDatabase,
|
||||
cfg.Matrix.ServerName,
|
||||
|
@ -55,6 +57,11 @@ func NewInternalAPI(
|
|||
logrus.WithError(err).Panicf("failed to connect to accounts db")
|
||||
}
|
||||
|
||||
keyDB, err := storage.NewKeyDatabase(base, &base.Cfg.KeyServer.Database)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Panicf("failed to connect to key db")
|
||||
}
|
||||
|
||||
syncProducer := producers.NewSyncAPI(
|
||||
db, js,
|
||||
// TODO: user API should handle syncs for account data. Right now,
|
||||
|
@ -64,17 +71,50 @@ func NewInternalAPI(
|
|||
cfg.Matrix.JetStream.Prefixed(jetstream.OutputClientData),
|
||||
cfg.Matrix.JetStream.Prefixed(jetstream.OutputNotificationData),
|
||||
)
|
||||
keyChangeProducer := &producers.KeyChange{
|
||||
Topic: cfg.Matrix.JetStream.Prefixed(jetstream.OutputKeyChangeEvent),
|
||||
JetStream: js,
|
||||
DB: keyDB,
|
||||
}
|
||||
|
||||
userAPI := &internal.UserInternalAPI{
|
||||
DB: db,
|
||||
KeyDatabase: keyDB,
|
||||
SyncProducer: syncProducer,
|
||||
KeyChangeProducer: keyChangeProducer,
|
||||
Config: cfg,
|
||||
AppServices: appServices,
|
||||
KeyAPI: keyAPI,
|
||||
RSAPI: rsAPI,
|
||||
DisableTLSValidation: cfg.PushGatewayDisableTLSValidation,
|
||||
PgClient: pgClient,
|
||||
Cfg: cfg,
|
||||
FedClient: fedClient,
|
||||
}
|
||||
|
||||
updater := internal.NewDeviceListUpdater(base.ProcessContext, keyDB, userAPI, keyChangeProducer, fedClient, 8, rsAPI, cfg.Matrix.ServerName) // 8 workers TODO: configurable
|
||||
userAPI.Updater = updater
|
||||
// Remove users which we don't share a room with anymore
|
||||
if err := updater.CleanUp(); err != nil {
|
||||
logrus.WithError(err).Error("failed to cleanup stale device lists")
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := updater.Start(); err != nil {
|
||||
logrus.WithError(err).Panicf("failed to start device list updater")
|
||||
}
|
||||
}()
|
||||
|
||||
dlConsumer := consumers.NewDeviceListUpdateConsumer(
|
||||
base.ProcessContext, cfg, js, updater,
|
||||
)
|
||||
if err := dlConsumer.Start(); err != nil {
|
||||
logrus.WithError(err).Panic("failed to start device list consumer")
|
||||
}
|
||||
|
||||
sigConsumer := consumers.NewSigningKeyUpdateConsumer(
|
||||
base.ProcessContext, cfg, js, userAPI,
|
||||
)
|
||||
if err := sigConsumer.Start(); err != nil {
|
||||
logrus.WithError(err).Panic("failed to start signing key consumer")
|
||||
}
|
||||
|
||||
receiptConsumer := consumers.NewOutputReceiptEventConsumer(
|
||||
|
|
|
@ -21,7 +21,10 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/userapi/producers"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/nats-io/nats.go"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
|
@ -38,32 +41,55 @@ const (
|
|||
|
||||
type apiTestOpts struct {
|
||||
loginTokenLifetime time.Duration
|
||||
serverName string
|
||||
}
|
||||
|
||||
func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (api.UserInternalAPI, storage.Database, func()) {
|
||||
type dummyProducer struct{}
|
||||
|
||||
func (d *dummyProducer) PublishMsg(*nats.Msg, ...nats.PubOpt) (*nats.PubAck, error) {
|
||||
return &nats.PubAck{}, nil
|
||||
}
|
||||
|
||||
func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (api.UserInternalAPI, storage.UserDatabase, func()) {
|
||||
if opts.loginTokenLifetime == 0 {
|
||||
opts.loginTokenLifetime = api.DefaultLoginTokenLifetime * time.Millisecond
|
||||
}
|
||||
base, baseclose := testrig.CreateBaseDendrite(t, dbType)
|
||||
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||
accountDB, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{
|
||||
sName := serverName
|
||||
if opts.serverName != "" {
|
||||
sName = gomatrixserverlib.ServerName(opts.serverName)
|
||||
}
|
||||
accountDB, err := storage.NewUserDatabase(base, &config.DatabaseOptions{
|
||||
ConnectionString: config.DataSource(connStr),
|
||||
}, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime, "")
|
||||
}, sName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime, "")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create account DB: %s", err)
|
||||
}
|
||||
|
||||
keyDB, err := storage.NewKeyDatabase(base, &config.DatabaseOptions{
|
||||
ConnectionString: config.DataSource(connStr),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create key DB: %s", err)
|
||||
}
|
||||
|
||||
cfg := &config.UserAPI{
|
||||
Matrix: &config.Global{
|
||||
SigningIdentity: gomatrixserverlib.SigningIdentity{
|
||||
ServerName: serverName,
|
||||
ServerName: sName,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
syncProducer := producers.NewSyncAPI(accountDB, &dummyProducer{}, "", "")
|
||||
keyChangeProducer := &producers.KeyChange{DB: keyDB, JetStream: &dummyProducer{}}
|
||||
return &internal.UserInternalAPI{
|
||||
DB: accountDB,
|
||||
Config: cfg,
|
||||
DB: accountDB,
|
||||
KeyDatabase: keyDB,
|
||||
Config: cfg,
|
||||
SyncProducer: syncProducer,
|
||||
KeyChangeProducer: keyChangeProducer,
|
||||
}, accountDB, func() {
|
||||
close()
|
||||
baseclose()
|
||||
|
@ -332,3 +358,292 @@ func TestQueryAccountByLocalpart(t *testing.T) {
|
|||
testCases(t, intAPI)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAccountData(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
alice := test.NewUser(t)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
inputData *api.InputAccountDataRequest
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "not a local user",
|
||||
inputData: &api.InputAccountDataRequest{UserID: "@notlocal:example.com"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "local user missing datatype",
|
||||
inputData: &api.InputAccountDataRequest{UserID: alice.ID},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "missing json",
|
||||
inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.push_rules", AccountData: nil},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "with json",
|
||||
inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.push_rules", AccountData: []byte("{}")},
|
||||
},
|
||||
{
|
||||
name: "room data",
|
||||
inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.push_rules", AccountData: []byte("{}"), RoomID: "!dummy:test"},
|
||||
},
|
||||
{
|
||||
name: "ignored users",
|
||||
inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.ignored_user_list", AccountData: []byte("{}")},
|
||||
},
|
||||
{
|
||||
name: "m.fully_read",
|
||||
inputData: &api.InputAccountDataRequest{UserID: alice.ID, DataType: "m.fully_read", AccountData: []byte("{}")},
|
||||
},
|
||||
}
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
intAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{serverName: "test"}, dbType)
|
||||
defer close()
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
res := api.InputAccountDataResponse{}
|
||||
err := intAPI.InputAccountData(ctx, tc.inputData, &res)
|
||||
if tc.wantErr && err == nil {
|
||||
t.Fatalf("expected an error, but got none")
|
||||
}
|
||||
if !tc.wantErr && err != nil {
|
||||
t.Fatalf("expected no error, but got: %s", err)
|
||||
}
|
||||
|
||||
// query the data again and compare
|
||||
queryRes := api.QueryAccountDataResponse{}
|
||||
queryReq := api.QueryAccountDataRequest{
|
||||
UserID: tc.inputData.UserID,
|
||||
DataType: tc.inputData.DataType,
|
||||
RoomID: tc.inputData.RoomID,
|
||||
}
|
||||
err = intAPI.QueryAccountData(ctx, &queryReq, &queryRes)
|
||||
if err != nil && !tc.wantErr {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// verify global data
|
||||
if tc.inputData.RoomID == "" {
|
||||
if !reflect.DeepEqual(tc.inputData.AccountData, queryRes.GlobalAccountData[tc.inputData.DataType]) {
|
||||
t.Fatalf("expected accountdata to be %s, got %s", string(tc.inputData.AccountData), string(queryRes.GlobalAccountData[tc.inputData.DataType]))
|
||||
}
|
||||
} else {
|
||||
// verify room data
|
||||
if !reflect.DeepEqual(tc.inputData.AccountData, queryRes.RoomAccountData[tc.inputData.RoomID][tc.inputData.DataType]) {
|
||||
t.Fatalf("expected accountdata to be %s, got %s", string(tc.inputData.AccountData), string(queryRes.RoomAccountData[tc.inputData.RoomID][tc.inputData.DataType]))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestDevices(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
dupeAccessToken := util.RandomString(8)
|
||||
|
||||
displayName := "testing"
|
||||
|
||||
creationTests := []struct {
|
||||
name string
|
||||
inputData *api.PerformDeviceCreationRequest
|
||||
wantErr bool
|
||||
wantNewDevID bool
|
||||
}{
|
||||
{
|
||||
name: "not a local user",
|
||||
inputData: &api.PerformDeviceCreationRequest{Localpart: "test1", ServerName: "notlocal"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "implicit local user",
|
||||
inputData: &api.PerformDeviceCreationRequest{Localpart: "test1", AccessToken: util.RandomString(8), NoDeviceListUpdate: true, DeviceDisplayName: &displayName},
|
||||
},
|
||||
{
|
||||
name: "explicit local user",
|
||||
inputData: &api.PerformDeviceCreationRequest{Localpart: "test2", ServerName: "test", AccessToken: util.RandomString(8), NoDeviceListUpdate: true},
|
||||
},
|
||||
{
|
||||
name: "dupe token - ok",
|
||||
inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: dupeAccessToken, NoDeviceListUpdate: true},
|
||||
},
|
||||
{
|
||||
name: "dupe token - not ok",
|
||||
inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: dupeAccessToken, NoDeviceListUpdate: true},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "test3 second device", // used to test deletion later
|
||||
inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: util.RandomString(8), NoDeviceListUpdate: true},
|
||||
},
|
||||
{
|
||||
name: "test3 third device", // used to test deletion later
|
||||
wantNewDevID: true,
|
||||
inputData: &api.PerformDeviceCreationRequest{Localpart: "test3", ServerName: "test", AccessToken: util.RandomString(8), NoDeviceListUpdate: true},
|
||||
},
|
||||
}
|
||||
|
||||
deletionTests := []struct {
|
||||
name string
|
||||
inputData *api.PerformDeviceDeletionRequest
|
||||
wantErr bool
|
||||
wantDevices int
|
||||
}{
|
||||
{
|
||||
name: "deletion - not a local user",
|
||||
inputData: &api.PerformDeviceDeletionRequest{UserID: "@test:notlocalhost"},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "deleting not existing devices should not error",
|
||||
inputData: &api.PerformDeviceDeletionRequest{UserID: "@test1:test", DeviceIDs: []string{"iDontExist"}},
|
||||
wantDevices: 1,
|
||||
},
|
||||
{
|
||||
name: "delete all devices",
|
||||
inputData: &api.PerformDeviceDeletionRequest{UserID: "@test1:test"},
|
||||
wantDevices: 0,
|
||||
},
|
||||
{
|
||||
name: "delete all devices",
|
||||
inputData: &api.PerformDeviceDeletionRequest{UserID: "@test3:test"},
|
||||
wantDevices: 0,
|
||||
},
|
||||
}
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
intAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{serverName: "test"}, dbType)
|
||||
defer close()
|
||||
|
||||
for _, tc := range creationTests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
res := api.PerformDeviceCreationResponse{}
|
||||
deviceID := util.RandomString(8)
|
||||
tc.inputData.DeviceID = &deviceID
|
||||
if tc.wantNewDevID {
|
||||
tc.inputData.DeviceID = nil
|
||||
}
|
||||
err := intAPI.PerformDeviceCreation(ctx, tc.inputData, &res)
|
||||
if tc.wantErr && err == nil {
|
||||
t.Fatalf("expected an error, but got none")
|
||||
}
|
||||
if !tc.wantErr && err != nil {
|
||||
t.Fatalf("expected no error, but got: %s", err)
|
||||
}
|
||||
if !res.DeviceCreated {
|
||||
return
|
||||
}
|
||||
|
||||
queryDevicesRes := api.QueryDevicesResponse{}
|
||||
queryDevicesReq := api.QueryDevicesRequest{UserID: res.Device.UserID}
|
||||
if err = intAPI.QueryDevices(ctx, &queryDevicesReq, &queryDevicesRes); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// We only want to verify one device
|
||||
if len(queryDevicesRes.Devices) > 1 {
|
||||
return
|
||||
}
|
||||
res.Device.AccessToken = ""
|
||||
|
||||
// At this point, there should only be one device
|
||||
if !reflect.DeepEqual(*res.Device, queryDevicesRes.Devices[0]) {
|
||||
t.Fatalf("expected device to be\n%#v, got \n%#v", *res.Device, queryDevicesRes.Devices[0])
|
||||
}
|
||||
|
||||
newDisplayName := "new name"
|
||||
if tc.inputData.DeviceDisplayName == nil {
|
||||
updateRes := api.PerformDeviceUpdateResponse{}
|
||||
updateReq := api.PerformDeviceUpdateRequest{
|
||||
RequestingUserID: fmt.Sprintf("@%s:%s", tc.inputData.Localpart, "test"),
|
||||
DeviceID: deviceID,
|
||||
DisplayName: &newDisplayName,
|
||||
}
|
||||
|
||||
if err = intAPI.PerformDeviceUpdate(ctx, &updateReq, &updateRes); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
queryDeviceInfosRes := api.QueryDeviceInfosResponse{}
|
||||
queryDeviceInfosReq := api.QueryDeviceInfosRequest{DeviceIDs: []string{*tc.inputData.DeviceID}}
|
||||
if err = intAPI.QueryDeviceInfos(ctx, &queryDeviceInfosReq, &queryDeviceInfosRes); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
gotDisplayName := queryDeviceInfosRes.DeviceInfo[*tc.inputData.DeviceID].DisplayName
|
||||
if tc.inputData.DeviceDisplayName != nil {
|
||||
wantDisplayName := *tc.inputData.DeviceDisplayName
|
||||
if wantDisplayName != gotDisplayName {
|
||||
t.Fatalf("expected displayName to be %s, got %s", wantDisplayName, gotDisplayName)
|
||||
}
|
||||
} else {
|
||||
wantDisplayName := newDisplayName
|
||||
if wantDisplayName != gotDisplayName {
|
||||
t.Fatalf("expected displayName to be %s, got %s", wantDisplayName, gotDisplayName)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
for _, tc := range deletionTests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
delRes := api.PerformDeviceDeletionResponse{}
|
||||
err := intAPI.PerformDeviceDeletion(ctx, tc.inputData, &delRes)
|
||||
if tc.wantErr && err == nil {
|
||||
t.Fatalf("expected an error, but got none")
|
||||
}
|
||||
if !tc.wantErr && err != nil {
|
||||
t.Fatalf("expected no error, but got: %s", err)
|
||||
}
|
||||
if tc.wantErr {
|
||||
return
|
||||
}
|
||||
|
||||
queryDevicesRes := api.QueryDevicesResponse{}
|
||||
queryDevicesReq := api.QueryDevicesRequest{UserID: tc.inputData.UserID}
|
||||
if err = intAPI.QueryDevices(ctx, &queryDevicesReq, &queryDevicesRes); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(queryDevicesRes.Devices) != tc.wantDevices {
|
||||
t.Fatalf("expected %d devices, got %d", tc.wantDevices, len(queryDevicesRes.Devices))
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Tests that the session ID of a device is not reused when reusing the same device ID.
|
||||
func TestDeviceIDReuse(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
intAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{serverName: "test"}, dbType)
|
||||
defer close()
|
||||
|
||||
res := api.PerformDeviceCreationResponse{}
|
||||
// create a first device
|
||||
deviceID := util.RandomString(8)
|
||||
req := api.PerformDeviceCreationRequest{Localpart: "alice", ServerName: "test", DeviceID: &deviceID, NoDeviceListUpdate: true}
|
||||
err := intAPI.PerformDeviceCreation(ctx, &req, &res)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Do the same request again, we expect a different sessionID
|
||||
res2 := api.PerformDeviceCreationResponse{}
|
||||
err = intAPI.PerformDeviceCreation(ctx, &req, &res2)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, but got: %v", err)
|
||||
}
|
||||
|
||||
if res2.Device.SessionID == res.Device.SessionID {
|
||||
t.Fatalf("expected a different session ID, but they are the same")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -19,7 +19,7 @@ type PusherDevice struct {
|
|||
}
|
||||
|
||||
// GetPushDevices pushes to the configured devices of a local user.
|
||||
func GetPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}, db storage.Database) ([]*PusherDevice, error) {
|
||||
func GetPushDevices(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, tweaks map[string]interface{}, db storage.UserDatabase) ([]*PusherDevice, error) {
|
||||
pushers, err := db.GetPushers(ctx, localpart, serverName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db.GetPushers: %w", err)
|
||||
|
|
|
@ -13,11 +13,11 @@ import (
|
|||
)
|
||||
|
||||
// NotifyUserCountsAsync sends notifications to a local user's
|
||||
// notification destinations. Database lookups run synchronously, but
|
||||
// notification destinations. UserDatabase lookups run synchronously, but
|
||||
// a single goroutine is started when talking to the Push
|
||||
// gateways. There is no way to know when the background goroutine has
|
||||
// finished.
|
||||
func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, serverName gomatrixserverlib.ServerName, db storage.Database) error {
|
||||
func NotifyUserCountsAsync(ctx context.Context, pgClient pushgateway.Client, localpart string, serverName gomatrixserverlib.ServerName, db storage.UserDatabase) error {
|
||||
pusherDevices, err := GetPushDevices(ctx, localpart, serverName, nil, db)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
@ -79,7 +79,7 @@ func TestNotifyUserCountsAsync(t *testing.T) {
|
|||
defer close()
|
||||
base, _, _ := testrig.Base(nil)
|
||||
defer base.Close()
|
||||
db, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{
|
||||
db, err := storage.NewUserDatabase(base, &config.DatabaseOptions{
|
||||
ConnectionString: config.DataSource(connStr),
|
||||
}, "test", bcrypt.MinCost, 0, 0, "")
|
||||
if err != nil {
|
||||
|
|
|
@ -21,7 +21,7 @@ func TestCollect(t *testing.T) {
|
|||
b, _, _ := testrig.Base(nil)
|
||||
connStr, closeDB := test.PrepareDBConnectionString(t, dbType)
|
||||
defer closeDB()
|
||||
db, err := storage.NewUserAPIDatabase(b, &config.DatabaseOptions{
|
||||
db, err := storage.NewUserDatabase(b, &config.DatabaseOptions{
|
||||
ConnectionString: config.DataSource(connStr),
|
||||
}, "localhost", bcrypt.MinCost, 1000, 1000, "")
|
||||
if err != nil {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue