mirror of
https://github.com/hoernschen/dendrite.git
synced 2025-07-31 13:22:46 +00:00
Merge keyserver & userapi (#2972)
As discussed yesterday, a first draft of merging the keyserver and the userapi.
This commit is contained in:
parent
bd6f0c14e5
commit
4594233f89
107 changed files with 1730 additions and 1863 deletions
|
@ -90,7 +90,7 @@ type KeyBackup interface {
|
|||
|
||||
type LoginToken interface {
|
||||
// CreateLoginToken generates a token, stores and returns it. The lifetime is
|
||||
// determined by the loginTokenLifetime given to the Database constructor.
|
||||
// determined by the loginTokenLifetime given to the UserDatabase constructor.
|
||||
CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error)
|
||||
|
||||
// RemoveLoginToken removes the named token (and may clean up other expired tokens).
|
||||
|
@ -130,7 +130,7 @@ type Notification interface {
|
|||
DeleteOldNotifications(ctx context.Context) error
|
||||
}
|
||||
|
||||
type Database interface {
|
||||
type UserDatabase interface {
|
||||
Account
|
||||
AccountData
|
||||
Device
|
||||
|
@ -144,6 +144,78 @@ type Database interface {
|
|||
ThreePID
|
||||
}
|
||||
|
||||
type KeyChangeDatabase interface {
|
||||
// StoreKeyChange stores key change metadata and returns the device change ID which represents the position in the /sync stream for this device change.
|
||||
// `userID` is the the user who has changed their keys in some way.
|
||||
StoreKeyChange(ctx context.Context, userID string) (int64, error)
|
||||
}
|
||||
|
||||
type KeyDatabase interface {
|
||||
KeyChangeDatabase
|
||||
// ExistingOneTimeKeys returns a map of keyIDWithAlgorithm to key JSON for the given parameters. If no keys exist with this combination
|
||||
// of user/device/key/algorithm 4-uple then it is omitted from the map. Returns an error when failing to communicate with the database.
|
||||
ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
|
||||
|
||||
// StoreOneTimeKeys persists the given one-time keys.
|
||||
StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error)
|
||||
|
||||
// OneTimeKeysCount returns a count of all OTKs for this device.
|
||||
OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)
|
||||
|
||||
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
|
||||
DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
|
||||
|
||||
// StoreLocalDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
|
||||
// for this (user, device).
|
||||
// The `StreamID` for each message is set on successful insertion. In the event the key already exists, the existing StreamID is set.
|
||||
// Returns an error if there was a problem storing the keys.
|
||||
StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error
|
||||
|
||||
// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
|
||||
// for this (user, device). Does not modify the stream ID for keys. User IDs in `clearUserIDs` will have all their device keys deleted prior
|
||||
// to insertion - use this when you have a complete snapshot of a user's keys in order to track device deletions correctly.
|
||||
StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error
|
||||
|
||||
// PrevIDsExists returns true if all prev IDs exist for this user.
|
||||
PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error)
|
||||
|
||||
// DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected.
|
||||
// If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice.
|
||||
DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error)
|
||||
|
||||
// DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying
|
||||
// cross-signing signatures relating to that device.
|
||||
DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error
|
||||
|
||||
// ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key
|
||||
// cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice.
|
||||
ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error)
|
||||
|
||||
// KeyChanges returns a list of user IDs who have modified their keys from the offset given (exclusive) to the offset given (inclusive).
|
||||
// A to offset of types.OffsetNewest means no upper limit.
|
||||
// Returns the offset of the latest key change.
|
||||
KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error)
|
||||
|
||||
// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
|
||||
// If no domains are given, all user IDs with stale device lists are returned.
|
||||
StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error)
|
||||
|
||||
// MarkDeviceListStale sets the stale bit for this user to isStale.
|
||||
MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error
|
||||
|
||||
CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error)
|
||||
CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error)
|
||||
CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error)
|
||||
|
||||
StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error
|
||||
StoreCrossSigningSigsForTarget(ctx context.Context, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error
|
||||
|
||||
DeleteStaleDeviceLists(
|
||||
ctx context.Context,
|
||||
userIDs []string,
|
||||
) error
|
||||
}
|
||||
|
||||
type Statistics interface {
|
||||
UserStatistics(ctx context.Context) (*types.UserStatistics, *types.DatabaseEngine, error)
|
||||
DailyRoomsMessages(ctx context.Context, serverName gomatrixserverlib.ServerName) (stats types.MessageStats, activeRooms, activeE2EERooms int64, err error)
|
||||
|
|
|
@ -78,7 +78,13 @@ func (s *accountDataStatements) InsertAccountData(
|
|||
roomID, dataType string, content json.RawMessage,
|
||||
) (err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertAccountDataStmt)
|
||||
_, err = stmt.ExecContext(ctx, localpart, serverName, roomID, dataType, content)
|
||||
// Empty/nil json.RawMessage is not interpreted as "nil", so use *json.RawMessage
|
||||
// when passing the data to trigger "NOT NULL" constraint
|
||||
var data *json.RawMessage
|
||||
if len(content) > 0 {
|
||||
data = &content
|
||||
}
|
||||
_, err = stmt.ExecContext(ctx, localpart, serverName, roomID, dataType, data)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
102
userapi/storage/postgres/cross_signing_keys_table.go
Normal file
102
userapi/storage/postgres/cross_signing_keys_table.go
Normal file
|
@ -0,0 +1,102 @@
|
|||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/userapi/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
var crossSigningKeysSchema = `
|
||||
CREATE TABLE IF NOT EXISTS keyserver_cross_signing_keys (
|
||||
user_id TEXT NOT NULL,
|
||||
key_type SMALLINT NOT NULL,
|
||||
key_data TEXT NOT NULL,
|
||||
PRIMARY KEY (user_id, key_type)
|
||||
);
|
||||
`
|
||||
|
||||
const selectCrossSigningKeysForUserSQL = "" +
|
||||
"SELECT key_type, key_data FROM keyserver_cross_signing_keys" +
|
||||
" WHERE user_id = $1"
|
||||
|
||||
const upsertCrossSigningKeysForUserSQL = "" +
|
||||
"INSERT INTO keyserver_cross_signing_keys (user_id, key_type, key_data)" +
|
||||
" VALUES($1, $2, $3)" +
|
||||
" ON CONFLICT (user_id, key_type) DO UPDATE SET key_data = $3"
|
||||
|
||||
type crossSigningKeysStatements struct {
|
||||
db *sql.DB
|
||||
selectCrossSigningKeysForUserStmt *sql.Stmt
|
||||
upsertCrossSigningKeysForUserStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error) {
|
||||
s := &crossSigningKeysStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(crossSigningKeysSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.selectCrossSigningKeysForUserStmt, selectCrossSigningKeysForUserSQL},
|
||||
{&s.upsertCrossSigningKeysForUserStmt, upsertCrossSigningKeysForUserSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser(
|
||||
ctx context.Context, txn *sql.Tx, userID string,
|
||||
) (r types.CrossSigningKeyMap, err error) {
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningKeysForUserStmt).QueryContext(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningKeysForUserStmt: rows.close() failed")
|
||||
r = types.CrossSigningKeyMap{}
|
||||
for rows.Next() {
|
||||
var keyTypeInt int16
|
||||
var keyData gomatrixserverlib.Base64Bytes
|
||||
if err := rows.Scan(&keyTypeInt, &keyData); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keyType, ok := types.KeyTypeIntToPurpose[keyTypeInt]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown key purpose int %d", keyTypeInt)
|
||||
}
|
||||
r[keyType] = keyData
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser(
|
||||
ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes,
|
||||
) error {
|
||||
keyTypeInt, ok := types.KeyTypePurposeToInt[keyType]
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown key purpose %q", keyType)
|
||||
}
|
||||
if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningKeysForUserStmt).ExecContext(ctx, userID, keyTypeInt, keyData); err != nil {
|
||||
return fmt.Errorf("s.upsertCrossSigningKeysForUserStmt: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
131
userapi/storage/postgres/cross_signing_sigs_table.go
Normal file
131
userapi/storage/postgres/cross_signing_sigs_table.go
Normal file
|
@ -0,0 +1,131 @@
|
|||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/postgres/deltas"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/userapi/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
var crossSigningSigsSchema = `
|
||||
CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs (
|
||||
origin_user_id TEXT NOT NULL,
|
||||
origin_key_id TEXT NOT NULL,
|
||||
target_user_id TEXT NOT NULL,
|
||||
target_key_id TEXT NOT NULL,
|
||||
signature TEXT NOT NULL,
|
||||
PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id);
|
||||
`
|
||||
|
||||
const selectCrossSigningSigsForTargetSQL = "" +
|
||||
"SELECT origin_user_id, origin_key_id, signature FROM keyserver_cross_signing_sigs" +
|
||||
" WHERE (origin_user_id = $1 OR origin_user_id = $2) AND target_user_id = $2 AND target_key_id = $3"
|
||||
|
||||
const upsertCrossSigningSigsForTargetSQL = "" +
|
||||
"INSERT INTO keyserver_cross_signing_sigs (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" +
|
||||
" VALUES($1, $2, $3, $4, $5)" +
|
||||
" ON CONFLICT (origin_user_id, origin_key_id, target_user_id, target_key_id) DO UPDATE SET signature = $5"
|
||||
|
||||
const deleteCrossSigningSigsForTargetSQL = "" +
|
||||
"DELETE FROM keyserver_cross_signing_sigs WHERE target_user_id=$1 AND target_key_id=$2"
|
||||
|
||||
type crossSigningSigsStatements struct {
|
||||
db *sql.DB
|
||||
selectCrossSigningSigsForTargetStmt *sql.Stmt
|
||||
upsertCrossSigningSigsForTargetStmt *sql.Stmt
|
||||
deleteCrossSigningSigsForTargetStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, error) {
|
||||
s := &crossSigningSigsStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(crossSigningSigsSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m := sqlutil.NewMigrator(db)
|
||||
m.AddMigrations(sqlutil.Migration{
|
||||
Version: "keyserver: cross signing signature indexes",
|
||||
Up: deltas.UpFixCrossSigningSignatureIndexes,
|
||||
})
|
||||
if err = m.Up(context.Background()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.selectCrossSigningSigsForTargetStmt, selectCrossSigningSigsForTargetSQL},
|
||||
{&s.upsertCrossSigningSigsForTargetStmt, upsertCrossSigningSigsForTargetSQL},
|
||||
{&s.deleteCrossSigningSigsForTargetStmt, deleteCrossSigningSigsForTargetSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget(
|
||||
ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID,
|
||||
) (r types.CrossSigningSigMap, err error) {
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningSigsForTargetStmt).QueryContext(ctx, originUserID, targetUserID, targetKeyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningSigsForTargetStmt: rows.close() failed")
|
||||
r = types.CrossSigningSigMap{}
|
||||
for rows.Next() {
|
||||
var userID string
|
||||
var keyID gomatrixserverlib.KeyID
|
||||
var signature gomatrixserverlib.Base64Bytes
|
||||
if err := rows.Scan(&userID, &keyID, &signature); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, ok := r[userID]; !ok {
|
||||
r[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
|
||||
}
|
||||
r[userID][keyID] = signature
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *crossSigningSigsStatements) UpsertCrossSigningSigsForTarget(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
originUserID string, originKeyID gomatrixserverlib.KeyID,
|
||||
targetUserID string, targetKeyID gomatrixserverlib.KeyID,
|
||||
signature gomatrixserverlib.Base64Bytes,
|
||||
) error {
|
||||
if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningSigsForTargetStmt).ExecContext(ctx, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil {
|
||||
return fmt.Errorf("s.upsertCrossSigningSigsForTargetStmt: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *crossSigningSigsStatements) DeleteCrossSigningSigsForTarget(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
targetUserID string, targetKeyID gomatrixserverlib.KeyID,
|
||||
) error {
|
||||
if _, err := sqlutil.TxStmt(txn, s.deleteCrossSigningSigsForTargetStmt).ExecContext(ctx, targetUserID, targetKeyID); err != nil {
|
||||
return fmt.Errorf("s.deleteCrossSigningSigsForTargetStmt: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,69 @@
|
|||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package deltas
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func UpRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error {
|
||||
// start counting from the last max offset, else 0. We need to do a count(*) first to see if there
|
||||
// even are entries in this table to know if we can query for log_offset. Without the count then
|
||||
// the query to SELECT the max log offset fails on new Dendrite instances as log_offset doesn't
|
||||
// exist on that table. Even though we discard the error, the txn is tainted and gets aborted :/
|
||||
var count int
|
||||
_ = tx.QueryRowContext(ctx, `SELECT count(*) FROM keyserver_key_changes`).Scan(&count)
|
||||
if count > 0 {
|
||||
var maxOffset int64
|
||||
_ = tx.QueryRowContext(ctx, `SELECT coalesce(MAX(log_offset), 0) AS offset FROM keyserver_key_changes`).Scan(&maxOffset)
|
||||
if _, err := tx.ExecContext(ctx, fmt.Sprintf(`CREATE SEQUENCE IF NOT EXISTS keyserver_key_changes_seq START %d`, maxOffset)); err != nil {
|
||||
return fmt.Errorf("failed to CREATE SEQUENCE for key changes, starting at %d: %s", maxOffset, err)
|
||||
}
|
||||
}
|
||||
|
||||
_, err := tx.ExecContext(ctx, `
|
||||
-- make the new table
|
||||
DROP TABLE IF EXISTS keyserver_key_changes;
|
||||
CREATE TABLE IF NOT EXISTS keyserver_key_changes (
|
||||
change_id BIGINT PRIMARY KEY DEFAULT nextval('keyserver_key_changes_seq'),
|
||||
user_id TEXT NOT NULL,
|
||||
CONSTRAINT keyserver_key_changes_unique_per_user UNIQUE (user_id)
|
||||
);
|
||||
`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute upgrade: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func DownRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error {
|
||||
_, err := tx.ExecContext(ctx, `
|
||||
-- Drop all data and revert back, we can't keep the data as Kafka offsets determine the numbers
|
||||
DROP SEQUENCE IF EXISTS keyserver_key_changes_seq;
|
||||
DROP TABLE IF EXISTS keyserver_key_changes;
|
||||
CREATE TABLE IF NOT EXISTS keyserver_key_changes (
|
||||
partition BIGINT NOT NULL,
|
||||
log_offset BIGINT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
CONSTRAINT keyserver_key_changes_unique UNIQUE (partition, log_offset)
|
||||
);
|
||||
`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute downgrade: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,47 @@
|
|||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package deltas
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func UpFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error {
|
||||
_, err := tx.ExecContext(ctx, `
|
||||
ALTER TABLE keyserver_cross_signing_sigs DROP CONSTRAINT keyserver_cross_signing_sigs_pkey;
|
||||
ALTER TABLE keyserver_cross_signing_sigs ADD PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id);
|
||||
`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute upgrade: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func DownFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error {
|
||||
_, err := tx.ExecContext(ctx, `
|
||||
ALTER TABLE keyserver_cross_signing_sigs DROP CONSTRAINT keyserver_cross_signing_sigs_pkey;
|
||||
ALTER TABLE keyserver_cross_signing_sigs ADD PRIMARY KEY (origin_user_id, target_user_id, target_key_id);
|
||||
|
||||
DROP INDEX IF EXISTS keyserver_cross_signing_sigs_idx;
|
||||
`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute downgrade: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
213
userapi/storage/postgres/device_keys_table.go
Normal file
213
userapi/storage/postgres/device_keys_table.go
Normal file
|
@ -0,0 +1,213 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
)
|
||||
|
||||
var deviceKeysSchema = `
|
||||
-- Stores device keys for users
|
||||
CREATE TABLE IF NOT EXISTS keyserver_device_keys (
|
||||
user_id TEXT NOT NULL,
|
||||
device_id TEXT NOT NULL,
|
||||
ts_added_secs BIGINT NOT NULL,
|
||||
key_json TEXT NOT NULL,
|
||||
-- the stream ID of this key, scoped per-user. This gets updated when the device key changes.
|
||||
-- This means we do not store an unbounded append-only log of device keys, which is not actually
|
||||
-- required in the spec because in the event of a missed update the server fetches the entire
|
||||
-- current set of keys rather than trying to 'fast-forward' or catchup missing stream IDs.
|
||||
stream_id BIGINT NOT NULL,
|
||||
display_name TEXT,
|
||||
-- Clobber based on tuple of user/device.
|
||||
CONSTRAINT keyserver_device_keys_unique UNIQUE (user_id, device_id)
|
||||
);
|
||||
`
|
||||
|
||||
const upsertDeviceKeysSQL = "" +
|
||||
"INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" +
|
||||
" VALUES ($1, $2, $3, $4, $5, $6)" +
|
||||
" ON CONFLICT ON CONSTRAINT keyserver_device_keys_unique" +
|
||||
" DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6"
|
||||
|
||||
const selectDeviceKeysSQL = "" +
|
||||
"SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
|
||||
|
||||
const selectBatchDeviceKeysSQL = "" +
|
||||
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''"
|
||||
|
||||
const selectBatchDeviceKeysWithEmptiesSQL = "" +
|
||||
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1"
|
||||
|
||||
const selectMaxStreamForUserSQL = "" +
|
||||
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
|
||||
|
||||
const countStreamIDsForUserSQL = "" +
|
||||
"SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id = ANY($2)"
|
||||
|
||||
const deleteDeviceKeysSQL = "" +
|
||||
"DELETE FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
|
||||
|
||||
const deleteAllDeviceKeysSQL = "" +
|
||||
"DELETE FROM keyserver_device_keys WHERE user_id=$1"
|
||||
|
||||
type deviceKeysStatements struct {
|
||||
db *sql.DB
|
||||
upsertDeviceKeysStmt *sql.Stmt
|
||||
selectDeviceKeysStmt *sql.Stmt
|
||||
selectBatchDeviceKeysStmt *sql.Stmt
|
||||
selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt
|
||||
selectMaxStreamForUserStmt *sql.Stmt
|
||||
countStreamIDsForUserStmt *sql.Stmt
|
||||
deleteDeviceKeysStmt *sql.Stmt
|
||||
deleteAllDeviceKeysStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
||||
s := &deviceKeysStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(deviceKeysSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.upsertDeviceKeysStmt, upsertDeviceKeysSQL},
|
||||
{&s.selectDeviceKeysStmt, selectDeviceKeysSQL},
|
||||
{&s.selectBatchDeviceKeysStmt, selectBatchDeviceKeysSQL},
|
||||
{&s.selectBatchDeviceKeysWithEmptiesStmt, selectBatchDeviceKeysWithEmptiesSQL},
|
||||
{&s.selectMaxStreamForUserStmt, selectMaxStreamForUserSQL},
|
||||
{&s.countStreamIDsForUserStmt, countStreamIDsForUserSQL},
|
||||
{&s.deleteDeviceKeysStmt, deleteDeviceKeysSQL},
|
||||
{&s.deleteAllDeviceKeysStmt, deleteAllDeviceKeysSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
|
||||
for i, key := range keys {
|
||||
var keyJSONStr string
|
||||
var streamID int64
|
||||
var displayName sql.NullString
|
||||
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
// this will be '' when there is no device
|
||||
keys[i].Type = api.TypeDeviceKeyUpdate
|
||||
keys[i].KeyJSON = []byte(keyJSONStr)
|
||||
keys[i].StreamID = streamID
|
||||
if displayName.Valid {
|
||||
keys[i].DisplayName = displayName.String
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error) {
|
||||
// nullable if there are no results
|
||||
var nullStream sql.NullInt64
|
||||
err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
|
||||
if err == sql.ErrNoRows {
|
||||
err = nil
|
||||
}
|
||||
if nullStream.Valid {
|
||||
streamID = nullStream.Int64
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) {
|
||||
// nullable if there are no results
|
||||
var count sql.NullInt32
|
||||
err := s.countStreamIDsForUserStmt.QueryRowContext(ctx, userID, pq.Int64Array(streamIDs)).Scan(&count)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if count.Valid {
|
||||
return int(count.Int32), nil
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
|
||||
for _, key := range keys {
|
||||
now := time.Now().Unix()
|
||||
_, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext(
|
||||
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.deleteDeviceKeysStmt).ExecContext(ctx, userID, deviceID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
|
||||
var stmt *sql.Stmt
|
||||
if includeEmpty {
|
||||
stmt = s.selectBatchDeviceKeysWithEmptiesStmt
|
||||
} else {
|
||||
stmt = s.selectBatchDeviceKeysStmt
|
||||
}
|
||||
rows, err := stmt.QueryContext(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed")
|
||||
deviceIDMap := make(map[string]bool)
|
||||
for _, d := range deviceIDs {
|
||||
deviceIDMap[d] = true
|
||||
}
|
||||
var result []api.DeviceMessage
|
||||
var displayName sql.NullString
|
||||
for rows.Next() {
|
||||
dk := api.DeviceMessage{
|
||||
Type: api.TypeDeviceKeyUpdate,
|
||||
DeviceKeys: &api.DeviceKeys{
|
||||
UserID: userID,
|
||||
},
|
||||
}
|
||||
if err := rows.Scan(&dk.DeviceID, &dk.KeyJSON, &dk.StreamID, &displayName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if displayName.Valid {
|
||||
dk.DisplayName = displayName.String
|
||||
}
|
||||
// include the key if we want all keys (no device) or it was asked
|
||||
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
|
||||
result = append(result, dk)
|
||||
}
|
||||
}
|
||||
return result, rows.Err()
|
||||
}
|
|
@ -160,7 +160,7 @@ func (s *devicesStatements) InsertDevice(
|
|||
if err := stmt.QueryRowContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, createdTimeMS, ipAddr, userAgent).Scan(&sessionID); err != nil {
|
||||
return nil, fmt.Errorf("insertDeviceStmt: %w", err)
|
||||
}
|
||||
return &api.Device{
|
||||
dev := &api.Device{
|
||||
ID: id,
|
||||
UserID: userutil.MakeUserID(localpart, serverName),
|
||||
AccessToken: accessToken,
|
||||
|
@ -168,7 +168,11 @@ func (s *devicesStatements) InsertDevice(
|
|||
LastSeenTS: createdTimeMS,
|
||||
LastSeenIP: ipAddr,
|
||||
UserAgent: userAgent,
|
||||
}, nil
|
||||
}
|
||||
if displayName != nil {
|
||||
dev.DisplayName = *displayName
|
||||
}
|
||||
return dev, nil
|
||||
}
|
||||
|
||||
func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn *sql.Tx, id,
|
||||
|
|
|
@ -52,7 +52,7 @@ const updateBackupKeySQL = "" +
|
|||
const countKeysSQL = "" +
|
||||
"SELECT COUNT(*) FROM userapi_key_backups WHERE user_id = $1 AND version = $2"
|
||||
|
||||
const selectKeysSQL = "" +
|
||||
const selectBackupKeysSQL = "" +
|
||||
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " +
|
||||
"WHERE user_id = $1 AND version = $2"
|
||||
|
||||
|
@ -83,7 +83,7 @@ func NewPostgresKeyBackupTable(db *sql.DB) (tables.KeyBackupTable, error) {
|
|||
{&s.insertBackupKeyStmt, insertBackupKeySQL},
|
||||
{&s.updateBackupKeyStmt, updateBackupKeySQL},
|
||||
{&s.countKeysStmt, countKeysSQL},
|
||||
{&s.selectKeysStmt, selectKeysSQL},
|
||||
{&s.selectKeysStmt, selectBackupKeysSQL},
|
||||
{&s.selectKeysByRoomIDStmt, selectKeysByRoomIDSQL},
|
||||
{&s.selectKeysByRoomIDAndSessionIDStmt, selectKeysByRoomIDAndSessionIDSQL},
|
||||
}.Prepare(db)
|
||||
|
|
127
userapi/storage/postgres/key_changes_table.go
Normal file
127
userapi/storage/postgres/key_changes_table.go
Normal file
|
@ -0,0 +1,127 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/postgres/deltas"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
)
|
||||
|
||||
var keyChangesSchema = `
|
||||
-- Stores key change information about users. Used to determine when to send updated device lists to clients.
|
||||
CREATE SEQUENCE IF NOT EXISTS keyserver_key_changes_seq;
|
||||
CREATE TABLE IF NOT EXISTS keyserver_key_changes (
|
||||
change_id BIGINT PRIMARY KEY DEFAULT nextval('keyserver_key_changes_seq'),
|
||||
user_id TEXT NOT NULL,
|
||||
CONSTRAINT keyserver_key_changes_unique_per_user UNIQUE (user_id)
|
||||
);
|
||||
`
|
||||
|
||||
// Replace based on user ID. We don't care how many times the user's keys have changed, only that they
|
||||
// have changed, hence we can just keep bumping the change ID for this user.
|
||||
const upsertKeyChangeSQL = "" +
|
||||
"INSERT INTO keyserver_key_changes (user_id)" +
|
||||
" VALUES ($1)" +
|
||||
" ON CONFLICT ON CONSTRAINT keyserver_key_changes_unique_per_user" +
|
||||
" DO UPDATE SET change_id = nextval('keyserver_key_changes_seq')" +
|
||||
" RETURNING change_id"
|
||||
|
||||
const selectKeyChangesSQL = "" +
|
||||
"SELECT user_id, change_id FROM keyserver_key_changes WHERE change_id > $1 AND change_id <= $2"
|
||||
|
||||
type keyChangesStatements struct {
|
||||
db *sql.DB
|
||||
upsertKeyChangeStmt *sql.Stmt
|
||||
selectKeyChangesStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
|
||||
s := &keyChangesStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(keyChangesSchema)
|
||||
if err != nil {
|
||||
return s, err
|
||||
}
|
||||
|
||||
if err = executeMigration(context.Background(), db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.upsertKeyChangeStmt, upsertKeyChangeSQL},
|
||||
{&s.selectKeyChangesStmt, selectKeyChangesSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func executeMigration(ctx context.Context, db *sql.DB) error {
|
||||
// TODO: Remove when we are sure we are not having goose artefacts in the db
|
||||
// This forces an error, which indicates the migration is already applied, since the
|
||||
// column partition was removed from the table
|
||||
migrationName := "keyserver: refactor key changes"
|
||||
|
||||
var cName string
|
||||
err := db.QueryRowContext(ctx, "select column_name from information_schema.columns where table_name = 'keyserver_key_changes' AND column_name = 'partition'").Scan(&cName)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) { // migration was already executed, as the column was removed
|
||||
if err = sqlutil.InsertMigration(ctx, db, migrationName); err != nil {
|
||||
return fmt.Errorf("unable to manually insert migration '%s': %w", migrationName, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
m := sqlutil.NewMigrator(db)
|
||||
m.AddMigrations(sqlutil.Migration{
|
||||
Version: migrationName,
|
||||
Up: deltas.UpRefactorKeyChanges,
|
||||
})
|
||||
|
||||
return m.Up(ctx)
|
||||
}
|
||||
|
||||
func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, userID string) (changeID int64, err error) {
|
||||
err = s.upsertKeyChangeStmt.QueryRowContext(ctx, userID).Scan(&changeID)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *keyChangesStatements) SelectKeyChanges(
|
||||
ctx context.Context, fromOffset, toOffset int64,
|
||||
) (userIDs []string, latestOffset int64, err error) {
|
||||
latestOffset = fromOffset
|
||||
rows, err := s.selectKeyChangesStmt.QueryContext(ctx, fromOffset, toOffset)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectKeyChangesStmt: rows.close() failed")
|
||||
for rows.Next() {
|
||||
var userID string
|
||||
var offset int64
|
||||
if err := rows.Scan(&userID, &offset); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if offset > latestOffset {
|
||||
latestOffset = offset
|
||||
}
|
||||
userIDs = append(userIDs, userID)
|
||||
}
|
||||
return
|
||||
}
|
194
userapi/storage/postgres/one_time_keys_table.go
Normal file
194
userapi/storage/postgres/one_time_keys_table.go
Normal file
|
@ -0,0 +1,194 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
)
|
||||
|
||||
var oneTimeKeysSchema = `
|
||||
-- Stores one-time public keys for users
|
||||
CREATE TABLE IF NOT EXISTS keyserver_one_time_keys (
|
||||
user_id TEXT NOT NULL,
|
||||
device_id TEXT NOT NULL,
|
||||
key_id TEXT NOT NULL,
|
||||
algorithm TEXT NOT NULL,
|
||||
ts_added_secs BIGINT NOT NULL,
|
||||
key_json TEXT NOT NULL,
|
||||
-- Clobber based on 4-uple of user/device/key/algorithm.
|
||||
CONSTRAINT keyserver_one_time_keys_unique UNIQUE (user_id, device_id, key_id, algorithm)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS keyserver_one_time_keys_idx ON keyserver_one_time_keys (user_id, device_id);
|
||||
`
|
||||
|
||||
const upsertKeysSQL = "" +
|
||||
"INSERT INTO keyserver_one_time_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json)" +
|
||||
" VALUES ($1, $2, $3, $4, $5, $6)" +
|
||||
" ON CONFLICT ON CONSTRAINT keyserver_one_time_keys_unique" +
|
||||
" DO UPDATE SET key_json = $6"
|
||||
|
||||
const selectOneTimeKeysSQL = "" +
|
||||
"SELECT concat(algorithm, ':', key_id) as algorithmwithid, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 AND concat(algorithm, ':', key_id) = ANY($3);"
|
||||
|
||||
const selectKeysCountSQL = "" +
|
||||
"SELECT algorithm, COUNT(key_id) FROM " +
|
||||
" (SELECT algorithm, key_id FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 LIMIT 100)" +
|
||||
" x GROUP BY algorithm"
|
||||
|
||||
const deleteOneTimeKeySQL = "" +
|
||||
"DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 AND key_id = $4"
|
||||
|
||||
const selectKeyByAlgorithmSQL = "" +
|
||||
"SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1"
|
||||
|
||||
const deleteOneTimeKeysSQL = "" +
|
||||
"DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2"
|
||||
|
||||
type oneTimeKeysStatements struct {
|
||||
db *sql.DB
|
||||
upsertKeysStmt *sql.Stmt
|
||||
selectKeysStmt *sql.Stmt
|
||||
selectKeysCountStmt *sql.Stmt
|
||||
selectKeyByAlgorithmStmt *sql.Stmt
|
||||
deleteOneTimeKeyStmt *sql.Stmt
|
||||
deleteOneTimeKeysStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
|
||||
s := &oneTimeKeysStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(oneTimeKeysSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.upsertKeysStmt, upsertKeysSQL},
|
||||
{&s.selectKeysStmt, selectOneTimeKeysSQL},
|
||||
{&s.selectKeysCountStmt, selectKeysCountSQL},
|
||||
{&s.selectKeyByAlgorithmStmt, selectKeyByAlgorithmSQL},
|
||||
{&s.deleteOneTimeKeyStmt, deleteOneTimeKeySQL},
|
||||
{&s.deleteOneTimeKeysStmt, deleteOneTimeKeysSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
|
||||
rows, err := s.selectKeysStmt.QueryContext(ctx, userID, deviceID, pq.Array(keyIDsWithAlgorithms))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt: rows.close() failed")
|
||||
|
||||
result := make(map[string]json.RawMessage)
|
||||
var (
|
||||
algorithmWithID string
|
||||
keyJSONStr string
|
||||
)
|
||||
for rows.Next() {
|
||||
if err := rows.Scan(&algorithmWithID, &keyJSONStr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result[algorithmWithID] = json.RawMessage(keyJSONStr)
|
||||
}
|
||||
return result, rows.Err()
|
||||
}
|
||||
|
||||
func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) {
|
||||
counts := &api.OneTimeKeysCount{
|
||||
DeviceID: deviceID,
|
||||
UserID: userID,
|
||||
KeyCount: make(map[string]int),
|
||||
}
|
||||
rows, err := s.selectKeysCountStmt.QueryContext(ctx, userID, deviceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
|
||||
for rows.Next() {
|
||||
var algorithm string
|
||||
var count int
|
||||
if err = rows.Scan(&algorithm, &count); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
counts.KeyCount[algorithm] = count
|
||||
}
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, txn *sql.Tx, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) {
|
||||
now := time.Now().Unix()
|
||||
counts := &api.OneTimeKeysCount{
|
||||
DeviceID: keys.DeviceID,
|
||||
UserID: keys.UserID,
|
||||
KeyCount: make(map[string]int),
|
||||
}
|
||||
for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
|
||||
algo, keyID := keys.Split(keyIDWithAlgo)
|
||||
_, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext(
|
||||
ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
|
||||
for rows.Next() {
|
||||
var algorithm string
|
||||
var count int
|
||||
if err = rows.Scan(&algorithm, &count); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
counts.KeyCount[algorithm] = count
|
||||
}
|
||||
|
||||
return counts, rows.Err()
|
||||
}
|
||||
|
||||
func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey(
|
||||
ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string,
|
||||
) (map[string]json.RawMessage, error) {
|
||||
var keyID string
|
||||
var keyJSON string
|
||||
err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
_, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
|
||||
return map[string]json.RawMessage{
|
||||
algorithm + ":" + keyID: json.RawMessage(keyJSON),
|
||||
}, err
|
||||
}
|
||||
|
||||
func (s *oneTimeKeysStatements) DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.deleteOneTimeKeysStmt).ExecContext(ctx, userID, deviceID)
|
||||
return err
|
||||
}
|
131
userapi/storage/postgres/stale_device_lists.go
Normal file
131
userapi/storage/postgres/stale_device_lists.go
Normal file
|
@ -0,0 +1,131 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
var staleDeviceListsSchema = `
|
||||
-- Stores whether a user's device lists are stale or not.
|
||||
CREATE TABLE IF NOT EXISTS keyserver_stale_device_lists (
|
||||
user_id TEXT PRIMARY KEY NOT NULL,
|
||||
domain TEXT NOT NULL,
|
||||
is_stale BOOLEAN NOT NULL,
|
||||
ts_added_secs BIGINT NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS keyserver_stale_device_lists_idx ON keyserver_stale_device_lists (domain, is_stale);
|
||||
`
|
||||
|
||||
const upsertStaleDeviceListSQL = "" +
|
||||
"INSERT INTO keyserver_stale_device_lists (user_id, domain, is_stale, ts_added_secs)" +
|
||||
" VALUES ($1, $2, $3, $4)" +
|
||||
" ON CONFLICT (user_id)" +
|
||||
" DO UPDATE SET is_stale = $3, ts_added_secs = $4"
|
||||
|
||||
const selectStaleDeviceListsWithDomainsSQL = "" +
|
||||
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2 ORDER BY ts_added_secs DESC"
|
||||
|
||||
const selectStaleDeviceListsSQL = "" +
|
||||
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 ORDER BY ts_added_secs DESC"
|
||||
|
||||
const deleteStaleDevicesSQL = "" +
|
||||
"DELETE FROM keyserver_stale_device_lists WHERE user_id = ANY($1)"
|
||||
|
||||
type staleDeviceListsStatements struct {
|
||||
upsertStaleDeviceListStmt *sql.Stmt
|
||||
selectStaleDeviceListsWithDomainsStmt *sql.Stmt
|
||||
selectStaleDeviceListsStmt *sql.Stmt
|
||||
deleteStaleDeviceListsStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) {
|
||||
s := &staleDeviceListsStatements{}
|
||||
_, err := db.Exec(staleDeviceListsSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.upsertStaleDeviceListStmt, upsertStaleDeviceListSQL},
|
||||
{&s.selectStaleDeviceListsStmt, selectStaleDeviceListsSQL},
|
||||
{&s.selectStaleDeviceListsWithDomainsStmt, selectStaleDeviceListsWithDomainsSQL},
|
||||
{&s.deleteStaleDeviceListsStmt, deleteStaleDevicesSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error {
|
||||
_, domain, err := gomatrixserverlib.SplitID('@', userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, gomatrixserverlib.AsTimestamp(time.Now()))
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
|
||||
// we only query for 1 domain or all domains so optimise for those use cases
|
||||
if len(domains) == 0 {
|
||||
rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return rowsToUserIDs(ctx, rows)
|
||||
}
|
||||
var result []string
|
||||
for _, domain := range domains {
|
||||
rows, err := s.selectStaleDeviceListsWithDomainsStmt.QueryContext(ctx, true, string(domain))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userIDs, err := rowsToUserIDs(ctx, rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, userIDs...)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// DeleteStaleDeviceLists removes users from stale device lists
|
||||
func (s *staleDeviceListsStatements) DeleteStaleDeviceLists(
|
||||
ctx context.Context, txn *sql.Tx, userIDs []string,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteStaleDeviceListsStmt)
|
||||
_, err := stmt.ExecContext(ctx, pq.Array(userIDs))
|
||||
return err
|
||||
}
|
||||
|
||||
func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) {
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed")
|
||||
for rows.Next() {
|
||||
var userID string
|
||||
if err := rows.Scan(&userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, userID)
|
||||
}
|
||||
return result, rows.Err()
|
||||
}
|
|
@ -136,3 +136,44 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
|
|||
OpenIDTokenLifetimeMS: openIDTokenLifetimeMS,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func NewKeyDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.KeyDatabase, error) {
|
||||
db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
otk, err := NewPostgresOneTimeKeysTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dk, err := NewPostgresDeviceKeysTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
kc, err := NewPostgresKeyChangesTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sdl, err := NewPostgresStaleDeviceListsTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
csk, err := NewPostgresCrossSigningKeysTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
css, err := NewPostgresCrossSigningSigsTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &shared.KeyDatabase{
|
||||
OneTimeKeysTable: otk,
|
||||
DeviceKeysTable: dk,
|
||||
KeyChangesTable: kc,
|
||||
StaleDeviceListsTable: sdl,
|
||||
CrossSigningKeysTable: csk,
|
||||
CrossSigningSigsTable: css,
|
||||
Writer: writer,
|
||||
}, nil
|
||||
}
|
||||
|
|
|
@ -59,6 +59,17 @@ type Database struct {
|
|||
OpenIDTokenLifetimeMS int64
|
||||
}
|
||||
|
||||
type KeyDatabase struct {
|
||||
OneTimeKeysTable tables.OneTimeKeys
|
||||
DeviceKeysTable tables.DeviceKeys
|
||||
KeyChangesTable tables.KeyChanges
|
||||
StaleDeviceListsTable tables.StaleDeviceLists
|
||||
CrossSigningKeysTable tables.CrossSigningKeys
|
||||
CrossSigningSigsTable tables.CrossSigningSigs
|
||||
DB *sql.DB
|
||||
Writer sqlutil.Writer
|
||||
}
|
||||
|
||||
const (
|
||||
// The length of generated device IDs
|
||||
deviceIDByteLength = 6
|
||||
|
@ -875,3 +886,227 @@ func (d *Database) DailyRoomsMessages(
|
|||
) (stats types.MessageStats, activeRooms, activeE2EERooms int64, err error) {
|
||||
return d.Stats.DailyRoomsMessages(ctx, nil, serverName)
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
func (d *KeyDatabase) ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
|
||||
return d.OneTimeKeysTable.SelectOneTimeKeys(ctx, userID, deviceID, keyIDsWithAlgorithms)
|
||||
}
|
||||
|
||||
func (d *KeyDatabase) StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (counts *api.OneTimeKeysCount, err error) {
|
||||
_ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
counts, err = d.OneTimeKeysTable.InsertOneTimeKeys(ctx, txn, keys)
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (d *KeyDatabase) OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) {
|
||||
return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID)
|
||||
}
|
||||
|
||||
func (d *KeyDatabase) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
|
||||
return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys)
|
||||
}
|
||||
|
||||
func (d *KeyDatabase) PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) {
|
||||
count, err := d.DeviceKeysTable.CountStreamIDsForUser(ctx, userID, prevIDs)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return count == len(prevIDs), nil
|
||||
}
|
||||
|
||||
func (d *KeyDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
for _, userID := range clearUserIDs {
|
||||
err := d.DeviceKeysTable.DeleteAllDeviceKeys(ctx, txn, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys)
|
||||
})
|
||||
}
|
||||
|
||||
func (d *KeyDatabase) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error {
|
||||
// work out the latest stream IDs for each user
|
||||
userIDToStreamID := make(map[string]int64)
|
||||
for _, k := range keys {
|
||||
userIDToStreamID[k.UserID] = 0
|
||||
}
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
for userID := range userIDToStreamID {
|
||||
streamID, err := d.DeviceKeysTable.SelectMaxStreamIDForUser(ctx, txn, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userIDToStreamID[userID] = streamID
|
||||
}
|
||||
// set the stream IDs for each key
|
||||
for i := range keys {
|
||||
k := keys[i]
|
||||
userIDToStreamID[k.UserID]++ // start stream from 1
|
||||
k.StreamID = userIDToStreamID[k.UserID]
|
||||
keys[i] = k
|
||||
}
|
||||
return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys)
|
||||
})
|
||||
}
|
||||
|
||||
func (d *KeyDatabase) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
|
||||
return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs, includeEmpty)
|
||||
}
|
||||
|
||||
func (d *KeyDatabase) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) {
|
||||
var result []api.OneTimeKeys
|
||||
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
for userID, deviceToAlgo := range userToDeviceToAlgorithm {
|
||||
for deviceID, algo := range deviceToAlgo {
|
||||
keyJSON, err := d.OneTimeKeysTable.SelectAndDeleteOneTimeKey(ctx, txn, userID, deviceID, algo)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if keyJSON != nil {
|
||||
result = append(result, api.OneTimeKeys{
|
||||
UserID: userID,
|
||||
DeviceID: deviceID,
|
||||
KeyJSON: keyJSON,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (d *KeyDatabase) StoreKeyChange(ctx context.Context, userID string) (id int64, err error) {
|
||||
err = d.Writer.Do(nil, nil, func(_ *sql.Tx) error {
|
||||
id, err = d.KeyChangesTable.InsertKeyChange(ctx, userID)
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (d *KeyDatabase) KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) {
|
||||
return d.KeyChangesTable.SelectKeyChanges(ctx, fromOffset, toOffset)
|
||||
}
|
||||
|
||||
// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
|
||||
// If no domains are given, all user IDs with stale device lists are returned.
|
||||
func (d *KeyDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
|
||||
return d.StaleDeviceListsTable.SelectUserIDsWithStaleDeviceLists(ctx, domains)
|
||||
}
|
||||
|
||||
// MarkDeviceListStale sets the stale bit for this user to isStale.
|
||||
func (d *KeyDatabase) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error {
|
||||
return d.Writer.Do(nil, nil, func(_ *sql.Tx) error {
|
||||
return d.StaleDeviceListsTable.InsertStaleDeviceList(ctx, userID, isStale)
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying
|
||||
// cross-signing signatures relating to that device.
|
||||
func (d *KeyDatabase) DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
for _, deviceID := range deviceIDs {
|
||||
if err := d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget(ctx, txn, userID, deviceID); err != nil && err != sql.ErrNoRows {
|
||||
return fmt.Errorf("d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget: %w", err)
|
||||
}
|
||||
if err := d.DeviceKeysTable.DeleteDeviceKeys(ctx, txn, userID, string(deviceID)); err != nil && err != sql.ErrNoRows {
|
||||
return fmt.Errorf("d.DeviceKeysTable.DeleteDeviceKeys: %w", err)
|
||||
}
|
||||
if err := d.OneTimeKeysTable.DeleteOneTimeKeys(ctx, txn, userID, string(deviceID)); err != nil && err != sql.ErrNoRows {
|
||||
return fmt.Errorf("d.OneTimeKeysTable.DeleteOneTimeKeys: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any.
|
||||
func (d *KeyDatabase) CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error) {
|
||||
keyMap, err := d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("d.CrossSigningKeysTable.SelectCrossSigningKeysForUser: %w", err)
|
||||
}
|
||||
results := map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey{}
|
||||
for purpose, key := range keyMap {
|
||||
keyID := gomatrixserverlib.KeyID("ed25519:" + key.Encode())
|
||||
result := gomatrixserverlib.CrossSigningKey{
|
||||
UserID: userID,
|
||||
Usage: []gomatrixserverlib.CrossSigningKeyPurpose{purpose},
|
||||
Keys: map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{
|
||||
keyID: key,
|
||||
},
|
||||
}
|
||||
sigMap, err := d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, userID, userID, keyID)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
for sigUserID, forSigUserID := range sigMap {
|
||||
if userID != sigUserID {
|
||||
continue
|
||||
}
|
||||
if result.Signatures == nil {
|
||||
result.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
|
||||
}
|
||||
if _, ok := result.Signatures[sigUserID]; !ok {
|
||||
result.Signatures[sigUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
|
||||
}
|
||||
for sigKeyID, sigBytes := range forSigUserID {
|
||||
result.Signatures[sigUserID][sigKeyID] = sigBytes
|
||||
}
|
||||
}
|
||||
results[purpose] = result
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any.
|
||||
func (d *KeyDatabase) CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error) {
|
||||
return d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID)
|
||||
}
|
||||
|
||||
// CrossSigningSigsForTarget returns the signatures for a given user's key ID, if any.
|
||||
func (d *KeyDatabase) CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) {
|
||||
return d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, originUserID, targetUserID, targetKeyID)
|
||||
}
|
||||
|
||||
// StoreCrossSigningKeysForUser stores the latest known cross-signing keys for a user.
|
||||
func (d *KeyDatabase) StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
for keyType, keyData := range keyMap {
|
||||
if err := d.CrossSigningKeysTable.UpsertCrossSigningKeysForUser(ctx, txn, userID, keyType, keyData); err != nil {
|
||||
return fmt.Errorf("d.CrossSigningKeysTable.InsertCrossSigningKeysForUser: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// StoreCrossSigningSigsForTarget stores a signature for a target user ID and key/dvice.
|
||||
func (d *KeyDatabase) StoreCrossSigningSigsForTarget(
|
||||
ctx context.Context,
|
||||
originUserID string, originKeyID gomatrixserverlib.KeyID,
|
||||
targetUserID string, targetKeyID gomatrixserverlib.KeyID,
|
||||
signature gomatrixserverlib.Base64Bytes,
|
||||
) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
if err := d.CrossSigningSigsTable.UpsertCrossSigningSigsForTarget(ctx, nil, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil {
|
||||
return fmt.Errorf("d.CrossSigningSigsTable.InsertCrossSigningSigsForTarget: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteStaleDeviceLists deletes stale device list entries for users we don't share a room with anymore.
|
||||
func (d *KeyDatabase) DeleteStaleDeviceLists(
|
||||
ctx context.Context,
|
||||
userIDs []string,
|
||||
) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.StaleDeviceListsTable.DeleteStaleDeviceLists(ctx, txn, userIDs)
|
||||
})
|
||||
}
|
||||
|
|
101
userapi/storage/sqlite3/cross_signing_keys_table.go
Normal file
101
userapi/storage/sqlite3/cross_signing_keys_table.go
Normal file
|
@ -0,0 +1,101 @@
|
|||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/userapi/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
var crossSigningKeysSchema = `
|
||||
CREATE TABLE IF NOT EXISTS keyserver_cross_signing_keys (
|
||||
user_id TEXT NOT NULL,
|
||||
key_type INTEGER NOT NULL,
|
||||
key_data TEXT NOT NULL,
|
||||
PRIMARY KEY (user_id, key_type)
|
||||
);
|
||||
`
|
||||
|
||||
const selectCrossSigningKeysForUserSQL = "" +
|
||||
"SELECT key_type, key_data FROM keyserver_cross_signing_keys" +
|
||||
" WHERE user_id = $1"
|
||||
|
||||
const upsertCrossSigningKeysForUserSQL = "" +
|
||||
"INSERT OR REPLACE INTO keyserver_cross_signing_keys (user_id, key_type, key_data)" +
|
||||
" VALUES($1, $2, $3)"
|
||||
|
||||
type crossSigningKeysStatements struct {
|
||||
db *sql.DB
|
||||
selectCrossSigningKeysForUserStmt *sql.Stmt
|
||||
upsertCrossSigningKeysForUserStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error) {
|
||||
s := &crossSigningKeysStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(crossSigningKeysSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.selectCrossSigningKeysForUserStmt, selectCrossSigningKeysForUserSQL},
|
||||
{&s.upsertCrossSigningKeysForUserStmt, upsertCrossSigningKeysForUserSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser(
|
||||
ctx context.Context, txn *sql.Tx, userID string,
|
||||
) (r types.CrossSigningKeyMap, err error) {
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningKeysForUserStmt).QueryContext(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningKeysForUserStmt: rows.close() failed")
|
||||
r = types.CrossSigningKeyMap{}
|
||||
for rows.Next() {
|
||||
var keyTypeInt int16
|
||||
var keyData gomatrixserverlib.Base64Bytes
|
||||
if err := rows.Scan(&keyTypeInt, &keyData); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keyType, ok := types.KeyTypeIntToPurpose[keyTypeInt]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown key purpose int %d", keyTypeInt)
|
||||
}
|
||||
r[keyType] = keyData
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser(
|
||||
ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes,
|
||||
) error {
|
||||
keyTypeInt, ok := types.KeyTypePurposeToInt[keyType]
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown key purpose %q", keyType)
|
||||
}
|
||||
if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningKeysForUserStmt).ExecContext(ctx, userID, keyTypeInt, keyData); err != nil {
|
||||
return fmt.Errorf("s.upsertCrossSigningKeysForUserStmt: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
129
userapi/storage/sqlite3/cross_signing_sigs_table.go
Normal file
129
userapi/storage/sqlite3/cross_signing_sigs_table.go
Normal file
|
@ -0,0 +1,129 @@
|
|||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/userapi/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
var crossSigningSigsSchema = `
|
||||
CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs (
|
||||
origin_user_id TEXT NOT NULL,
|
||||
origin_key_id TEXT NOT NULL,
|
||||
target_user_id TEXT NOT NULL,
|
||||
target_key_id TEXT NOT NULL,
|
||||
signature TEXT NOT NULL,
|
||||
PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id);
|
||||
`
|
||||
|
||||
const selectCrossSigningSigsForTargetSQL = "" +
|
||||
"SELECT origin_user_id, origin_key_id, signature FROM keyserver_cross_signing_sigs" +
|
||||
" WHERE (origin_user_id = $1 OR origin_user_id = $2) AND target_user_id = $3 AND target_key_id = $4"
|
||||
|
||||
const upsertCrossSigningSigsForTargetSQL = "" +
|
||||
"INSERT OR REPLACE INTO keyserver_cross_signing_sigs (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" +
|
||||
" VALUES($1, $2, $3, $4, $5)"
|
||||
|
||||
const deleteCrossSigningSigsForTargetSQL = "" +
|
||||
"DELETE FROM keyserver_cross_signing_sigs WHERE target_user_id=$1 AND target_key_id=$2"
|
||||
|
||||
type crossSigningSigsStatements struct {
|
||||
db *sql.DB
|
||||
selectCrossSigningSigsForTargetStmt *sql.Stmt
|
||||
upsertCrossSigningSigsForTargetStmt *sql.Stmt
|
||||
deleteCrossSigningSigsForTargetStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, error) {
|
||||
s := &crossSigningSigsStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(crossSigningSigsSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m := sqlutil.NewMigrator(db)
|
||||
m.AddMigrations(sqlutil.Migration{
|
||||
Version: "keyserver: cross signing signature indexes",
|
||||
Up: deltas.UpFixCrossSigningSignatureIndexes,
|
||||
})
|
||||
if err = m.Up(context.Background()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.selectCrossSigningSigsForTargetStmt, selectCrossSigningSigsForTargetSQL},
|
||||
{&s.upsertCrossSigningSigsForTargetStmt, upsertCrossSigningSigsForTargetSQL},
|
||||
{&s.deleteCrossSigningSigsForTargetStmt, deleteCrossSigningSigsForTargetSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget(
|
||||
ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID,
|
||||
) (r types.CrossSigningSigMap, err error) {
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningSigsForTargetStmt).QueryContext(ctx, originUserID, targetUserID, targetUserID, targetKeyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningSigsForOriginTargetStmt: rows.close() failed")
|
||||
r = types.CrossSigningSigMap{}
|
||||
for rows.Next() {
|
||||
var userID string
|
||||
var keyID gomatrixserverlib.KeyID
|
||||
var signature gomatrixserverlib.Base64Bytes
|
||||
if err := rows.Scan(&userID, &keyID, &signature); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, ok := r[userID]; !ok {
|
||||
r[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
|
||||
}
|
||||
r[userID][keyID] = signature
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *crossSigningSigsStatements) UpsertCrossSigningSigsForTarget(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
originUserID string, originKeyID gomatrixserverlib.KeyID,
|
||||
targetUserID string, targetKeyID gomatrixserverlib.KeyID,
|
||||
signature gomatrixserverlib.Base64Bytes,
|
||||
) error {
|
||||
if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningSigsForTargetStmt).ExecContext(ctx, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil {
|
||||
return fmt.Errorf("s.upsertCrossSigningSigsForTargetStmt: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *crossSigningSigsStatements) DeleteCrossSigningSigsForTarget(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
targetUserID string, targetKeyID gomatrixserverlib.KeyID,
|
||||
) error {
|
||||
if _, err := sqlutil.TxStmt(txn, s.deleteCrossSigningSigsForTargetStmt).ExecContext(ctx, targetUserID, targetKeyID); err != nil {
|
||||
return fmt.Errorf("s.deleteCrossSigningSigsForTargetStmt: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,66 @@
|
|||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package deltas
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func UpRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error {
|
||||
// start counting from the last max offset, else 0.
|
||||
var maxOffset int64
|
||||
var userID string
|
||||
_ = tx.QueryRowContext(ctx, `SELECT user_id, MAX(log_offset) FROM keyserver_key_changes GROUP BY user_id`).Scan(&userID, &maxOffset)
|
||||
|
||||
_, err := tx.ExecContext(ctx, `
|
||||
-- make the new table
|
||||
DROP TABLE IF EXISTS keyserver_key_changes;
|
||||
CREATE TABLE IF NOT EXISTS keyserver_key_changes (
|
||||
change_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
-- The key owner
|
||||
user_id TEXT NOT NULL,
|
||||
UNIQUE (user_id)
|
||||
);
|
||||
`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute upgrade: %w", err)
|
||||
}
|
||||
// to start counting from maxOffset, insert a row with that value
|
||||
if userID != "" {
|
||||
_, err = tx.ExecContext(ctx, `INSERT INTO keyserver_key_changes(change_id, user_id) VALUES($1, $2)`, maxOffset, userID)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func DownRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error {
|
||||
_, err := tx.ExecContext(ctx, `
|
||||
-- Drop all data and revert back, we can't keep the data as Kafka offsets determine the numbers
|
||||
DROP TABLE IF EXISTS keyserver_key_changes;
|
||||
CREATE TABLE IF NOT EXISTS keyserver_key_changes (
|
||||
partition BIGINT NOT NULL,
|
||||
offset BIGINT NOT NULL,
|
||||
-- The key owner
|
||||
user_id TEXT NOT NULL,
|
||||
UNIQUE (partition, offset)
|
||||
);
|
||||
`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute downgrade: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,71 @@
|
|||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package deltas
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func UpFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error {
|
||||
_, err := tx.ExecContext(ctx, `
|
||||
CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs_tmp (
|
||||
origin_user_id TEXT NOT NULL,
|
||||
origin_key_id TEXT NOT NULL,
|
||||
target_user_id TEXT NOT NULL,
|
||||
target_key_id TEXT NOT NULL,
|
||||
signature TEXT NOT NULL,
|
||||
PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id)
|
||||
);
|
||||
|
||||
INSERT INTO keyserver_cross_signing_sigs_tmp (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)
|
||||
SELECT origin_user_id, origin_key_id, target_user_id, target_key_id, signature FROM keyserver_cross_signing_sigs;
|
||||
|
||||
DROP TABLE keyserver_cross_signing_sigs;
|
||||
ALTER TABLE keyserver_cross_signing_sigs_tmp RENAME TO keyserver_cross_signing_sigs;
|
||||
|
||||
CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id);
|
||||
`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute upgrade: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func DownFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error {
|
||||
_, err := tx.ExecContext(ctx, `
|
||||
CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs_tmp (
|
||||
origin_user_id TEXT NOT NULL,
|
||||
origin_key_id TEXT NOT NULL,
|
||||
target_user_id TEXT NOT NULL,
|
||||
target_key_id TEXT NOT NULL,
|
||||
signature TEXT NOT NULL,
|
||||
PRIMARY KEY (origin_user_id, target_user_id, target_key_id)
|
||||
);
|
||||
|
||||
INSERT INTO keyserver_cross_signing_sigs_tmp (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)
|
||||
SELECT origin_user_id, origin_key_id, target_user_id, target_key_id, signature FROM keyserver_cross_signing_sigs;
|
||||
|
||||
DROP TABLE keyserver_cross_signing_sigs;
|
||||
ALTER TABLE keyserver_cross_signing_sigs_tmp RENAME TO keyserver_cross_signing_sigs;
|
||||
|
||||
DELETE INDEX IF EXISTS keyserver_cross_signing_sigs_idx;
|
||||
`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute downgrade: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
213
userapi/storage/sqlite3/device_keys_table.go
Normal file
213
userapi/storage/sqlite3/device_keys_table.go
Normal file
|
@ -0,0 +1,213 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
)
|
||||
|
||||
var deviceKeysSchema = `
|
||||
-- Stores device keys for users
|
||||
CREATE TABLE IF NOT EXISTS keyserver_device_keys (
|
||||
user_id TEXT NOT NULL,
|
||||
device_id TEXT NOT NULL,
|
||||
ts_added_secs BIGINT NOT NULL,
|
||||
key_json TEXT NOT NULL,
|
||||
stream_id BIGINT NOT NULL,
|
||||
display_name TEXT,
|
||||
-- Clobber based on tuple of user/device.
|
||||
UNIQUE (user_id, device_id)
|
||||
);
|
||||
`
|
||||
|
||||
const upsertDeviceKeysSQL = "" +
|
||||
"INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" +
|
||||
" VALUES ($1, $2, $3, $4, $5, $6)" +
|
||||
" ON CONFLICT (user_id, device_id)" +
|
||||
" DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6"
|
||||
|
||||
const selectDeviceKeysSQL = "" +
|
||||
"SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
|
||||
|
||||
const selectBatchDeviceKeysSQL = "" +
|
||||
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''"
|
||||
|
||||
const selectBatchDeviceKeysWithEmptiesSQL = "" +
|
||||
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1"
|
||||
|
||||
const selectMaxStreamForUserSQL = "" +
|
||||
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
|
||||
|
||||
const countStreamIDsForUserSQL = "" +
|
||||
"SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id IN ($2)"
|
||||
|
||||
const deleteDeviceKeysSQL = "" +
|
||||
"DELETE FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
|
||||
|
||||
const deleteAllDeviceKeysSQL = "" +
|
||||
"DELETE FROM keyserver_device_keys WHERE user_id=$1"
|
||||
|
||||
type deviceKeysStatements struct {
|
||||
db *sql.DB
|
||||
upsertDeviceKeysStmt *sql.Stmt
|
||||
selectDeviceKeysStmt *sql.Stmt
|
||||
selectBatchDeviceKeysStmt *sql.Stmt
|
||||
selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt
|
||||
selectMaxStreamForUserStmt *sql.Stmt
|
||||
deleteDeviceKeysStmt *sql.Stmt
|
||||
deleteAllDeviceKeysStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
||||
s := &deviceKeysStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(deviceKeysSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.upsertDeviceKeysStmt, upsertDeviceKeysSQL},
|
||||
{&s.selectDeviceKeysStmt, selectDeviceKeysSQL},
|
||||
{&s.selectBatchDeviceKeysStmt, selectBatchDeviceKeysSQL},
|
||||
{&s.selectBatchDeviceKeysWithEmptiesStmt, selectBatchDeviceKeysWithEmptiesSQL},
|
||||
{&s.selectMaxStreamForUserStmt, selectMaxStreamForUserSQL},
|
||||
// {&s.countStreamIDsForUserStmt, countStreamIDsForUserSQL}, // prepared at runtime
|
||||
{&s.deleteDeviceKeysStmt, deleteDeviceKeysSQL},
|
||||
{&s.deleteAllDeviceKeysStmt, deleteAllDeviceKeysSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.deleteDeviceKeysStmt).ExecContext(ctx, userID, deviceID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
|
||||
deviceIDMap := make(map[string]bool)
|
||||
for _, d := range deviceIDs {
|
||||
deviceIDMap[d] = true
|
||||
}
|
||||
var stmt *sql.Stmt
|
||||
if includeEmpty {
|
||||
stmt = s.selectBatchDeviceKeysWithEmptiesStmt
|
||||
} else {
|
||||
stmt = s.selectBatchDeviceKeysStmt
|
||||
}
|
||||
rows, err := stmt.QueryContext(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed")
|
||||
var result []api.DeviceMessage
|
||||
var displayName sql.NullString
|
||||
for rows.Next() {
|
||||
dk := api.DeviceMessage{
|
||||
Type: api.TypeDeviceKeyUpdate,
|
||||
DeviceKeys: &api.DeviceKeys{
|
||||
UserID: userID,
|
||||
},
|
||||
}
|
||||
if err := rows.Scan(&dk.DeviceID, &dk.KeyJSON, &dk.StreamID, &displayName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if displayName.Valid {
|
||||
dk.DisplayName = displayName.String
|
||||
}
|
||||
// include the key if we want all keys (no device) or it was asked
|
||||
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
|
||||
result = append(result, dk)
|
||||
}
|
||||
}
|
||||
return result, rows.Err()
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
|
||||
for i, key := range keys {
|
||||
var keyJSONStr string
|
||||
var streamID int64
|
||||
var displayName sql.NullString
|
||||
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
// this will be '' when there is no device
|
||||
keys[i].Type = api.TypeDeviceKeyUpdate
|
||||
keys[i].KeyJSON = []byte(keyJSONStr)
|
||||
keys[i].StreamID = streamID
|
||||
if displayName.Valid {
|
||||
keys[i].DisplayName = displayName.String
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error) {
|
||||
// nullable if there are no results
|
||||
var nullStream sql.NullInt64
|
||||
err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
|
||||
if err == sql.ErrNoRows {
|
||||
err = nil
|
||||
}
|
||||
if nullStream.Valid {
|
||||
streamID = nullStream.Int64
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) {
|
||||
iStreamIDs := make([]interface{}, len(streamIDs)+1)
|
||||
iStreamIDs[0] = userID
|
||||
for i := range streamIDs {
|
||||
iStreamIDs[i+1] = streamIDs[i]
|
||||
}
|
||||
query := strings.Replace(countStreamIDsForUserSQL, "($2)", sqlutil.QueryVariadicOffset(len(streamIDs), 1), 1)
|
||||
// nullable if there are no results
|
||||
var count sql.NullInt64
|
||||
err := s.db.QueryRowContext(ctx, query, iStreamIDs...).Scan(&count)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if count.Valid {
|
||||
return int(count.Int64), nil
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
|
||||
for _, key := range keys {
|
||||
now := time.Now().Unix()
|
||||
_, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext(
|
||||
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -151,7 +151,7 @@ func (s *devicesStatements) InsertDevice(
|
|||
if _, err := insertStmt.ExecContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &api.Device{
|
||||
dev := &api.Device{
|
||||
ID: id,
|
||||
UserID: userutil.MakeUserID(localpart, serverName),
|
||||
AccessToken: accessToken,
|
||||
|
@ -159,7 +159,11 @@ func (s *devicesStatements) InsertDevice(
|
|||
LastSeenTS: createdTimeMS,
|
||||
LastSeenIP: ipAddr,
|
||||
UserAgent: userAgent,
|
||||
}, nil
|
||||
}
|
||||
if displayName != nil {
|
||||
dev.DisplayName = *displayName
|
||||
}
|
||||
return dev, nil
|
||||
}
|
||||
|
||||
func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn *sql.Tx, id,
|
||||
|
@ -172,7 +176,7 @@ func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn *
|
|||
if _, err := insertStmt.ExecContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &api.Device{
|
||||
dev := &api.Device{
|
||||
ID: id,
|
||||
UserID: userutil.MakeUserID(localpart, serverName),
|
||||
AccessToken: accessToken,
|
||||
|
@ -180,7 +184,11 @@ func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn *
|
|||
LastSeenTS: createdTimeMS,
|
||||
LastSeenIP: ipAddr,
|
||||
UserAgent: userAgent,
|
||||
}, nil
|
||||
}
|
||||
if displayName != nil {
|
||||
dev.DisplayName = *displayName
|
||||
}
|
||||
return dev, nil
|
||||
}
|
||||
|
||||
func (s *devicesStatements) DeleteDevice(
|
||||
|
@ -202,6 +210,7 @@ func (s *devicesStatements) DeleteDevices(
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, prep, "DeleteDevices.StmtClose() failed")
|
||||
stmt := sqlutil.TxStmt(txn, prep)
|
||||
params := make([]interface{}, len(devices)+2)
|
||||
params[0] = localpart
|
||||
|
|
|
@ -52,7 +52,7 @@ const updateBackupKeySQL = "" +
|
|||
const countKeysSQL = "" +
|
||||
"SELECT COUNT(*) FROM userapi_key_backups WHERE user_id = $1 AND version = $2"
|
||||
|
||||
const selectKeysSQL = "" +
|
||||
const selectBackupKeysSQL = "" +
|
||||
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " +
|
||||
"WHERE user_id = $1 AND version = $2"
|
||||
|
||||
|
@ -83,7 +83,7 @@ func NewSQLiteKeyBackupTable(db *sql.DB) (tables.KeyBackupTable, error) {
|
|||
{&s.insertBackupKeyStmt, insertBackupKeySQL},
|
||||
{&s.updateBackupKeyStmt, updateBackupKeySQL},
|
||||
{&s.countKeysStmt, countKeysSQL},
|
||||
{&s.selectKeysStmt, selectKeysSQL},
|
||||
{&s.selectKeysStmt, selectBackupKeysSQL},
|
||||
{&s.selectKeysByRoomIDStmt, selectKeysByRoomIDSQL},
|
||||
{&s.selectKeysByRoomIDAndSessionIDStmt, selectKeysByRoomIDAndSessionIDSQL},
|
||||
}.Prepare(db)
|
||||
|
|
125
userapi/storage/sqlite3/key_changes_table.go
Normal file
125
userapi/storage/sqlite3/key_changes_table.go
Normal file
|
@ -0,0 +1,125 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
)
|
||||
|
||||
var keyChangesSchema = `
|
||||
-- Stores key change information about users. Used to determine when to send updated device lists to clients.
|
||||
CREATE TABLE IF NOT EXISTS keyserver_key_changes (
|
||||
change_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
-- The key owner
|
||||
user_id TEXT NOT NULL,
|
||||
UNIQUE (user_id)
|
||||
);
|
||||
`
|
||||
|
||||
// Replace based on user ID. We don't care how many times the user's keys have changed, only that they
|
||||
// have changed, hence we can just keep bumping the change ID for this user.
|
||||
const upsertKeyChangeSQL = "" +
|
||||
"INSERT OR REPLACE INTO keyserver_key_changes (user_id)" +
|
||||
" VALUES ($1)" +
|
||||
" RETURNING change_id"
|
||||
|
||||
const selectKeyChangesSQL = "" +
|
||||
"SELECT user_id, change_id FROM keyserver_key_changes WHERE change_id > $1 AND change_id <= $2"
|
||||
|
||||
type keyChangesStatements struct {
|
||||
db *sql.DB
|
||||
upsertKeyChangeStmt *sql.Stmt
|
||||
selectKeyChangesStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
|
||||
s := &keyChangesStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(keyChangesSchema)
|
||||
if err != nil {
|
||||
return s, err
|
||||
}
|
||||
|
||||
if err = executeMigration(context.Background(), db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.upsertKeyChangeStmt, upsertKeyChangeSQL},
|
||||
{&s.selectKeyChangesStmt, selectKeyChangesSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func executeMigration(ctx context.Context, db *sql.DB) error {
|
||||
// TODO: Remove when we are sure we are not having goose artefacts in the db
|
||||
// This forces an error, which indicates the migration is already applied, since the
|
||||
// column partition was removed from the table
|
||||
migrationName := "keyserver: refactor key changes"
|
||||
|
||||
var cName string
|
||||
err := db.QueryRowContext(ctx, `SELECT p.name FROM sqlite_master AS m JOIN pragma_table_info(m.name) AS p WHERE m.name = 'keyserver_key_changes' AND p.name = 'partition'`).Scan(&cName)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) { // migration was already executed, as the column was removed
|
||||
if err = sqlutil.InsertMigration(ctx, db, migrationName); err != nil {
|
||||
return fmt.Errorf("unable to manually insert migration '%s': %w", migrationName, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
m := sqlutil.NewMigrator(db)
|
||||
m.AddMigrations(sqlutil.Migration{
|
||||
Version: migrationName,
|
||||
Up: deltas.UpRefactorKeyChanges,
|
||||
})
|
||||
return m.Up(ctx)
|
||||
}
|
||||
|
||||
func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, userID string) (changeID int64, err error) {
|
||||
err = s.upsertKeyChangeStmt.QueryRowContext(ctx, userID).Scan(&changeID)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *keyChangesStatements) SelectKeyChanges(
|
||||
ctx context.Context, fromOffset, toOffset int64,
|
||||
) (userIDs []string, latestOffset int64, err error) {
|
||||
latestOffset = fromOffset
|
||||
rows, err := s.selectKeyChangesStmt.QueryContext(ctx, fromOffset, toOffset)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectKeyChangesStmt: rows.close() failed")
|
||||
for rows.Next() {
|
||||
var userID string
|
||||
var offset int64
|
||||
if err := rows.Scan(&userID, &offset); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if offset > latestOffset {
|
||||
latestOffset = offset
|
||||
}
|
||||
userIDs = append(userIDs, userID)
|
||||
}
|
||||
return
|
||||
}
|
208
userapi/storage/sqlite3/one_time_keys_table.go
Normal file
208
userapi/storage/sqlite3/one_time_keys_table.go
Normal file
|
@ -0,0 +1,208 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
)
|
||||
|
||||
var oneTimeKeysSchema = `
|
||||
-- Stores one-time public keys for users
|
||||
CREATE TABLE IF NOT EXISTS keyserver_one_time_keys (
|
||||
user_id TEXT NOT NULL,
|
||||
device_id TEXT NOT NULL,
|
||||
key_id TEXT NOT NULL,
|
||||
algorithm TEXT NOT NULL,
|
||||
ts_added_secs BIGINT NOT NULL,
|
||||
key_json TEXT NOT NULL,
|
||||
-- Clobber based on 4-uple of user/device/key/algorithm.
|
||||
UNIQUE (user_id, device_id, key_id, algorithm)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS keyserver_one_time_keys_idx ON keyserver_one_time_keys (user_id, device_id);
|
||||
`
|
||||
|
||||
const upsertKeysSQL = "" +
|
||||
"INSERT INTO keyserver_one_time_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json)" +
|
||||
" VALUES ($1, $2, $3, $4, $5, $6)" +
|
||||
" ON CONFLICT (user_id, device_id, key_id, algorithm)" +
|
||||
" DO UPDATE SET key_json = $6"
|
||||
|
||||
const selectOneTimeKeysSQL = "" +
|
||||
"SELECT key_id, algorithm, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2"
|
||||
|
||||
const selectKeysCountSQL = "" +
|
||||
"SELECT algorithm, COUNT(key_id) FROM " +
|
||||
" (SELECT algorithm, key_id FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 LIMIT 100)" +
|
||||
" x GROUP BY algorithm"
|
||||
|
||||
const deleteOneTimeKeySQL = "" +
|
||||
"DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 AND key_id = $4"
|
||||
|
||||
const selectKeyByAlgorithmSQL = "" +
|
||||
"SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1"
|
||||
|
||||
const deleteOneTimeKeysSQL = "" +
|
||||
"DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2"
|
||||
|
||||
type oneTimeKeysStatements struct {
|
||||
db *sql.DB
|
||||
upsertKeysStmt *sql.Stmt
|
||||
selectKeysStmt *sql.Stmt
|
||||
selectKeysCountStmt *sql.Stmt
|
||||
selectKeyByAlgorithmStmt *sql.Stmt
|
||||
deleteOneTimeKeyStmt *sql.Stmt
|
||||
deleteOneTimeKeysStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
|
||||
s := &oneTimeKeysStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(oneTimeKeysSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.upsertKeysStmt, upsertKeysSQL},
|
||||
{&s.selectKeysStmt, selectOneTimeKeysSQL},
|
||||
{&s.selectKeysCountStmt, selectKeysCountSQL},
|
||||
{&s.selectKeyByAlgorithmStmt, selectKeyByAlgorithmSQL},
|
||||
{&s.deleteOneTimeKeyStmt, deleteOneTimeKeySQL},
|
||||
{&s.deleteOneTimeKeysStmt, deleteOneTimeKeysSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
|
||||
rows, err := s.selectKeysStmt.QueryContext(ctx, userID, deviceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt: rows.close() failed")
|
||||
|
||||
wantSet := make(map[string]bool, len(keyIDsWithAlgorithms))
|
||||
for _, ka := range keyIDsWithAlgorithms {
|
||||
wantSet[ka] = true
|
||||
}
|
||||
|
||||
result := make(map[string]json.RawMessage)
|
||||
for rows.Next() {
|
||||
var keyID string
|
||||
var algorithm string
|
||||
var keyJSONStr string
|
||||
if err := rows.Scan(&keyID, &algorithm, &keyJSONStr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keyIDWithAlgo := algorithm + ":" + keyID
|
||||
if wantSet[keyIDWithAlgo] {
|
||||
result[keyIDWithAlgo] = json.RawMessage(keyJSONStr)
|
||||
}
|
||||
}
|
||||
return result, rows.Err()
|
||||
}
|
||||
|
||||
func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) {
|
||||
counts := &api.OneTimeKeysCount{
|
||||
DeviceID: deviceID,
|
||||
UserID: userID,
|
||||
KeyCount: make(map[string]int),
|
||||
}
|
||||
rows, err := s.selectKeysCountStmt.QueryContext(ctx, userID, deviceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
|
||||
for rows.Next() {
|
||||
var algorithm string
|
||||
var count int
|
||||
if err = rows.Scan(&algorithm, &count); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
counts.KeyCount[algorithm] = count
|
||||
}
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
func (s *oneTimeKeysStatements) InsertOneTimeKeys(
|
||||
ctx context.Context, txn *sql.Tx, keys api.OneTimeKeys,
|
||||
) (*api.OneTimeKeysCount, error) {
|
||||
now := time.Now().Unix()
|
||||
counts := &api.OneTimeKeysCount{
|
||||
DeviceID: keys.DeviceID,
|
||||
UserID: keys.UserID,
|
||||
KeyCount: make(map[string]int),
|
||||
}
|
||||
for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
|
||||
algo, keyID := keys.Split(keyIDWithAlgo)
|
||||
_, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext(
|
||||
ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
|
||||
for rows.Next() {
|
||||
var algorithm string
|
||||
var count int
|
||||
if err = rows.Scan(&algorithm, &count); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
counts.KeyCount[algorithm] = count
|
||||
}
|
||||
|
||||
return counts, rows.Err()
|
||||
}
|
||||
|
||||
func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey(
|
||||
ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string,
|
||||
) (map[string]json.RawMessage, error) {
|
||||
var keyID string
|
||||
var keyJSON string
|
||||
err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
_, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if keyJSON == "" {
|
||||
return nil, nil
|
||||
}
|
||||
return map[string]json.RawMessage{
|
||||
algorithm + ":" + keyID: json.RawMessage(keyJSON),
|
||||
}, err
|
||||
}
|
||||
|
||||
func (s *oneTimeKeysStatements) DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.deleteOneTimeKeysStmt).ExecContext(ctx, userID, deviceID)
|
||||
return err
|
||||
}
|
145
userapi/storage/sqlite3/stale_device_lists.go
Normal file
145
userapi/storage/sqlite3/stale_device_lists.go
Normal file
|
@ -0,0 +1,145 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
var staleDeviceListsSchema = `
|
||||
-- Stores whether a user's device lists are stale or not.
|
||||
CREATE TABLE IF NOT EXISTS keyserver_stale_device_lists (
|
||||
user_id TEXT PRIMARY KEY NOT NULL,
|
||||
domain TEXT NOT NULL,
|
||||
is_stale BOOLEAN NOT NULL,
|
||||
ts_added_secs BIGINT NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS keyserver_stale_device_lists_idx ON keyserver_stale_device_lists (domain, is_stale);
|
||||
`
|
||||
|
||||
const upsertStaleDeviceListSQL = "" +
|
||||
"INSERT INTO keyserver_stale_device_lists (user_id, domain, is_stale, ts_added_secs)" +
|
||||
" VALUES ($1, $2, $3, $4)" +
|
||||
" ON CONFLICT (user_id)" +
|
||||
" DO UPDATE SET is_stale = $3, ts_added_secs = $4"
|
||||
|
||||
const selectStaleDeviceListsWithDomainsSQL = "" +
|
||||
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2 ORDER BY ts_added_secs DESC"
|
||||
|
||||
const selectStaleDeviceListsSQL = "" +
|
||||
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 ORDER BY ts_added_secs DESC"
|
||||
|
||||
const deleteStaleDevicesSQL = "" +
|
||||
"DELETE FROM keyserver_stale_device_lists WHERE user_id IN ($1)"
|
||||
|
||||
type staleDeviceListsStatements struct {
|
||||
db *sql.DB
|
||||
upsertStaleDeviceListStmt *sql.Stmt
|
||||
selectStaleDeviceListsWithDomainsStmt *sql.Stmt
|
||||
selectStaleDeviceListsStmt *sql.Stmt
|
||||
// deleteStaleDeviceListsStmt *sql.Stmt // Prepared at runtime
|
||||
}
|
||||
|
||||
func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) {
|
||||
s := &staleDeviceListsStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(staleDeviceListsSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.upsertStaleDeviceListStmt, upsertStaleDeviceListSQL},
|
||||
{&s.selectStaleDeviceListsStmt, selectStaleDeviceListsSQL},
|
||||
{&s.selectStaleDeviceListsWithDomainsStmt, selectStaleDeviceListsWithDomainsSQL},
|
||||
// { &s.deleteStaleDeviceListsStmt, deleteStaleDevicesSQL}, // Prepared at runtime
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error {
|
||||
_, domain, err := gomatrixserverlib.SplitID('@', userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, gomatrixserverlib.AsTimestamp(time.Now()))
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
|
||||
// we only query for 1 domain or all domains so optimise for those use cases
|
||||
if len(domains) == 0 {
|
||||
rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return rowsToUserIDs(ctx, rows)
|
||||
}
|
||||
var result []string
|
||||
for _, domain := range domains {
|
||||
rows, err := s.selectStaleDeviceListsWithDomainsStmt.QueryContext(ctx, true, string(domain))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userIDs, err := rowsToUserIDs(ctx, rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, userIDs...)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// DeleteStaleDeviceLists removes users from stale device lists
|
||||
func (s *staleDeviceListsStatements) DeleteStaleDeviceLists(
|
||||
ctx context.Context, txn *sql.Tx, userIDs []string,
|
||||
) error {
|
||||
qry := strings.Replace(deleteStaleDevicesSQL, "($1)", sqlutil.QueryVariadic(len(userIDs)), 1)
|
||||
stmt, err := s.db.Prepare(qry)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, stmt, "DeleteStaleDeviceLists: stmt.Close failed")
|
||||
stmt = sqlutil.TxStmt(txn, stmt)
|
||||
|
||||
params := make([]any, len(userIDs))
|
||||
for i := range userIDs {
|
||||
params[i] = userIDs[i]
|
||||
}
|
||||
|
||||
_, err = stmt.ExecContext(ctx, params...)
|
||||
return err
|
||||
}
|
||||
|
||||
func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) {
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed")
|
||||
for rows.Next() {
|
||||
var userID string
|
||||
if err := rows.Scan(&userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, userID)
|
||||
}
|
||||
return result, rows.Err()
|
||||
}
|
|
@ -256,6 +256,7 @@ func (s *statsStatements) allUsers(ctx context.Context, txn *sql.Tx) (result int
|
|||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, queryStmt, "allUsers.StmtClose() failed")
|
||||
stmt := sqlutil.TxStmt(txn, queryStmt)
|
||||
err = stmt.QueryRowContext(ctx,
|
||||
1, 2, 3, 4,
|
||||
|
@ -269,6 +270,7 @@ func (s *statsStatements) nonBridgedUsers(ctx context.Context, txn *sql.Tx) (res
|
|||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, queryStmt, "nonBridgedUsers.StmtClose() failed")
|
||||
stmt := sqlutil.TxStmt(txn, queryStmt)
|
||||
err = stmt.QueryRowContext(ctx,
|
||||
1, 2, 3,
|
||||
|
@ -286,6 +288,7 @@ func (s *statsStatements) registeredUserByType(ctx context.Context, txn *sql.Tx)
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, queryStmt, "registeredUserByType.StmtClose() failed")
|
||||
stmt := sqlutil.TxStmt(txn, queryStmt)
|
||||
registeredAfter := time.Now().AddDate(0, 0, -30)
|
||||
|
||||
|
|
|
@ -30,8 +30,8 @@ import (
|
|||
"github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
|
||||
)
|
||||
|
||||
// NewDatabase creates a new accounts and profiles database
|
||||
func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (*shared.Database, error) {
|
||||
// NewUserDatabase creates a new accounts and profiles database
|
||||
func NewUserDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (*shared.Database, error) {
|
||||
db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -134,3 +134,44 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
|
|||
OpenIDTokenLifetimeMS: openIDTokenLifetimeMS,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func NewKeyDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.KeyDatabase, error) {
|
||||
db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
otk, err := NewSqliteOneTimeKeysTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dk, err := NewSqliteDeviceKeysTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
kc, err := NewSqliteKeyChangesTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sdl, err := NewSqliteStaleDeviceListsTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
csk, err := NewSqliteCrossSigningKeysTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
css, err := NewSqliteCrossSigningSigsTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &shared.KeyDatabase{
|
||||
OneTimeKeysTable: otk,
|
||||
DeviceKeysTable: dk,
|
||||
KeyChangesTable: kc,
|
||||
StaleDeviceListsTable: sdl,
|
||||
CrossSigningKeysTable: csk,
|
||||
CrossSigningSigsTable: css,
|
||||
Writer: writer,
|
||||
}, nil
|
||||
}
|
||||
|
|
|
@ -29,15 +29,36 @@ import (
|
|||
"github.com/matrix-org/dendrite/userapi/storage/sqlite3"
|
||||
)
|
||||
|
||||
// NewUserAPIDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
|
||||
// NewUserDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
|
||||
// and sets postgres connection parameters
|
||||
func NewUserAPIDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (Database, error) {
|
||||
func NewUserDatabase(
|
||||
base *base.BaseDendrite,
|
||||
dbProperties *config.DatabaseOptions,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
bcryptCost int,
|
||||
openIDTokenLifetimeMS int64,
|
||||
loginTokenLifetime time.Duration,
|
||||
serverNoticesLocalpart string,
|
||||
) (UserDatabase, error) {
|
||||
switch {
|
||||
case dbProperties.ConnectionString.IsSQLite():
|
||||
return sqlite3.NewDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart)
|
||||
return sqlite3.NewUserDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart)
|
||||
case dbProperties.ConnectionString.IsPostgres():
|
||||
return postgres.NewDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart)
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected database type")
|
||||
}
|
||||
}
|
||||
|
||||
// NewKeyDatabase opens a new Postgres or Sqlite database (base on dataSourceName) scheme)
|
||||
// and sets postgres connection parameters.
|
||||
func NewKeyDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (KeyDatabase, error) {
|
||||
switch {
|
||||
case dbProperties.ConnectionString.IsSQLite():
|
||||
return sqlite3.NewKeyDatabase(base, dbProperties)
|
||||
case dbProperties.ConnectionString.IsPostgres():
|
||||
return postgres.NewKeyDatabase(base, dbProperties)
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected database type")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,9 +4,12 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/userapi/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
@ -29,14 +32,14 @@ var (
|
|||
ctx = context.Background()
|
||||
)
|
||||
|
||||
func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
|
||||
func mustCreateUserDatabase(t *testing.T, dbType test.DBType) (storage.UserDatabase, func()) {
|
||||
base, baseclose := testrig.CreateBaseDendrite(t, dbType)
|
||||
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||
db, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{
|
||||
db, err := storage.NewUserDatabase(base, &config.DatabaseOptions{
|
||||
ConnectionString: config.DataSource(connStr),
|
||||
}, "localhost", bcrypt.MinCost, openIDLifetimeMS, loginTokenLifetime, "_server")
|
||||
if err != nil {
|
||||
t.Fatalf("NewUserAPIDatabase returned %s", err)
|
||||
t.Fatalf("NewUserDatabase returned %s", err)
|
||||
}
|
||||
return db, func() {
|
||||
close()
|
||||
|
@ -47,7 +50,7 @@ func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, fun
|
|||
// Tests storing and getting account data
|
||||
func Test_AccountData(t *testing.T) {
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
db, close := mustCreateUserDatabase(t, dbType)
|
||||
defer close()
|
||||
alice := test.NewUser(t)
|
||||
localpart, domain, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
|
@ -78,7 +81,7 @@ func Test_AccountData(t *testing.T) {
|
|||
// Tests the creation of accounts
|
||||
func Test_Accounts(t *testing.T) {
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
db, close := mustCreateUserDatabase(t, dbType)
|
||||
defer close()
|
||||
alice := test.NewUser(t)
|
||||
aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID)
|
||||
|
@ -158,7 +161,7 @@ func Test_Devices(t *testing.T) {
|
|||
accessToken := util.RandomString(16)
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
db, close := mustCreateUserDatabase(t, dbType)
|
||||
defer close()
|
||||
|
||||
deviceWithID, err := db.CreateDevice(ctx, localpart, domain, &deviceID, accessToken, nil, "", "")
|
||||
|
@ -238,7 +241,7 @@ func Test_KeyBackup(t *testing.T) {
|
|||
room := test.NewRoom(t, alice)
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
db, close := mustCreateUserDatabase(t, dbType)
|
||||
defer close()
|
||||
|
||||
wantAuthData := json.RawMessage("my auth data")
|
||||
|
@ -315,7 +318,7 @@ func Test_KeyBackup(t *testing.T) {
|
|||
func Test_LoginToken(t *testing.T) {
|
||||
alice := test.NewUser(t)
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
db, close := mustCreateUserDatabase(t, dbType)
|
||||
defer close()
|
||||
|
||||
// create a new token
|
||||
|
@ -347,7 +350,7 @@ func Test_OpenID(t *testing.T) {
|
|||
token := util.RandomString(24)
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
db, close := mustCreateUserDatabase(t, dbType)
|
||||
defer close()
|
||||
|
||||
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + openIDLifetimeMS
|
||||
|
@ -368,7 +371,7 @@ func Test_Profile(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
db, close := mustCreateUserDatabase(t, dbType)
|
||||
defer close()
|
||||
|
||||
// create account, which also creates a profile
|
||||
|
@ -417,7 +420,7 @@ func Test_Pusher(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
db, close := mustCreateUserDatabase(t, dbType)
|
||||
defer close()
|
||||
|
||||
appID := util.RandomString(8)
|
||||
|
@ -468,7 +471,7 @@ func Test_ThreePID(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
db, close := mustCreateUserDatabase(t, dbType)
|
||||
defer close()
|
||||
threePID := util.RandomString(8)
|
||||
medium := util.RandomString(8)
|
||||
|
@ -507,7 +510,7 @@ func Test_Notification(t *testing.T) {
|
|||
room := test.NewRoom(t, alice)
|
||||
room2 := test.NewRoom(t, alice)
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateDatabase(t, dbType)
|
||||
db, close := mustCreateUserDatabase(t, dbType)
|
||||
defer close()
|
||||
// generate some dummy notifications
|
||||
for i := 0; i < 10; i++ {
|
||||
|
@ -571,3 +574,184 @@ func Test_Notification(t *testing.T) {
|
|||
assert.Equal(t, int64(0), total)
|
||||
})
|
||||
}
|
||||
|
||||
func mustCreateKeyDatabase(t *testing.T, dbType test.DBType) (storage.KeyDatabase, func()) {
|
||||
base, close := testrig.CreateBaseDendrite(t, dbType)
|
||||
db, err := storage.NewKeyDatabase(base, &base.Cfg.KeyServer.Database)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create new database: %v", err)
|
||||
}
|
||||
return db, close
|
||||
}
|
||||
|
||||
func MustNotError(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
t.Fatalf("operation failed: %s", err)
|
||||
}
|
||||
|
||||
func TestKeyChanges(t *testing.T) {
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, clean := mustCreateKeyDatabase(t, dbType)
|
||||
defer clean()
|
||||
_, err := db.StoreKeyChange(ctx, "@alice:localhost")
|
||||
MustNotError(t, err)
|
||||
deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost")
|
||||
MustNotError(t, err)
|
||||
deviceChangeIDC, err := db.StoreKeyChange(ctx, "@charlie:localhost")
|
||||
MustNotError(t, err)
|
||||
userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDB, types.OffsetNewest)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to KeyChanges: %s", err)
|
||||
}
|
||||
if latest != deviceChangeIDC {
|
||||
t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDC)
|
||||
}
|
||||
if !reflect.DeepEqual(userIDs, []string{"@charlie:localhost"}) {
|
||||
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestKeyChangesNoDupes(t *testing.T) {
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, clean := mustCreateKeyDatabase(t, dbType)
|
||||
defer clean()
|
||||
deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost")
|
||||
MustNotError(t, err)
|
||||
deviceChangeIDB, err := db.StoreKeyChange(ctx, "@alice:localhost")
|
||||
MustNotError(t, err)
|
||||
if deviceChangeIDA == deviceChangeIDB {
|
||||
t.Fatalf("Expected change ID to be different even when inserting key change for the same user, got %d for both changes", deviceChangeIDA)
|
||||
}
|
||||
deviceChangeID, err := db.StoreKeyChange(ctx, "@alice:localhost")
|
||||
MustNotError(t, err)
|
||||
userIDs, latest, err := db.KeyChanges(ctx, 0, types.OffsetNewest)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to KeyChanges: %s", err)
|
||||
}
|
||||
if latest != deviceChangeID {
|
||||
t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeID)
|
||||
}
|
||||
if !reflect.DeepEqual(userIDs, []string{"@alice:localhost"}) {
|
||||
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestKeyChangesUpperLimit(t *testing.T) {
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, clean := mustCreateKeyDatabase(t, dbType)
|
||||
defer clean()
|
||||
deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost")
|
||||
MustNotError(t, err)
|
||||
deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost")
|
||||
MustNotError(t, err)
|
||||
_, err = db.StoreKeyChange(ctx, "@charlie:localhost")
|
||||
MustNotError(t, err)
|
||||
userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDA, deviceChangeIDB)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to KeyChanges: %s", err)
|
||||
}
|
||||
if latest != deviceChangeIDB {
|
||||
t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDB)
|
||||
}
|
||||
if !reflect.DeepEqual(userIDs, []string{"@bob:localhost"}) {
|
||||
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
var dbLock sync.Mutex
|
||||
var deviceArray = []string{"AAA", "another_device"}
|
||||
|
||||
// The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user,
|
||||
// and that they are returned correctly when querying for device keys.
|
||||
func TestDeviceKeysStreamIDGeneration(t *testing.T) {
|
||||
var err error
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, clean := mustCreateKeyDatabase(t, dbType)
|
||||
defer clean()
|
||||
alice := "@alice:TestDeviceKeysStreamIDGeneration"
|
||||
bob := "@bob:TestDeviceKeysStreamIDGeneration"
|
||||
msgs := []api.DeviceMessage{
|
||||
{
|
||||
Type: api.TypeDeviceKeyUpdate,
|
||||
DeviceKeys: &api.DeviceKeys{
|
||||
DeviceID: "AAA",
|
||||
UserID: alice,
|
||||
KeyJSON: []byte(`{"key":"v1"}`),
|
||||
},
|
||||
// StreamID: 1
|
||||
},
|
||||
{
|
||||
Type: api.TypeDeviceKeyUpdate,
|
||||
DeviceKeys: &api.DeviceKeys{
|
||||
DeviceID: "AAA",
|
||||
UserID: bob,
|
||||
KeyJSON: []byte(`{"key":"v1"}`),
|
||||
},
|
||||
// StreamID: 1 as this is a different user
|
||||
},
|
||||
{
|
||||
Type: api.TypeDeviceKeyUpdate,
|
||||
DeviceKeys: &api.DeviceKeys{
|
||||
DeviceID: "another_device",
|
||||
UserID: alice,
|
||||
KeyJSON: []byte(`{"key":"v1"}`),
|
||||
},
|
||||
// StreamID: 2 as this is a 2nd device key
|
||||
},
|
||||
}
|
||||
MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs))
|
||||
if msgs[0].StreamID != 1 {
|
||||
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID)
|
||||
}
|
||||
if msgs[1].StreamID != 1 {
|
||||
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID)
|
||||
}
|
||||
if msgs[2].StreamID != 2 {
|
||||
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID)
|
||||
}
|
||||
|
||||
// updating a device sets the next stream ID for that user
|
||||
msgs = []api.DeviceMessage{
|
||||
{
|
||||
Type: api.TypeDeviceKeyUpdate,
|
||||
DeviceKeys: &api.DeviceKeys{
|
||||
DeviceID: "AAA",
|
||||
UserID: alice,
|
||||
KeyJSON: []byte(`{"key":"v2"}`),
|
||||
},
|
||||
// StreamID: 3
|
||||
},
|
||||
}
|
||||
MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs))
|
||||
if msgs[0].StreamID != 3 {
|
||||
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID)
|
||||
}
|
||||
|
||||
dbLock.Lock()
|
||||
defer dbLock.Unlock()
|
||||
// Querying for device keys returns the latest stream IDs
|
||||
msgs, err = db.DeviceKeysForUser(ctx, alice, deviceArray, false)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("DeviceKeysForUser returned error: %s", err)
|
||||
}
|
||||
wantStreamIDs := map[string]int64{
|
||||
"AAA": 3,
|
||||
"another_device": 2,
|
||||
}
|
||||
if len(msgs) != len(wantStreamIDs) {
|
||||
t.Fatalf("DeviceKeysForUser: wrong number of devices, got %d want %d", len(msgs), len(wantStreamIDs))
|
||||
}
|
||||
for _, m := range msgs {
|
||||
if m.StreamID != wantStreamIDs[m.DeviceID] {
|
||||
t.Errorf("DeviceKeysForUser: wrong returned stream ID for key, got %d want %d", m.StreamID, wantStreamIDs[m.DeviceID])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -32,10 +32,10 @@ func NewUserAPIDatabase(
|
|||
openIDTokenLifetimeMS int64,
|
||||
loginTokenLifetime time.Duration,
|
||||
serverNoticesLocalpart string,
|
||||
) (Database, error) {
|
||||
) (UserDatabase, error) {
|
||||
switch {
|
||||
case dbProperties.ConnectionString.IsSQLite():
|
||||
return sqlite3.NewDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart)
|
||||
return sqlite3.NewUserDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart)
|
||||
case dbProperties.ConnectionString.IsPostgres():
|
||||
return nil, fmt.Errorf("can't use Postgres implementation")
|
||||
default:
|
||||
|
|
|
@ -20,10 +20,10 @@ import (
|
|||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/types"
|
||||
)
|
||||
|
||||
|
@ -145,3 +145,47 @@ const (
|
|||
// uint32.
|
||||
AllNotifications NotificationFilter = (1 << 31) - 1
|
||||
)
|
||||
|
||||
type OneTimeKeys interface {
|
||||
SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
|
||||
CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)
|
||||
InsertOneTimeKeys(ctx context.Context, txn *sql.Tx, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error)
|
||||
// SelectAndDeleteOneTimeKey selects a single one time key matching the user/device/algorithm specified and returns the algo:key_id => JSON.
|
||||
// Returns an empty map if the key does not exist.
|
||||
SelectAndDeleteOneTimeKey(ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string) (map[string]json.RawMessage, error)
|
||||
DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error
|
||||
}
|
||||
|
||||
type DeviceKeys interface {
|
||||
SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
|
||||
InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error
|
||||
SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error)
|
||||
CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error)
|
||||
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error)
|
||||
DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error
|
||||
DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error
|
||||
}
|
||||
|
||||
type KeyChanges interface {
|
||||
InsertKeyChange(ctx context.Context, userID string) (int64, error)
|
||||
// SelectKeyChanges returns the set (de-duplicated) of users who have changed their keys between the two offsets.
|
||||
// Results are exclusive of fromOffset and inclusive of toOffset. A toOffset of types.OffsetNewest means no upper offset.
|
||||
SelectKeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error)
|
||||
}
|
||||
|
||||
type StaleDeviceLists interface {
|
||||
InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error
|
||||
SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error)
|
||||
DeleteStaleDeviceLists(ctx context.Context, txn *sql.Tx, userIDs []string) error
|
||||
}
|
||||
|
||||
type CrossSigningKeys interface {
|
||||
SelectCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string) (r types.CrossSigningKeyMap, err error)
|
||||
UpsertCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes) error
|
||||
}
|
||||
|
||||
type CrossSigningSigs interface {
|
||||
SelectCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (r types.CrossSigningSigMap, err error)
|
||||
UpsertCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error
|
||||
DeleteCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, targetUserID string, targetKeyID gomatrixserverlib.KeyID) error
|
||||
}
|
||||
|
|
94
userapi/storage/tables/stale_device_lists_test.go
Normal file
94
userapi/storage/tables/stale_device_lists_test.go
Normal file
|
@ -0,0 +1,94 @@
|
|||
package tables_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/dendrite/userapi/storage/postgres"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/sqlite3"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
)
|
||||
|
||||
func mustCreateTable(t *testing.T, dbType test.DBType) (tab tables.StaleDeviceLists, close func()) {
|
||||
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||
db, err := sqlutil.Open(&config.DatabaseOptions{
|
||||
ConnectionString: config.DataSource(connStr),
|
||||
}, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open database: %s", err)
|
||||
}
|
||||
switch dbType {
|
||||
case test.DBTypePostgres:
|
||||
tab, err = postgres.NewPostgresStaleDeviceListsTable(db)
|
||||
case test.DBTypeSQLite:
|
||||
tab, err = sqlite3.NewSqliteStaleDeviceListsTable(db)
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create new table: %s", err)
|
||||
}
|
||||
return tab, close
|
||||
}
|
||||
|
||||
func TestStaleDeviceLists(t *testing.T) {
|
||||
alice := test.NewUser(t)
|
||||
bob := test.NewUser(t)
|
||||
charlie := "@charlie:localhost"
|
||||
ctx := context.Background()
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
tab, closeDB := mustCreateTable(t, dbType)
|
||||
defer closeDB()
|
||||
|
||||
if err := tab.InsertStaleDeviceList(ctx, alice.ID, true); err != nil {
|
||||
t.Fatalf("failed to insert stale device: %s", err)
|
||||
}
|
||||
if err := tab.InsertStaleDeviceList(ctx, bob.ID, true); err != nil {
|
||||
t.Fatalf("failed to insert stale device: %s", err)
|
||||
}
|
||||
if err := tab.InsertStaleDeviceList(ctx, charlie, true); err != nil {
|
||||
t.Fatalf("failed to insert stale device: %s", err)
|
||||
}
|
||||
|
||||
// Query one server
|
||||
wantStaleUsers := []string{alice.ID, bob.ID}
|
||||
gotStaleUsers, err := tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to query stale device lists: %s", err)
|
||||
}
|
||||
if !test.UnsortedStringSliceEqual(wantStaleUsers, gotStaleUsers) {
|
||||
t.Fatalf("expected stale users %v, got %v", wantStaleUsers, gotStaleUsers)
|
||||
}
|
||||
|
||||
// Query all servers
|
||||
wantStaleUsers = []string{alice.ID, bob.ID, charlie}
|
||||
gotStaleUsers, err = tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to query stale device lists: %s", err)
|
||||
}
|
||||
if !test.UnsortedStringSliceEqual(wantStaleUsers, gotStaleUsers) {
|
||||
t.Fatalf("expected stale users %v, got %v", wantStaleUsers, gotStaleUsers)
|
||||
}
|
||||
|
||||
// Delete stale devices
|
||||
deleteUsers := []string{alice.ID, bob.ID}
|
||||
if err = tab.DeleteStaleDeviceLists(ctx, nil, deleteUsers); err != nil {
|
||||
t.Fatalf("failed to delete stale device lists: %s", err)
|
||||
}
|
||||
|
||||
// Verify we don't get anything back after deleting
|
||||
gotStaleUsers, err = tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to query stale device lists: %s", err)
|
||||
}
|
||||
|
||||
if gotCount := len(gotStaleUsers); gotCount > 0 {
|
||||
t.Fatalf("expected no stale users, got %d", gotCount)
|
||||
}
|
||||
})
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue