Key Backups (3/3) : Implement querying keys and various bugfixes (#1946)

* Add querying device keys

Makes a bunch of sytests pass

* Apparently only the current version supports uploading keys

* Linting
This commit is contained in:
kegsay 2021-07-27 19:29:32 +01:00 committed by GitHub
parent b3754d68fc
commit 32bf14a37c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 362 additions and 101 deletions

View file

@ -61,6 +61,8 @@ type Database interface {
DeleteKeyBackup(ctx context.Context, userID, version string) (exists bool, err error)
GetKeyBackup(ctx context.Context, userID, version string) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error)
UpsertBackupKeys(ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession) (count int64, etag string, err error)
GetBackupKeys(ctx context.Context, version, userID, filterRoomID, filterSessionID string) (result map[string]map[string]api.KeyBackupSession, err error)
CountBackupKeys(ctx context.Context, version, userID string) (count int64, err error)
}
// Err3PIDInUse is the error returned when trying to save an association involving

View file

@ -35,7 +35,8 @@ CREATE TABLE IF NOT EXISTS account_e2e_room_keys (
is_verified BOOLEAN NOT NULL,
session_data TEXT NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id);
CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id, version);
CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_versions_idx ON account_e2e_room_keys(user_id, version);
`
const insertBackupKeySQL = "" +
@ -53,14 +54,23 @@ const selectKeysSQL = "" +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
"WHERE user_id = $1 AND version = $2"
const selectKeysByRoomIDSQL = "" +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
"WHERE user_id = $1 AND version = $2 AND room_id = $3"
const selectKeysByRoomIDAndSessionIDSQL = "" +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
"WHERE user_id = $1 AND version = $2 AND room_id = $3 AND session_id = $4"
type keyBackupStatements struct {
insertBackupKeyStmt *sql.Stmt
updateBackupKeyStmt *sql.Stmt
countKeysStmt *sql.Stmt
selectKeysStmt *sql.Stmt
insertBackupKeyStmt *sql.Stmt
updateBackupKeyStmt *sql.Stmt
countKeysStmt *sql.Stmt
selectKeysStmt *sql.Stmt
selectKeysByRoomIDStmt *sql.Stmt
selectKeysByRoomIDAndSessionIDStmt *sql.Stmt
}
// nolint:unused
func (s *keyBackupStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(keyBackupTableSchema)
if err != nil {
@ -78,6 +88,12 @@ func (s *keyBackupStatements) prepare(db *sql.DB) (err error) {
if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil {
return
}
if s.selectKeysByRoomIDStmt, err = db.Prepare(selectKeysByRoomIDSQL); err != nil {
return
}
if s.selectKeysByRoomIDAndSessionIDStmt, err = db.Prepare(selectKeysByRoomIDAndSessionIDSQL); err != nil {
return
}
return
}
@ -109,11 +125,35 @@ func (s *keyBackupStatements) updateBackupKey(
func (s *keyBackupStatements) selectKeys(
ctx context.Context, txn *sql.Tx, userID, version string,
) (map[string]map[string]api.KeyBackupSession, error) {
result := make(map[string]map[string]api.KeyBackupSession)
rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version)
if err != nil {
return nil, err
}
return unpackKeys(ctx, rows)
}
func (s *keyBackupStatements) selectKeysByRoomID(
ctx context.Context, txn *sql.Tx, userID, version, roomID string,
) (map[string]map[string]api.KeyBackupSession, error) {
rows, err := txn.Stmt(s.selectKeysByRoomIDStmt).QueryContext(ctx, userID, version, roomID)
if err != nil {
return nil, err
}
return unpackKeys(ctx, rows)
}
func (s *keyBackupStatements) selectKeysByRoomIDAndSessionID(
ctx context.Context, txn *sql.Tx, userID, version, roomID, sessionID string,
) (map[string]map[string]api.KeyBackupSession, error) {
rows, err := txn.Stmt(s.selectKeysByRoomIDAndSessionIDStmt).QueryContext(ctx, userID, version, roomID, sessionID)
if err != nil {
return nil, err
}
return unpackKeys(ctx, rows)
}
func unpackKeys(ctx context.Context, rows *sql.Rows) (map[string]map[string]api.KeyBackupSession, error) {
result := make(map[string]map[string]api.KeyBackupSession)
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt.Close failed")
for rows.Next() {
var key api.InternalKeyBackupSession

View file

@ -67,7 +67,6 @@ type keyBackupVersionStatements struct {
updateKeyBackupETagStmt *sql.Stmt
}
// nolint:unused
func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(keyBackupVersionTableSchema)
if err != nil {

View file

@ -96,13 +96,12 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err = d.openIDTokens.prepare(db, serverName); err != nil {
return nil, err
}
/*
if err = d.keyBackupVersions.prepare(db); err != nil {
return nil, err
}
if err = d.keyBackups.prepare(db); err != nil {
return nil, err
} */
if err = d.keyBackupVersions.prepare(db); err != nil {
return nil, err
}
if err = d.keyBackups.prepare(db); err != nil {
return nil, err
}
return d, nil
}
@ -418,6 +417,37 @@ func (d *Database) GetKeyBackup(
return
}
func (d *Database) GetBackupKeys(
ctx context.Context, version, userID, filterRoomID, filterSessionID string,
) (result map[string]map[string]api.KeyBackupSession, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
if filterSessionID != "" {
result, err = d.keyBackups.selectKeysByRoomIDAndSessionID(ctx, txn, userID, version, filterRoomID, filterSessionID)
return err
}
if filterRoomID != "" {
result, err = d.keyBackups.selectKeysByRoomID(ctx, txn, userID, version, filterRoomID)
return err
}
result, err = d.keyBackups.selectKeys(ctx, txn, userID, version)
return err
})
return
}
func (d *Database) CountBackupKeys(
ctx context.Context, version, userID string,
) (count int64, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
count, err = d.keyBackups.countKeys(ctx, txn, userID, version)
if err != nil {
return err
}
return nil
})
return
}
// nolint:nakedret
func (d *Database) UpsertBackupKeys(
ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession,
@ -445,7 +475,7 @@ func (d *Database) UpsertBackupKeys(
if existingRoom != nil {
existingSession, ok := existingRoom[newKey.SessionID]
if ok {
if shouldReplaceRoomKey(existingSession, newKey.KeyBackupSession) {
if existingSession.ShouldReplaceRoomKey(&newKey.KeyBackupSession) {
err = d.keyBackups.updateBackupKey(ctx, txn, userID, version, newKey)
changed = true
if err != nil {
@ -489,22 +519,3 @@ func (d *Database) UpsertBackupKeys(
})
return
}
// TODO FIXME XXX : This logic really shouldn't live in the storage layer, but I don't know where else is sensible which won't
// create circular import loops
func shouldReplaceRoomKey(existing, uploaded api.KeyBackupSession) bool {
// https://spec.matrix.org/unstable/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2
// "if the keys have different values for is_verified, then it will keep the key that has is_verified set to true"
if uploaded.IsVerified && !existing.IsVerified {
return true
}
// "if they have the same values for is_verified, then it will keep the key with a lower first_message_index"
if uploaded.FirstMessageIndex < existing.FirstMessageIndex {
return true
}
// "and finally, is is_verified and first_message_index are equal, then it will keep the key with a lower forwarded_count"
if uploaded.ForwardedCount < existing.ForwardedCount {
return true
}
return false
}

View file

@ -35,7 +35,8 @@ CREATE TABLE IF NOT EXISTS account_e2e_room_keys (
is_verified BOOLEAN NOT NULL,
session_data TEXT NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id);
CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_idx ON account_e2e_room_keys(user_id, room_id, session_id, version);
CREATE UNIQUE INDEX IF NOT EXISTS e2e_room_keys_versions_idx ON account_e2e_room_keys(user_id, version);
`
const insertBackupKeySQL = "" +
@ -53,14 +54,23 @@ const selectKeysSQL = "" +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
"WHERE user_id = $1 AND version = $2"
const selectKeysByRoomIDSQL = "" +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
"WHERE user_id = $1 AND version = $2 AND room_id = $3"
const selectKeysByRoomIDAndSessionIDSQL = "" +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM account_e2e_room_keys " +
"WHERE user_id = $1 AND version = $2 AND room_id = $3 AND session_id = $4"
type keyBackupStatements struct {
insertBackupKeyStmt *sql.Stmt
updateBackupKeyStmt *sql.Stmt
countKeysStmt *sql.Stmt
selectKeysStmt *sql.Stmt
insertBackupKeyStmt *sql.Stmt
updateBackupKeyStmt *sql.Stmt
countKeysStmt *sql.Stmt
selectKeysStmt *sql.Stmt
selectKeysByRoomIDStmt *sql.Stmt
selectKeysByRoomIDAndSessionIDStmt *sql.Stmt
}
// nolint:unused
func (s *keyBackupStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(keyBackupTableSchema)
if err != nil {
@ -78,6 +88,12 @@ func (s *keyBackupStatements) prepare(db *sql.DB) (err error) {
if s.selectKeysStmt, err = db.Prepare(selectKeysSQL); err != nil {
return
}
if s.selectKeysByRoomIDStmt, err = db.Prepare(selectKeysByRoomIDSQL); err != nil {
return
}
if s.selectKeysByRoomIDAndSessionIDStmt, err = db.Prepare(selectKeysByRoomIDAndSessionIDSQL); err != nil {
return
}
return
}
@ -109,11 +125,35 @@ func (s *keyBackupStatements) updateBackupKey(
func (s *keyBackupStatements) selectKeys(
ctx context.Context, txn *sql.Tx, userID, version string,
) (map[string]map[string]api.KeyBackupSession, error) {
result := make(map[string]map[string]api.KeyBackupSession)
rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version)
if err != nil {
return nil, err
}
return unpackKeys(ctx, rows)
}
func (s *keyBackupStatements) selectKeysByRoomID(
ctx context.Context, txn *sql.Tx, userID, version, roomID string,
) (map[string]map[string]api.KeyBackupSession, error) {
rows, err := txn.Stmt(s.selectKeysByRoomIDStmt).QueryContext(ctx, userID, version, roomID)
if err != nil {
return nil, err
}
return unpackKeys(ctx, rows)
}
func (s *keyBackupStatements) selectKeysByRoomIDAndSessionID(
ctx context.Context, txn *sql.Tx, userID, version, roomID, sessionID string,
) (map[string]map[string]api.KeyBackupSession, error) {
rows, err := txn.Stmt(s.selectKeysByRoomIDAndSessionIDStmt).QueryContext(ctx, userID, version, roomID, sessionID)
if err != nil {
return nil, err
}
return unpackKeys(ctx, rows)
}
func unpackKeys(ctx context.Context, rows *sql.Rows) (map[string]map[string]api.KeyBackupSession, error) {
result := make(map[string]map[string]api.KeyBackupSession)
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt.Close failed")
for rows.Next() {
var key api.InternalKeyBackupSession

View file

@ -65,7 +65,6 @@ type keyBackupVersionStatements struct {
updateKeyBackupETagStmt *sql.Stmt
}
// nolint:unused
func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(keyBackupVersionTableSchema)
if err != nil {

View file

@ -100,13 +100,12 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err = d.openIDTokens.prepare(db, serverName); err != nil {
return nil, err
}
/*
if err = d.keyBackupVersions.prepare(db); err != nil {
return nil, err
}
if err = d.keyBackups.prepare(db); err != nil {
return nil, err
} */
if err = d.keyBackupVersions.prepare(db); err != nil {
return nil, err
}
if err = d.keyBackups.prepare(db); err != nil {
return nil, err
}
return d, nil
}
@ -459,6 +458,37 @@ func (d *Database) GetKeyBackup(
return
}
func (d *Database) GetBackupKeys(
ctx context.Context, version, userID, filterRoomID, filterSessionID string,
) (result map[string]map[string]api.KeyBackupSession, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
if filterSessionID != "" {
result, err = d.keyBackups.selectKeysByRoomIDAndSessionID(ctx, txn, userID, version, filterRoomID, filterSessionID)
return err
}
if filterRoomID != "" {
result, err = d.keyBackups.selectKeysByRoomID(ctx, txn, userID, version, filterRoomID)
return err
}
result, err = d.keyBackups.selectKeys(ctx, txn, userID, version)
return err
})
return
}
func (d *Database) CountBackupKeys(
ctx context.Context, version, userID string,
) (count int64, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
count, err = d.keyBackups.countKeys(ctx, txn, userID, version)
if err != nil {
return err
}
return nil
})
return
}
// nolint:nakedret
func (d *Database) UpsertBackupKeys(
ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession,
@ -486,7 +516,7 @@ func (d *Database) UpsertBackupKeys(
if existingRoom != nil {
existingSession, ok := existingRoom[newKey.SessionID]
if ok {
if shouldReplaceRoomKey(existingSession, newKey.KeyBackupSession) {
if existingSession.ShouldReplaceRoomKey(&newKey.KeyBackupSession) {
err = d.keyBackups.updateBackupKey(ctx, txn, userID, version, newKey)
changed = true
if err != nil {
@ -531,22 +561,3 @@ func (d *Database) UpsertBackupKeys(
})
return
}
// TODO FIXME XXX : This logic really shouldn't live in the storage layer, but I don't know where else is sensible which won't
// create circular import loops
func shouldReplaceRoomKey(existing, uploaded api.KeyBackupSession) bool {
// https://spec.matrix.org/unstable/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2
// "if the keys have different values for is_verified, then it will keep the key that has is_verified set to true"
if uploaded.IsVerified && !existing.IsVerified {
return true
}
// "if they have the same values for is_verified, then it will keep the key with a lower first_message_index"
if uploaded.FirstMessageIndex < existing.FirstMessageIndex {
return true
}
// "and finally, is is_verified and first_message_index are equal, then it will keep the key with a lower forwarded_count"
if uploaded.ForwardedCount < existing.ForwardedCount {
return true
}
return false
}