Generate stream IDs for locally uploaded device keys (#1236)

* Breaking: add stream_id to keyserver_device_keys table

* Add tests for stream ID generation

* Fix whitelist
This commit is contained in:
Kegsay 2020-08-03 17:07:06 +01:00 committed by GitHub
parent ffcb6d2ea1
commit fb56bbf0b7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 265 additions and 84 deletions

View file

@ -43,6 +43,13 @@ func (k *KeyError) Error() string {
return k.Err return k.Err
} }
// DeviceMessage represents the message produced into Kafka by the key server.
type DeviceMessage struct {
DeviceKeys
// A monotonically increasing number which represents device changes for this user.
StreamID int
}
// DeviceKeys represents a set of device keys for a single device // DeviceKeys represents a set of device keys for a single device
// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload // https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload
type DeviceKeys struct { type DeviceKeys struct {
@ -50,10 +57,20 @@ type DeviceKeys struct {
UserID string UserID string
// The device ID of this device // The device ID of this device
DeviceID string DeviceID string
// The device display name
DisplayName string
// The raw device key JSON // The raw device key JSON
KeyJSON []byte KeyJSON []byte
} }
// WithStreamID returns a copy of this device message with the given stream ID
func (k *DeviceKeys) WithStreamID(streamID int) DeviceMessage {
return DeviceMessage{
DeviceKeys: *k,
StreamID: streamID,
}
}
// OneTimeKeys represents a set of one-time keys for a single device // OneTimeKeys represents a set of one-time keys for a single device
// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload // https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-keys-upload
type OneTimeKeys struct { type OneTimeKeys struct {

View file

@ -61,7 +61,7 @@ func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyC
func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { func (a *KeyInternalAPI) PerformUploadKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
res.KeyErrors = make(map[string]map[string]*api.KeyError) res.KeyErrors = make(map[string]map[string]*api.KeyError)
a.uploadDeviceKeys(ctx, req, res) a.uploadLocalDeviceKeys(ctx, req, res)
a.uploadOneTimeKeys(ctx, req, res) a.uploadOneTimeKeys(ctx, req, res)
} }
@ -286,18 +286,25 @@ func (a *KeyInternalAPI) queryRemoteKeys(
} }
} }
func (a *KeyInternalAPI) uploadDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) { func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
var keysToStore []api.DeviceKeys var keysToStore []api.DeviceMessage
// assert that the user ID / device ID are not lying for each key // assert that the user ID / device ID are not lying for each key
for _, key := range req.DeviceKeys { for _, key := range req.DeviceKeys {
_, serverName, err := gomatrixserverlib.SplitID('@', key.UserID)
if err != nil {
continue // ignore invalid users
}
if serverName != a.ThisServer {
continue // ignore remote users
}
if len(key.KeyJSON) == 0 { if len(key.KeyJSON) == 0 {
keysToStore = append(keysToStore, key) keysToStore = append(keysToStore, key.WithStreamID(0))
continue // deleted keys don't need sanity checking continue // deleted keys don't need sanity checking
} }
gotUserID := gjson.GetBytes(key.KeyJSON, "user_id").Str gotUserID := gjson.GetBytes(key.KeyJSON, "user_id").Str
gotDeviceID := gjson.GetBytes(key.KeyJSON, "device_id").Str gotDeviceID := gjson.GetBytes(key.KeyJSON, "device_id").Str
if gotUserID == key.UserID && gotDeviceID == key.DeviceID { if gotUserID == key.UserID && gotDeviceID == key.DeviceID {
keysToStore = append(keysToStore, key) keysToStore = append(keysToStore, key.WithStreamID(0))
continue continue
} }
@ -310,11 +317,13 @@ func (a *KeyInternalAPI) uploadDeviceKeys(ctx context.Context, req *api.PerformU
} }
// get existing device keys so we can check for changes // get existing device keys so we can check for changes
existingKeys := make([]api.DeviceKeys, len(keysToStore)) existingKeys := make([]api.DeviceMessage, len(keysToStore))
for i := range keysToStore { for i := range keysToStore {
existingKeys[i] = api.DeviceKeys{ existingKeys[i] = api.DeviceMessage{
DeviceKeys: api.DeviceKeys{
UserID: keysToStore[i].UserID, UserID: keysToStore[i].UserID,
DeviceID: keysToStore[i].DeviceID, DeviceID: keysToStore[i].DeviceID,
},
} }
} }
if err := a.DB.DeviceKeysJSON(ctx, existingKeys); err != nil { if err := a.DB.DeviceKeysJSON(ctx, existingKeys); err != nil {
@ -324,13 +333,14 @@ func (a *KeyInternalAPI) uploadDeviceKeys(ctx context.Context, req *api.PerformU
return return
} }
// store the device keys and emit changes // store the device keys and emit changes
if err := a.DB.StoreDeviceKeys(ctx, keysToStore); err != nil { err := a.DB.StoreDeviceKeys(ctx, keysToStore)
if err != nil {
res.Error = &api.KeyError{ res.Error = &api.KeyError{
Err: fmt.Sprintf("failed to store device keys: %s", err.Error()), Err: fmt.Sprintf("failed to store device keys: %s", err.Error()),
} }
return return
} }
err := a.emitDeviceKeyChanges(existingKeys, keysToStore) err = a.emitDeviceKeyChanges(existingKeys, keysToStore)
if err != nil { if err != nil {
util.GetLogger(ctx).Errorf("Failed to emitDeviceKeyChanges: %s", err) util.GetLogger(ctx).Errorf("Failed to emitDeviceKeyChanges: %s", err)
} }
@ -375,9 +385,9 @@ func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perform
} }
func (a *KeyInternalAPI) emitDeviceKeyChanges(existing, new []api.DeviceKeys) error { func (a *KeyInternalAPI) emitDeviceKeyChanges(existing, new []api.DeviceMessage) error {
// find keys in new that are not in existing // find keys in new that are not in existing
var keysAdded []api.DeviceKeys var keysAdded []api.DeviceMessage
for _, newKey := range new { for _, newKey := range new {
exists := false exists := false
for _, existingKey := range existing { for _, existingKey := range existing {

View file

@ -41,7 +41,7 @@ func (p *KeyChange) DefaultPartition() int32 {
} }
// ProduceKeyChanges creates new change events for each key // ProduceKeyChanges creates new change events for each key
func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceKeys) error { func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceMessage) error {
for _, key := range keys { for _, key := range keys {
var m sarama.ProducerMessage var m sarama.ProducerMessage

View file

@ -32,17 +32,18 @@ type Database interface {
// OneTimeKeysCount returns a count of all OTKs for this device. // OneTimeKeysCount returns a count of all OTKs for this device.
OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` already then it will be replaced. // DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
DeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
// StoreDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key // StoreDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
// for this (user, device). // for this (user, device).
// The `StreamID` for each message is set on successful insertion. In the event the key already exists, the existing StreamID is set.
// Returns an error if there was a problem storing the keys. // Returns an error if there was a problem storing the keys.
StoreDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error StoreDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error
// DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected. // DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected.
// If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice. // If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice.
DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error)
// ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key // ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key
// cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice. // cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice.

View file

@ -20,7 +20,6 @@ import (
"time" "time"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage/tables" "github.com/matrix-org/dendrite/keyserver/storage/tables"
) )
@ -32,28 +31,37 @@ CREATE TABLE IF NOT EXISTS keyserver_device_keys (
device_id TEXT NOT NULL, device_id TEXT NOT NULL,
ts_added_secs BIGINT NOT NULL, ts_added_secs BIGINT NOT NULL,
key_json TEXT NOT NULL, key_json TEXT NOT NULL,
-- the stream ID of this key, scoped per-user. This gets updated when the device key changes.
-- This means we do not store an unbounded append-only log of device keys, which is not actually
-- required in the spec because in the event of a missed update the server fetches the entire
-- current set of keys rather than trying to 'fast-forward' or catchup missing stream IDs.
stream_id BIGINT NOT NULL,
-- Clobber based on tuple of user/device. -- Clobber based on tuple of user/device.
CONSTRAINT keyserver_device_keys_unique UNIQUE (user_id, device_id) CONSTRAINT keyserver_device_keys_unique UNIQUE (user_id, device_id)
); );
` `
const upsertDeviceKeysSQL = "" + const upsertDeviceKeysSQL = "" +
"INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json)" + "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id)" +
" VALUES ($1, $2, $3, $4)" + " VALUES ($1, $2, $3, $4, $5)" +
" ON CONFLICT ON CONSTRAINT keyserver_device_keys_unique" + " ON CONFLICT ON CONSTRAINT keyserver_device_keys_unique" +
" DO UPDATE SET key_json = $4" " DO UPDATE SET key_json = $4, stream_id = $5"
const selectDeviceKeysSQL = "" + const selectDeviceKeysSQL = "" +
"SELECT key_json FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" "SELECT key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
const selectBatchDeviceKeysSQL = "" + const selectBatchDeviceKeysSQL = "" +
"SELECT device_id, key_json FROM keyserver_device_keys WHERE user_id=$1" "SELECT device_id, key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1"
const selectMaxStreamForUserSQL = "" +
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
type deviceKeysStatements struct { type deviceKeysStatements struct {
db *sql.DB db *sql.DB
upsertDeviceKeysStmt *sql.Stmt upsertDeviceKeysStmt *sql.Stmt
selectDeviceKeysStmt *sql.Stmt selectDeviceKeysStmt *sql.Stmt
selectBatchDeviceKeysStmt *sql.Stmt selectBatchDeviceKeysStmt *sql.Stmt
selectMaxStreamForUserStmt *sql.Stmt
} }
func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
@ -73,38 +81,54 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil { if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
return nil, err return nil, err
} }
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
return nil, err
}
return s, nil return s, nil
} }
func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error { func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
for i, key := range keys { for i, key := range keys {
var keyJSONStr string var keyJSONStr string
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr) var streamID int
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
return err return err
} }
// this will be '' when there is no device // this will be '' when there is no device
keys[i].KeyJSON = []byte(keyJSONStr) keys[i].KeyJSON = []byte(keyJSONStr)
keys[i].StreamID = streamID
} }
return nil return nil
} }
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error { func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) {
now := time.Now().Unix() // nullable if there are no results
return sqlutil.WithTransaction(s.db, func(txn *sql.Tx) error { var nullStream sql.NullInt32
err = txn.Stmt(s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
if err == sql.ErrNoRows {
err = nil
}
if nullStream.Valid {
streamID = nullStream.Int32
}
return
}
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
for _, key := range keys { for _, key := range keys {
now := time.Now().Unix()
_, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext( _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext(
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID,
) )
if err != nil { if err != nil {
return err return err
} }
} }
return nil return nil
})
} }
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) { func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID) rows, err := s.selectBatchDeviceKeysStmt.QueryContext(ctx, userID)
if err != nil { if err != nil {
return nil, err return nil, err
@ -114,15 +138,17 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
for _, d := range deviceIDs { for _, d := range deviceIDs {
deviceIDMap[d] = true deviceIDMap[d] = true
} }
var result []api.DeviceKeys var result []api.DeviceMessage
for rows.Next() { for rows.Next() {
var dk api.DeviceKeys var dk api.DeviceMessage
dk.UserID = userID dk.UserID = userID
var keyJSON string var keyJSON string
if err := rows.Scan(&dk.DeviceID, &keyJSON); err != nil { var streamID int
if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID); err != nil {
return nil, err return nil, err
} }
dk.KeyJSON = []byte(keyJSON) dk.KeyJSON = []byte(keyJSON)
dk.StreamID = streamID
// include the key if we want all keys (no device) or it was asked // include the key if we want all keys (no device) or it was asked
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 { if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
result = append(result, dk) result = append(result, dk)

View file

@ -43,15 +43,36 @@ func (d *Database) OneTimeKeysCount(ctx context.Context, userID, deviceID string
return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID) return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID)
} }
func (d *Database) DeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error { func (d *Database) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys) return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys)
} }
func (d *Database) StoreDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error { func (d *Database) StoreDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error {
return d.DeviceKeysTable.InsertDeviceKeys(ctx, keys) // work out the latest stream IDs for each user
userIDToStreamID := make(map[string]int)
for _, k := range keys {
userIDToStreamID[k.UserID] = 0
}
return sqlutil.WithTransaction(d.DB, func(txn *sql.Tx) error {
for userID := range userIDToStreamID {
streamID, err := d.DeviceKeysTable.SelectMaxStreamIDForUser(ctx, txn, userID)
if err != nil {
return err
}
userIDToStreamID[userID] = int(streamID)
}
// set the stream IDs for each key
for i := range keys {
k := keys[i]
userIDToStreamID[k.UserID]++ // start stream from 1
k.StreamID = userIDToStreamID[k.UserID]
keys[i] = k
}
return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys)
})
} }
func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) { func (d *Database) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs) return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs)
} }

View file

@ -20,7 +20,6 @@ import (
"time" "time"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage/tables" "github.com/matrix-org/dendrite/keyserver/storage/tables"
) )
@ -32,28 +31,33 @@ CREATE TABLE IF NOT EXISTS keyserver_device_keys (
device_id TEXT NOT NULL, device_id TEXT NOT NULL,
ts_added_secs BIGINT NOT NULL, ts_added_secs BIGINT NOT NULL,
key_json TEXT NOT NULL, key_json TEXT NOT NULL,
stream_id BIGINT NOT NULL,
-- Clobber based on tuple of user/device. -- Clobber based on tuple of user/device.
UNIQUE (user_id, device_id) UNIQUE (user_id, device_id)
); );
` `
const upsertDeviceKeysSQL = "" + const upsertDeviceKeysSQL = "" +
"INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json)" + "INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id)" +
" VALUES ($1, $2, $3, $4)" + " VALUES ($1, $2, $3, $4, $5)" +
" ON CONFLICT (user_id, device_id)" + " ON CONFLICT (user_id, device_id)" +
" DO UPDATE SET key_json = $4" " DO UPDATE SET key_json = $4, stream_id = $5"
const selectDeviceKeysSQL = "" + const selectDeviceKeysSQL = "" +
"SELECT key_json FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2" "SELECT key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
const selectBatchDeviceKeysSQL = "" + const selectBatchDeviceKeysSQL = "" +
"SELECT device_id, key_json FROM keyserver_device_keys WHERE user_id=$1" "SELECT device_id, key_json, stream_id FROM keyserver_device_keys WHERE user_id=$1"
const selectMaxStreamForUserSQL = "" +
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
type deviceKeysStatements struct { type deviceKeysStatements struct {
db *sql.DB db *sql.DB
upsertDeviceKeysStmt *sql.Stmt upsertDeviceKeysStmt *sql.Stmt
selectDeviceKeysStmt *sql.Stmt selectDeviceKeysStmt *sql.Stmt
selectBatchDeviceKeysStmt *sql.Stmt selectBatchDeviceKeysStmt *sql.Stmt
selectMaxStreamForUserStmt *sql.Stmt
} }
func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) { func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
@ -73,10 +77,13 @@ func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil { if s.selectBatchDeviceKeysStmt, err = db.Prepare(selectBatchDeviceKeysSQL); err != nil {
return nil, err return nil, err
} }
if s.selectMaxStreamForUserStmt, err = db.Prepare(selectMaxStreamForUserSQL); err != nil {
return nil, err
}
return s, nil return s, nil
} }
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) { func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error) {
deviceIDMap := make(map[string]bool) deviceIDMap := make(map[string]bool)
for _, d := range deviceIDs { for _, d := range deviceIDs {
deviceIDMap[d] = true deviceIDMap[d] = true
@ -86,15 +93,17 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
return nil, err return nil, err
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed")
var result []api.DeviceKeys var result []api.DeviceMessage
for rows.Next() { for rows.Next() {
var dk api.DeviceKeys var dk api.DeviceMessage
dk.UserID = userID dk.UserID = userID
var keyJSON string var keyJSON string
if err := rows.Scan(&dk.DeviceID, &keyJSON); err != nil { var streamID int
if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID); err != nil {
return nil, err return nil, err
} }
dk.KeyJSON = []byte(keyJSON) dk.KeyJSON = []byte(keyJSON)
dk.StreamID = streamID
// include the key if we want all keys (no device) or it was asked // include the key if we want all keys (no device) or it was asked
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 { if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
result = append(result, dk) result = append(result, dk)
@ -103,30 +112,43 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
return result, rows.Err() return result, rows.Err()
} }
func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error { func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
for i, key := range keys { for i, key := range keys {
var keyJSONStr string var keyJSONStr string
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr) var streamID int
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID)
if err != nil && err != sql.ErrNoRows { if err != nil && err != sql.ErrNoRows {
return err return err
} }
// this will be '' when there is no device // this will be '' when there is no device
keys[i].KeyJSON = []byte(keyJSONStr) keys[i].KeyJSON = []byte(keyJSONStr)
keys[i].StreamID = streamID
} }
return nil return nil
} }
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error { func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) {
now := time.Now().Unix() // nullable if there are no results
return sqlutil.WithTransaction(s.db, func(txn *sql.Tx) error { var nullStream sql.NullInt32
err = txn.Stmt(s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
if err == sql.ErrNoRows {
err = nil
}
if nullStream.Valid {
streamID = nullStream.Int32
}
return
}
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
for _, key := range keys { for _, key := range keys {
now := time.Now().Unix()
_, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext( _, err := txn.Stmt(s.upsertDeviceKeysStmt).ExecContext(
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID,
) )
if err != nil { if err != nil {
return err return err
} }
} }
return nil return nil
})
} }

View file

@ -6,6 +6,7 @@ import (
"testing" "testing"
"github.com/Shopify/sarama" "github.com/Shopify/sarama"
"github.com/matrix-org/dendrite/keyserver/api"
) )
var ctx = context.Background() var ctx = context.Background()
@ -77,3 +78,84 @@ func TestKeyChangesUpperLimit(t *testing.T) {
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
} }
} }
// The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user,
// and that they are returned correctly when querying for device keys.
func TestDeviceKeysStreamIDGeneration(t *testing.T) {
db, err := NewDatabase("file::memory:", nil)
if err != nil {
t.Fatalf("Failed to NewDatabase: %s", err)
}
alice := "@alice:TestDeviceKeysStreamIDGeneration"
bob := "@bob:TestDeviceKeysStreamIDGeneration"
msgs := []api.DeviceMessage{
{
DeviceKeys: api.DeviceKeys{
DeviceID: "AAA",
UserID: alice,
KeyJSON: []byte(`{"key":"v1"}`),
},
// StreamID: 1
},
{
DeviceKeys: api.DeviceKeys{
DeviceID: "AAA",
UserID: bob,
KeyJSON: []byte(`{"key":"v1"}`),
},
// StreamID: 1 as this is a different user
},
{
DeviceKeys: api.DeviceKeys{
DeviceID: "another_device",
UserID: alice,
KeyJSON: []byte(`{"key":"v1"}`),
},
// StreamID: 2 as this is a 2nd device key
},
}
MustNotError(t, db.StoreDeviceKeys(ctx, msgs))
if msgs[0].StreamID != 1 {
t.Fatalf("Expected StoreDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID)
}
if msgs[1].StreamID != 1 {
t.Fatalf("Expected StoreDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID)
}
if msgs[2].StreamID != 2 {
t.Fatalf("Expected StoreDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID)
}
// updating a device sets the next stream ID for that user
msgs = []api.DeviceMessage{
{
DeviceKeys: api.DeviceKeys{
DeviceID: "AAA",
UserID: alice,
KeyJSON: []byte(`{"key":"v2"}`),
},
// StreamID: 3
},
}
MustNotError(t, db.StoreDeviceKeys(ctx, msgs))
if msgs[0].StreamID != 3 {
t.Fatalf("Expected StoreDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID)
}
// Querying for device keys returns the latest stream IDs
msgs, err = db.DeviceKeysForUser(ctx, alice, []string{"AAA", "another_device"})
if err != nil {
t.Fatalf("DeviceKeysForUser returned error: %s", err)
}
wantStreamIDs := map[string]int{
"AAA": 3,
"another_device": 2,
}
if len(msgs) != len(wantStreamIDs) {
t.Fatalf("DeviceKeysForUser: wrong number of devices, got %d want %d", len(msgs), len(wantStreamIDs))
}
for _, m := range msgs {
if m.StreamID != wantStreamIDs[m.DeviceID] {
t.Errorf("DeviceKeysForUser: wrong returned stream ID for key, got %d want %d", m.StreamID, wantStreamIDs[m.DeviceID])
}
}
}

View file

@ -32,9 +32,10 @@ type OneTimeKeys interface {
} }
type DeviceKeys interface { type DeviceKeys interface {
SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceKeys) error SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
InsertDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error)
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceMessage, error)
} }
type KeyChanges interface { type KeyChanges interface {

View file

@ -98,7 +98,7 @@ func (s *OutputKeyChangeEventConsumer) onMessage(msg *sarama.ConsumerMessage) er
defer func() { defer func() {
s.updateOffset(msg) s.updateOffset(msg)
}() }()
var output api.DeviceKeys var output api.DeviceMessage
if err := json.Unmarshal(msg.Value, &output); err != nil { if err := json.Unmarshal(msg.Value, &output); err != nil {
// If the message was invalid, log it and move on to the next message in the stream // If the message was invalid, log it and move on to the next message in the stream
log.WithError(err).Error("syncapi: failed to unmarshal key change event from key server") log.WithError(err).Error("syncapi: failed to unmarshal key change event from key server")

View file

@ -110,6 +110,7 @@ Rooms a user is invited to appear in an incremental sync
Sync can be polled for updates Sync can be polled for updates
Sync is woken up for leaves Sync is woken up for leaves
Newly left rooms appear in the leave section of incremental sync Newly left rooms appear in the leave section of incremental sync
Rooms can be created with an initial invite list (SYN-205)
We should see our own leave event, even if history_visibility is restricted (SYN-662) We should see our own leave event, even if history_visibility is restricted (SYN-662)
We should see our own leave event when rejecting an invite, even if history_visibility is restricted (riot-web/3462) We should see our own leave event when rejecting an invite, even if history_visibility is restricted (riot-web/3462)
Newly left rooms appear in the leave section of gapped sync Newly left rooms appear in the leave section of gapped sync