Key backups (1/2) : Add E2E session backup metadata tables (#1943)

* Initial key backup paths and userapi API

* Fix unit tests

* Add key backup table

* Glue REST API to database

* Linting

* use writer on sqlite
This commit is contained in:
kegsay 2021-07-27 12:47:32 +01:00 committed by GitHub
parent e3679799ea
commit 32538640db
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 712 additions and 0 deletions

View file

@ -54,6 +54,12 @@ type Database interface {
DeactivateAccount(ctx context.Context, localpart string) (err error)
CreateOpenIDToken(ctx context.Context, token, localpart string) (exp int64, err error)
GetOpenIDTokenAttributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error)
// Key backups
CreateKeyBackup(ctx context.Context, userID, algorithm string, authData json.RawMessage) (version string, err error)
UpdateKeyBackupAuthData(ctx context.Context, userID, version string, authData json.RawMessage) (err error)
DeleteKeyBackup(ctx context.Context, userID, version string) (exists bool, err error)
GetKeyBackup(ctx context.Context, userID, version string) (versionResult, algorithm string, authData json.RawMessage, deleted bool, err error)
}
// Err3PIDInUse is the error returned when trying to save an association involving

View file

@ -0,0 +1,144 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strconv"
)
const keyBackupVersionTableSchema = `
CREATE SEQUENCE IF NOT EXISTS account_e2e_room_keys_versions_seq;
-- the metadata for each generation of encrypted e2e session backups
CREATE TABLE IF NOT EXISTS account_e2e_room_keys_versions (
user_id TEXT NOT NULL,
-- this means no 2 users will ever have the same version of e2e session backups which strictly
-- isn't necessary, but this is easy to do rather than SELECT MAX(version)+1.
version BIGINT DEFAULT nextval('account_e2e_room_keys_versions_seq'),
algorithm TEXT NOT NULL,
auth_data TEXT NOT NULL,
deleted SMALLINT DEFAULT 0 NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_e2e_room_keys_versions(user_id, version);
`
const insertKeyBackupSQL = "" +
"INSERT INTO account_e2e_room_keys_versions(user_id, algorithm, auth_data) VALUES ($1, $2, $3) RETURNING version"
const updateKeyBackupAuthDataSQL = "" + // TODO: do we need to WHERE algorithm = $3 as well?
"UPDATE account_e2e_room_keys_versions SET auth_data = $1 WHERE user_id = $2 AND version = $3"
const deleteKeyBackupSQL = "" +
"UPDATE account_e2e_room_keys_versions SET deleted=1 WHERE user_id = $1 AND version = $2"
const selectKeyBackupSQL = "" +
"SELECT algorithm, auth_data, deleted FROM account_e2e_room_keys_versions WHERE user_id = $1 AND version = $2"
const selectLatestVersionSQL = "" +
"SELECT MAX(version) FROM account_e2e_room_keys_versions WHERE user_id = $1"
type keyBackupVersionStatements struct {
insertKeyBackupStmt *sql.Stmt
updateKeyBackupAuthDataStmt *sql.Stmt
deleteKeyBackupStmt *sql.Stmt
selectKeyBackupStmt *sql.Stmt
selectLatestVersionStmt *sql.Stmt
}
func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(keyBackupVersionTableSchema)
if err != nil {
return
}
if s.insertKeyBackupStmt, err = db.Prepare(insertKeyBackupSQL); err != nil {
return
}
if s.updateKeyBackupAuthDataStmt, err = db.Prepare(updateKeyBackupAuthDataSQL); err != nil {
return
}
if s.deleteKeyBackupStmt, err = db.Prepare(deleteKeyBackupSQL); err != nil {
return
}
if s.selectKeyBackupStmt, err = db.Prepare(selectKeyBackupSQL); err != nil {
return
}
if s.selectLatestVersionStmt, err = db.Prepare(selectLatestVersionSQL); err != nil {
return
}
return
}
func (s *keyBackupVersionStatements) insertKeyBackup(
ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage,
) (version string, err error) {
var versionInt int64
err = txn.Stmt(s.insertKeyBackupStmt).QueryRowContext(ctx, userID, algorithm, string(authData)).Scan(&versionInt)
return strconv.FormatInt(versionInt, 10), err
}
func (s *keyBackupVersionStatements) updateKeyBackupAuthData(
ctx context.Context, txn *sql.Tx, userID, version string, authData json.RawMessage,
) error {
versionInt, err := strconv.ParseInt(version, 10, 64)
if err != nil {
return fmt.Errorf("invalid version")
}
_, err = txn.Stmt(s.updateKeyBackupAuthDataStmt).ExecContext(ctx, string(authData), userID, versionInt)
return err
}
func (s *keyBackupVersionStatements) deleteKeyBackup(
ctx context.Context, txn *sql.Tx, userID, version string,
) (bool, error) {
versionInt, err := strconv.ParseInt(version, 10, 64)
if err != nil {
return false, fmt.Errorf("invalid version")
}
result, err := txn.Stmt(s.deleteKeyBackupStmt).ExecContext(ctx, userID, versionInt)
if err != nil {
return false, err
}
ra, err := result.RowsAffected()
if err != nil {
return false, err
}
return ra == 1, nil
}
func (s *keyBackupVersionStatements) selectKeyBackup(
ctx context.Context, txn *sql.Tx, userID, version string,
) (versionResult, algorithm string, authData json.RawMessage, deleted bool, err error) {
var versionInt int64
if version == "" {
err = txn.Stmt(s.selectLatestVersionStmt).QueryRowContext(ctx, userID).Scan(&versionInt)
} else {
versionInt, err = strconv.ParseInt(version, 10, 64)
}
if err != nil {
return
}
versionResult = strconv.FormatInt(versionInt, 10)
var deletedInt int
var authDataStr string
err = txn.Stmt(s.selectKeyBackupStmt).QueryRowContext(ctx, userID, versionInt).Scan(&algorithm, &authDataStr, &deletedInt)
deleted = deletedInt == 1
authData = json.RawMessage(authDataStr)
return
}

View file

@ -45,6 +45,7 @@ type Database struct {
accountDatas accountDataStatements
threepids threepidStatements
openIDTokens tokenStatements
keyBackups keyBackupVersionStatements
serverName gomatrixserverlib.ServerName
bcryptCost int
openIDTokenLifetimeMS int64
@ -93,6 +94,9 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err = d.openIDTokens.prepare(db, serverName); err != nil {
return nil, err
}
if err = d.keyBackups.prepare(db); err != nil {
return nil, err
}
return d, nil
}
@ -368,3 +372,42 @@ func (d *Database) GetOpenIDTokenAttributes(
) (*api.OpenIDTokenAttributes, error) {
return d.openIDTokens.selectOpenIDTokenAtrributes(ctx, token)
}
func (d *Database) CreateKeyBackup(
ctx context.Context, userID, algorithm string, authData json.RawMessage,
) (version string, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
version, err = d.keyBackups.insertKeyBackup(ctx, txn, userID, algorithm, authData)
return err
})
return
}
func (d *Database) UpdateKeyBackupAuthData(
ctx context.Context, userID, version string, authData json.RawMessage,
) (err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.keyBackups.updateKeyBackupAuthData(ctx, txn, userID, version, authData)
})
return
}
func (d *Database) DeleteKeyBackup(
ctx context.Context, userID, version string,
) (exists bool, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
exists, err = d.keyBackups.deleteKeyBackup(ctx, txn, userID, version)
return err
})
return
}
func (d *Database) GetKeyBackup(
ctx context.Context, userID, version string,
) (versionResult, algorithm string, authData json.RawMessage, deleted bool, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
versionResult, algorithm, authData, deleted, err = d.keyBackups.selectKeyBackup(ctx, txn, userID, version)
return err
})
return
}

View file

@ -0,0 +1,142 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strconv"
)
const keyBackupVersionTableSchema = `
-- the metadata for each generation of encrypted e2e session backups
CREATE TABLE IF NOT EXISTS account_e2e_room_keys_versions (
user_id TEXT NOT NULL,
-- this means no 2 users will ever have the same version of e2e session backups which strictly
-- isn't necessary, but this is easy to do rather than SELECT MAX(version)+1.
version INTEGER PRIMARY KEY AUTOINCREMENT,
algorithm TEXT NOT NULL,
auth_data TEXT NOT NULL,
deleted INTEGER DEFAULT 0 NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS account_e2e_room_keys_versions_idx ON account_e2e_room_keys_versions(user_id, version);
`
const insertKeyBackupSQL = "" +
"INSERT INTO account_e2e_room_keys_versions(user_id, algorithm, auth_data) VALUES ($1, $2, $3) RETURNING version"
const updateKeyBackupAuthDataSQL = "" + // TODO: do we need to WHERE algorithm = $3 as well?
"UPDATE account_e2e_room_keys_versions SET auth_data = $1 WHERE user_id = $2 AND version = $3"
const deleteKeyBackupSQL = "" +
"UPDATE account_e2e_room_keys_versions SET deleted=1 WHERE user_id = $1 AND version = $2"
const selectKeyBackupSQL = "" +
"SELECT algorithm, auth_data, deleted FROM account_e2e_room_keys_versions WHERE user_id = $1 AND version = $2"
const selectLatestVersionSQL = "" +
"SELECT MAX(version) FROM account_e2e_room_keys_versions WHERE user_id = $1"
type keyBackupVersionStatements struct {
insertKeyBackupStmt *sql.Stmt
updateKeyBackupAuthDataStmt *sql.Stmt
deleteKeyBackupStmt *sql.Stmt
selectKeyBackupStmt *sql.Stmt
selectLatestVersionStmt *sql.Stmt
}
func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) {
_, err = db.Exec(keyBackupVersionTableSchema)
if err != nil {
return
}
if s.insertKeyBackupStmt, err = db.Prepare(insertKeyBackupSQL); err != nil {
return
}
if s.updateKeyBackupAuthDataStmt, err = db.Prepare(updateKeyBackupAuthDataSQL); err != nil {
return
}
if s.deleteKeyBackupStmt, err = db.Prepare(deleteKeyBackupSQL); err != nil {
return
}
if s.selectKeyBackupStmt, err = db.Prepare(selectKeyBackupSQL); err != nil {
return
}
if s.selectLatestVersionStmt, err = db.Prepare(selectLatestVersionSQL); err != nil {
return
}
return
}
func (s *keyBackupVersionStatements) insertKeyBackup(
ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage,
) (version string, err error) {
var versionInt int64
err = txn.Stmt(s.insertKeyBackupStmt).QueryRowContext(ctx, userID, algorithm, string(authData)).Scan(&versionInt)
return strconv.FormatInt(versionInt, 10), err
}
func (s *keyBackupVersionStatements) updateKeyBackupAuthData(
ctx context.Context, txn *sql.Tx, userID, version string, authData json.RawMessage,
) error {
versionInt, err := strconv.ParseInt(version, 10, 64)
if err != nil {
return fmt.Errorf("invalid version")
}
_, err = txn.Stmt(s.updateKeyBackupAuthDataStmt).ExecContext(ctx, string(authData), userID, versionInt)
return err
}
func (s *keyBackupVersionStatements) deleteKeyBackup(
ctx context.Context, txn *sql.Tx, userID, version string,
) (bool, error) {
versionInt, err := strconv.ParseInt(version, 10, 64)
if err != nil {
return false, fmt.Errorf("invalid version")
}
result, err := txn.Stmt(s.deleteKeyBackupStmt).ExecContext(ctx, userID, versionInt)
if err != nil {
return false, err
}
ra, err := result.RowsAffected()
if err != nil {
return false, err
}
return ra == 1, nil
}
func (s *keyBackupVersionStatements) selectKeyBackup(
ctx context.Context, txn *sql.Tx, userID, version string,
) (versionResult, algorithm string, authData json.RawMessage, deleted bool, err error) {
var versionInt int64
if version == "" {
err = txn.Stmt(s.selectLatestVersionStmt).QueryRowContext(ctx, userID).Scan(&versionInt)
} else {
versionInt, err = strconv.ParseInt(version, 10, 64)
}
if err != nil {
return
}
versionResult = strconv.FormatInt(versionInt, 10)
var deletedInt int
var authDataStr string
err = txn.Stmt(s.selectKeyBackupStmt).QueryRowContext(ctx, userID, versionInt).Scan(&algorithm, &authDataStr, &deletedInt)
deleted = deletedInt == 1
authData = json.RawMessage(authDataStr)
return
}

View file

@ -43,6 +43,7 @@ type Database struct {
accountDatas accountDataStatements
threepids threepidStatements
openIDTokens tokenStatements
keyBackups keyBackupVersionStatements
serverName gomatrixserverlib.ServerName
bcryptCost int
openIDTokenLifetimeMS int64
@ -97,6 +98,9 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
if err = d.openIDTokens.prepare(db, serverName); err != nil {
return nil, err
}
if err = d.keyBackups.prepare(db); err != nil {
return nil, err
}
return d, nil
}
@ -406,3 +410,42 @@ func (d *Database) GetOpenIDTokenAttributes(
) (*api.OpenIDTokenAttributes, error) {
return d.openIDTokens.selectOpenIDTokenAtrributes(ctx, token)
}
func (d *Database) CreateKeyBackup(
ctx context.Context, userID, algorithm string, authData json.RawMessage,
) (version string, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
version, err = d.keyBackups.insertKeyBackup(ctx, txn, userID, algorithm, authData)
return err
})
return
}
func (d *Database) UpdateKeyBackupAuthData(
ctx context.Context, userID, version string, authData json.RawMessage,
) (err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.keyBackups.updateKeyBackupAuthData(ctx, txn, userID, version, authData)
})
return
}
func (d *Database) DeleteKeyBackup(
ctx context.Context, userID, version string,
) (exists bool, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
exists, err = d.keyBackups.deleteKeyBackup(ctx, txn, userID, version)
return err
})
return
}
func (d *Database) GetKeyBackup(
ctx context.Context, userID, version string,
) (versionResult, algorithm string, authData json.RawMessage, deleted bool, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
versionResult, algorithm, authData, deleted, err = d.keyBackups.selectKeyBackup(ctx, txn, userID, version)
return err
})
return
}