Merge keyserver & userapi (#2972)

As discussed yesterday, a first draft of merging the keyserver and the
userapi.
This commit is contained in:
Till 2023-02-20 14:58:03 +01:00 committed by GitHub
parent bd6f0c14e5
commit 4594233f89
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
107 changed files with 1730 additions and 1863 deletions

View file

@ -90,7 +90,7 @@ type KeyBackup interface {
type LoginToken interface {
// CreateLoginToken generates a token, stores and returns it. The lifetime is
// determined by the loginTokenLifetime given to the Database constructor.
// determined by the loginTokenLifetime given to the UserDatabase constructor.
CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error)
// RemoveLoginToken removes the named token (and may clean up other expired tokens).
@ -130,7 +130,7 @@ type Notification interface {
DeleteOldNotifications(ctx context.Context) error
}
type Database interface {
type UserDatabase interface {
Account
AccountData
Device
@ -144,6 +144,78 @@ type Database interface {
ThreePID
}
type KeyChangeDatabase interface {
// StoreKeyChange stores key change metadata and returns the device change ID which represents the position in the /sync stream for this device change.
// `userID` is the the user who has changed their keys in some way.
StoreKeyChange(ctx context.Context, userID string) (int64, error)
}
type KeyDatabase interface {
KeyChangeDatabase
// ExistingOneTimeKeys returns a map of keyIDWithAlgorithm to key JSON for the given parameters. If no keys exist with this combination
// of user/device/key/algorithm 4-uple then it is omitted from the map. Returns an error when failing to communicate with the database.
ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
// StoreOneTimeKeys persists the given one-time keys.
StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error)
// OneTimeKeysCount returns a count of all OTKs for this device.
OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
// StoreLocalDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
// for this (user, device).
// The `StreamID` for each message is set on successful insertion. In the event the key already exists, the existing StreamID is set.
// Returns an error if there was a problem storing the keys.
StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error
// StoreRemoteDeviceKeys persists the given keys. Keys with the same user ID and device ID will be replaced. An empty KeyJSON removes the key
// for this (user, device). Does not modify the stream ID for keys. User IDs in `clearUserIDs` will have all their device keys deleted prior
// to insertion - use this when you have a complete snapshot of a user's keys in order to track device deletions correctly.
StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error
// PrevIDsExists returns true if all prev IDs exist for this user.
PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error)
// DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected.
// If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice.
DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error)
// DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying
// cross-signing signatures relating to that device.
DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error
// ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key
// cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice.
ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error)
// KeyChanges returns a list of user IDs who have modified their keys from the offset given (exclusive) to the offset given (inclusive).
// A to offset of types.OffsetNewest means no upper limit.
// Returns the offset of the latest key change.
KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error)
// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
// If no domains are given, all user IDs with stale device lists are returned.
StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error)
// MarkDeviceListStale sets the stale bit for this user to isStale.
MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error
CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error)
CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error)
CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error)
StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error
StoreCrossSigningSigsForTarget(ctx context.Context, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error
DeleteStaleDeviceLists(
ctx context.Context,
userIDs []string,
) error
}
type Statistics interface {
UserStatistics(ctx context.Context) (*types.UserStatistics, *types.DatabaseEngine, error)
DailyRoomsMessages(ctx context.Context, serverName gomatrixserverlib.ServerName) (stats types.MessageStats, activeRooms, activeE2EERooms int64, err error)

View file

@ -78,7 +78,13 @@ func (s *accountDataStatements) InsertAccountData(
roomID, dataType string, content json.RawMessage,
) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertAccountDataStmt)
_, err = stmt.ExecContext(ctx, localpart, serverName, roomID, dataType, content)
// Empty/nil json.RawMessage is not interpreted as "nil", so use *json.RawMessage
// when passing the data to trigger "NOT NULL" constraint
var data *json.RawMessage
if len(content) > 0 {
data = &content
}
_, err = stmt.ExecContext(ctx, localpart, serverName, roomID, dataType, data)
return
}

View file

@ -0,0 +1,102 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/dendrite/userapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
var crossSigningKeysSchema = `
CREATE TABLE IF NOT EXISTS keyserver_cross_signing_keys (
user_id TEXT NOT NULL,
key_type SMALLINT NOT NULL,
key_data TEXT NOT NULL,
PRIMARY KEY (user_id, key_type)
);
`
const selectCrossSigningKeysForUserSQL = "" +
"SELECT key_type, key_data FROM keyserver_cross_signing_keys" +
" WHERE user_id = $1"
const upsertCrossSigningKeysForUserSQL = "" +
"INSERT INTO keyserver_cross_signing_keys (user_id, key_type, key_data)" +
" VALUES($1, $2, $3)" +
" ON CONFLICT (user_id, key_type) DO UPDATE SET key_data = $3"
type crossSigningKeysStatements struct {
db *sql.DB
selectCrossSigningKeysForUserStmt *sql.Stmt
upsertCrossSigningKeysForUserStmt *sql.Stmt
}
func NewPostgresCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error) {
s := &crossSigningKeysStatements{
db: db,
}
_, err := db.Exec(crossSigningKeysSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.selectCrossSigningKeysForUserStmt, selectCrossSigningKeysForUserSQL},
{&s.upsertCrossSigningKeysForUserStmt, upsertCrossSigningKeysForUserSQL},
}.Prepare(db)
}
func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser(
ctx context.Context, txn *sql.Tx, userID string,
) (r types.CrossSigningKeyMap, err error) {
rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningKeysForUserStmt).QueryContext(ctx, userID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningKeysForUserStmt: rows.close() failed")
r = types.CrossSigningKeyMap{}
for rows.Next() {
var keyTypeInt int16
var keyData gomatrixserverlib.Base64Bytes
if err := rows.Scan(&keyTypeInt, &keyData); err != nil {
return nil, err
}
keyType, ok := types.KeyTypeIntToPurpose[keyTypeInt]
if !ok {
return nil, fmt.Errorf("unknown key purpose int %d", keyTypeInt)
}
r[keyType] = keyData
}
return
}
func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser(
ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes,
) error {
keyTypeInt, ok := types.KeyTypePurposeToInt[keyType]
if !ok {
return fmt.Errorf("unknown key purpose %q", keyType)
}
if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningKeysForUserStmt).ExecContext(ctx, userID, keyTypeInt, keyData); err != nil {
return fmt.Errorf("s.upsertCrossSigningKeysForUserStmt: %w", err)
}
return nil
}

View file

@ -0,0 +1,131 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/postgres/deltas"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/dendrite/userapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
var crossSigningSigsSchema = `
CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs (
origin_user_id TEXT NOT NULL,
origin_key_id TEXT NOT NULL,
target_user_id TEXT NOT NULL,
target_key_id TEXT NOT NULL,
signature TEXT NOT NULL,
PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id)
);
CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id);
`
const selectCrossSigningSigsForTargetSQL = "" +
"SELECT origin_user_id, origin_key_id, signature FROM keyserver_cross_signing_sigs" +
" WHERE (origin_user_id = $1 OR origin_user_id = $2) AND target_user_id = $2 AND target_key_id = $3"
const upsertCrossSigningSigsForTargetSQL = "" +
"INSERT INTO keyserver_cross_signing_sigs (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" +
" VALUES($1, $2, $3, $4, $5)" +
" ON CONFLICT (origin_user_id, origin_key_id, target_user_id, target_key_id) DO UPDATE SET signature = $5"
const deleteCrossSigningSigsForTargetSQL = "" +
"DELETE FROM keyserver_cross_signing_sigs WHERE target_user_id=$1 AND target_key_id=$2"
type crossSigningSigsStatements struct {
db *sql.DB
selectCrossSigningSigsForTargetStmt *sql.Stmt
upsertCrossSigningSigsForTargetStmt *sql.Stmt
deleteCrossSigningSigsForTargetStmt *sql.Stmt
}
func NewPostgresCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, error) {
s := &crossSigningSigsStatements{
db: db,
}
_, err := db.Exec(crossSigningSigsSchema)
if err != nil {
return nil, err
}
m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "keyserver: cross signing signature indexes",
Up: deltas.UpFixCrossSigningSignatureIndexes,
})
if err = m.Up(context.Background()); err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.selectCrossSigningSigsForTargetStmt, selectCrossSigningSigsForTargetSQL},
{&s.upsertCrossSigningSigsForTargetStmt, upsertCrossSigningSigsForTargetSQL},
{&s.deleteCrossSigningSigsForTargetStmt, deleteCrossSigningSigsForTargetSQL},
}.Prepare(db)
}
func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget(
ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID,
) (r types.CrossSigningSigMap, err error) {
rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningSigsForTargetStmt).QueryContext(ctx, originUserID, targetUserID, targetKeyID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningSigsForTargetStmt: rows.close() failed")
r = types.CrossSigningSigMap{}
for rows.Next() {
var userID string
var keyID gomatrixserverlib.KeyID
var signature gomatrixserverlib.Base64Bytes
if err := rows.Scan(&userID, &keyID, &signature); err != nil {
return nil, err
}
if _, ok := r[userID]; !ok {
r[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
}
r[userID][keyID] = signature
}
return
}
func (s *crossSigningSigsStatements) UpsertCrossSigningSigsForTarget(
ctx context.Context, txn *sql.Tx,
originUserID string, originKeyID gomatrixserverlib.KeyID,
targetUserID string, targetKeyID gomatrixserverlib.KeyID,
signature gomatrixserverlib.Base64Bytes,
) error {
if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningSigsForTargetStmt).ExecContext(ctx, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil {
return fmt.Errorf("s.upsertCrossSigningSigsForTargetStmt: %w", err)
}
return nil
}
func (s *crossSigningSigsStatements) DeleteCrossSigningSigsForTarget(
ctx context.Context, txn *sql.Tx,
targetUserID string, targetKeyID gomatrixserverlib.KeyID,
) error {
if _, err := sqlutil.TxStmt(txn, s.deleteCrossSigningSigsForTargetStmt).ExecContext(ctx, targetUserID, targetKeyID); err != nil {
return fmt.Errorf("s.deleteCrossSigningSigsForTargetStmt: %w", err)
}
return nil
}

View file

@ -0,0 +1,69 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package deltas
import (
"context"
"database/sql"
"fmt"
)
func UpRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error {
// start counting from the last max offset, else 0. We need to do a count(*) first to see if there
// even are entries in this table to know if we can query for log_offset. Without the count then
// the query to SELECT the max log offset fails on new Dendrite instances as log_offset doesn't
// exist on that table. Even though we discard the error, the txn is tainted and gets aborted :/
var count int
_ = tx.QueryRowContext(ctx, `SELECT count(*) FROM keyserver_key_changes`).Scan(&count)
if count > 0 {
var maxOffset int64
_ = tx.QueryRowContext(ctx, `SELECT coalesce(MAX(log_offset), 0) AS offset FROM keyserver_key_changes`).Scan(&maxOffset)
if _, err := tx.ExecContext(ctx, fmt.Sprintf(`CREATE SEQUENCE IF NOT EXISTS keyserver_key_changes_seq START %d`, maxOffset)); err != nil {
return fmt.Errorf("failed to CREATE SEQUENCE for key changes, starting at %d: %s", maxOffset, err)
}
}
_, err := tx.ExecContext(ctx, `
-- make the new table
DROP TABLE IF EXISTS keyserver_key_changes;
CREATE TABLE IF NOT EXISTS keyserver_key_changes (
change_id BIGINT PRIMARY KEY DEFAULT nextval('keyserver_key_changes_seq'),
user_id TEXT NOT NULL,
CONSTRAINT keyserver_key_changes_unique_per_user UNIQUE (user_id)
);
`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
func DownRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, `
-- Drop all data and revert back, we can't keep the data as Kafka offsets determine the numbers
DROP SEQUENCE IF EXISTS keyserver_key_changes_seq;
DROP TABLE IF EXISTS keyserver_key_changes;
CREATE TABLE IF NOT EXISTS keyserver_key_changes (
partition BIGINT NOT NULL,
log_offset BIGINT NOT NULL,
user_id TEXT NOT NULL,
CONSTRAINT keyserver_key_changes_unique UNIQUE (partition, log_offset)
);
`)
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
return nil
}

View file

@ -0,0 +1,47 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package deltas
import (
"context"
"database/sql"
"fmt"
)
func UpFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, `
ALTER TABLE keyserver_cross_signing_sigs DROP CONSTRAINT keyserver_cross_signing_sigs_pkey;
ALTER TABLE keyserver_cross_signing_sigs ADD PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id);
CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id);
`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
func DownFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, `
ALTER TABLE keyserver_cross_signing_sigs DROP CONSTRAINT keyserver_cross_signing_sigs_pkey;
ALTER TABLE keyserver_cross_signing_sigs ADD PRIMARY KEY (origin_user_id, target_user_id, target_key_id);
DROP INDEX IF EXISTS keyserver_cross_signing_sigs_idx;
`)
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
return nil
}

View file

@ -0,0 +1,213 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"time"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
)
var deviceKeysSchema = `
-- Stores device keys for users
CREATE TABLE IF NOT EXISTS keyserver_device_keys (
user_id TEXT NOT NULL,
device_id TEXT NOT NULL,
ts_added_secs BIGINT NOT NULL,
key_json TEXT NOT NULL,
-- the stream ID of this key, scoped per-user. This gets updated when the device key changes.
-- This means we do not store an unbounded append-only log of device keys, which is not actually
-- required in the spec because in the event of a missed update the server fetches the entire
-- current set of keys rather than trying to 'fast-forward' or catchup missing stream IDs.
stream_id BIGINT NOT NULL,
display_name TEXT,
-- Clobber based on tuple of user/device.
CONSTRAINT keyserver_device_keys_unique UNIQUE (user_id, device_id)
);
`
const upsertDeviceKeysSQL = "" +
"INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" +
" VALUES ($1, $2, $3, $4, $5, $6)" +
" ON CONFLICT ON CONSTRAINT keyserver_device_keys_unique" +
" DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6"
const selectDeviceKeysSQL = "" +
"SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
const selectBatchDeviceKeysSQL = "" +
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''"
const selectBatchDeviceKeysWithEmptiesSQL = "" +
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1"
const selectMaxStreamForUserSQL = "" +
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
const countStreamIDsForUserSQL = "" +
"SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id = ANY($2)"
const deleteDeviceKeysSQL = "" +
"DELETE FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
const deleteAllDeviceKeysSQL = "" +
"DELETE FROM keyserver_device_keys WHERE user_id=$1"
type deviceKeysStatements struct {
db *sql.DB
upsertDeviceKeysStmt *sql.Stmt
selectDeviceKeysStmt *sql.Stmt
selectBatchDeviceKeysStmt *sql.Stmt
selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt
selectMaxStreamForUserStmt *sql.Stmt
countStreamIDsForUserStmt *sql.Stmt
deleteDeviceKeysStmt *sql.Stmt
deleteAllDeviceKeysStmt *sql.Stmt
}
func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
s := &deviceKeysStatements{
db: db,
}
_, err := db.Exec(deviceKeysSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.upsertDeviceKeysStmt, upsertDeviceKeysSQL},
{&s.selectDeviceKeysStmt, selectDeviceKeysSQL},
{&s.selectBatchDeviceKeysStmt, selectBatchDeviceKeysSQL},
{&s.selectBatchDeviceKeysWithEmptiesStmt, selectBatchDeviceKeysWithEmptiesSQL},
{&s.selectMaxStreamForUserStmt, selectMaxStreamForUserSQL},
{&s.countStreamIDsForUserStmt, countStreamIDsForUserSQL},
{&s.deleteDeviceKeysStmt, deleteDeviceKeysSQL},
{&s.deleteAllDeviceKeysStmt, deleteAllDeviceKeysSQL},
}.Prepare(db)
}
func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
for i, key := range keys {
var keyJSONStr string
var streamID int64
var displayName sql.NullString
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName)
if err != nil && err != sql.ErrNoRows {
return err
}
// this will be '' when there is no device
keys[i].Type = api.TypeDeviceKeyUpdate
keys[i].KeyJSON = []byte(keyJSONStr)
keys[i].StreamID = streamID
if displayName.Valid {
keys[i].DisplayName = displayName.String
}
}
return nil
}
func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error) {
// nullable if there are no results
var nullStream sql.NullInt64
err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
if err == sql.ErrNoRows {
err = nil
}
if nullStream.Valid {
streamID = nullStream.Int64
}
return
}
func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) {
// nullable if there are no results
var count sql.NullInt32
err := s.countStreamIDsForUserStmt.QueryRowContext(ctx, userID, pq.Int64Array(streamIDs)).Scan(&count)
if err != nil {
return 0, err
}
if count.Valid {
return int(count.Int32), nil
}
return 0, nil
}
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
for _, key := range keys {
now := time.Now().Unix()
_, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext(
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName,
)
if err != nil {
return err
}
}
return nil
}
func (s *deviceKeysStatements) DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error {
_, err := sqlutil.TxStmt(txn, s.deleteDeviceKeysStmt).ExecContext(ctx, userID, deviceID)
return err
}
func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error {
_, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
return err
}
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
var stmt *sql.Stmt
if includeEmpty {
stmt = s.selectBatchDeviceKeysWithEmptiesStmt
} else {
stmt = s.selectBatchDeviceKeysStmt
}
rows, err := stmt.QueryContext(ctx, userID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed")
deviceIDMap := make(map[string]bool)
for _, d := range deviceIDs {
deviceIDMap[d] = true
}
var result []api.DeviceMessage
var displayName sql.NullString
for rows.Next() {
dk := api.DeviceMessage{
Type: api.TypeDeviceKeyUpdate,
DeviceKeys: &api.DeviceKeys{
UserID: userID,
},
}
if err := rows.Scan(&dk.DeviceID, &dk.KeyJSON, &dk.StreamID, &displayName); err != nil {
return nil, err
}
if displayName.Valid {
dk.DisplayName = displayName.String
}
// include the key if we want all keys (no device) or it was asked
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
result = append(result, dk)
}
}
return result, rows.Err()
}

View file

@ -160,7 +160,7 @@ func (s *devicesStatements) InsertDevice(
if err := stmt.QueryRowContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, createdTimeMS, ipAddr, userAgent).Scan(&sessionID); err != nil {
return nil, fmt.Errorf("insertDeviceStmt: %w", err)
}
return &api.Device{
dev := &api.Device{
ID: id,
UserID: userutil.MakeUserID(localpart, serverName),
AccessToken: accessToken,
@ -168,7 +168,11 @@ func (s *devicesStatements) InsertDevice(
LastSeenTS: createdTimeMS,
LastSeenIP: ipAddr,
UserAgent: userAgent,
}, nil
}
if displayName != nil {
dev.DisplayName = *displayName
}
return dev, nil
}
func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn *sql.Tx, id,

View file

@ -52,7 +52,7 @@ const updateBackupKeySQL = "" +
const countKeysSQL = "" +
"SELECT COUNT(*) FROM userapi_key_backups WHERE user_id = $1 AND version = $2"
const selectKeysSQL = "" +
const selectBackupKeysSQL = "" +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " +
"WHERE user_id = $1 AND version = $2"
@ -83,7 +83,7 @@ func NewPostgresKeyBackupTable(db *sql.DB) (tables.KeyBackupTable, error) {
{&s.insertBackupKeyStmt, insertBackupKeySQL},
{&s.updateBackupKeyStmt, updateBackupKeySQL},
{&s.countKeysStmt, countKeysSQL},
{&s.selectKeysStmt, selectKeysSQL},
{&s.selectKeysStmt, selectBackupKeysSQL},
{&s.selectKeysByRoomIDStmt, selectKeysByRoomIDSQL},
{&s.selectKeysByRoomIDAndSessionIDStmt, selectKeysByRoomIDAndSessionIDSQL},
}.Prepare(db)

View file

@ -0,0 +1,127 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"errors"
"fmt"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/postgres/deltas"
"github.com/matrix-org/dendrite/userapi/storage/tables"
)
var keyChangesSchema = `
-- Stores key change information about users. Used to determine when to send updated device lists to clients.
CREATE SEQUENCE IF NOT EXISTS keyserver_key_changes_seq;
CREATE TABLE IF NOT EXISTS keyserver_key_changes (
change_id BIGINT PRIMARY KEY DEFAULT nextval('keyserver_key_changes_seq'),
user_id TEXT NOT NULL,
CONSTRAINT keyserver_key_changes_unique_per_user UNIQUE (user_id)
);
`
// Replace based on user ID. We don't care how many times the user's keys have changed, only that they
// have changed, hence we can just keep bumping the change ID for this user.
const upsertKeyChangeSQL = "" +
"INSERT INTO keyserver_key_changes (user_id)" +
" VALUES ($1)" +
" ON CONFLICT ON CONSTRAINT keyserver_key_changes_unique_per_user" +
" DO UPDATE SET change_id = nextval('keyserver_key_changes_seq')" +
" RETURNING change_id"
const selectKeyChangesSQL = "" +
"SELECT user_id, change_id FROM keyserver_key_changes WHERE change_id > $1 AND change_id <= $2"
type keyChangesStatements struct {
db *sql.DB
upsertKeyChangeStmt *sql.Stmt
selectKeyChangesStmt *sql.Stmt
}
func NewPostgresKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
s := &keyChangesStatements{
db: db,
}
_, err := db.Exec(keyChangesSchema)
if err != nil {
return s, err
}
if err = executeMigration(context.Background(), db); err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.upsertKeyChangeStmt, upsertKeyChangeSQL},
{&s.selectKeyChangesStmt, selectKeyChangesSQL},
}.Prepare(db)
}
func executeMigration(ctx context.Context, db *sql.DB) error {
// TODO: Remove when we are sure we are not having goose artefacts in the db
// This forces an error, which indicates the migration is already applied, since the
// column partition was removed from the table
migrationName := "keyserver: refactor key changes"
var cName string
err := db.QueryRowContext(ctx, "select column_name from information_schema.columns where table_name = 'keyserver_key_changes' AND column_name = 'partition'").Scan(&cName)
if err != nil {
if errors.Is(err, sql.ErrNoRows) { // migration was already executed, as the column was removed
if err = sqlutil.InsertMigration(ctx, db, migrationName); err != nil {
return fmt.Errorf("unable to manually insert migration '%s': %w", migrationName, err)
}
return nil
}
return err
}
m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: migrationName,
Up: deltas.UpRefactorKeyChanges,
})
return m.Up(ctx)
}
func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, userID string) (changeID int64, err error) {
err = s.upsertKeyChangeStmt.QueryRowContext(ctx, userID).Scan(&changeID)
return
}
func (s *keyChangesStatements) SelectKeyChanges(
ctx context.Context, fromOffset, toOffset int64,
) (userIDs []string, latestOffset int64, err error) {
latestOffset = fromOffset
rows, err := s.selectKeyChangesStmt.QueryContext(ctx, fromOffset, toOffset)
if err != nil {
return nil, 0, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectKeyChangesStmt: rows.close() failed")
for rows.Next() {
var userID string
var offset int64
if err := rows.Scan(&userID, &offset); err != nil {
return nil, 0, err
}
if offset > latestOffset {
latestOffset = offset
}
userIDs = append(userIDs, userID)
}
return
}

View file

@ -0,0 +1,194 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"encoding/json"
"time"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
)
var oneTimeKeysSchema = `
-- Stores one-time public keys for users
CREATE TABLE IF NOT EXISTS keyserver_one_time_keys (
user_id TEXT NOT NULL,
device_id TEXT NOT NULL,
key_id TEXT NOT NULL,
algorithm TEXT NOT NULL,
ts_added_secs BIGINT NOT NULL,
key_json TEXT NOT NULL,
-- Clobber based on 4-uple of user/device/key/algorithm.
CONSTRAINT keyserver_one_time_keys_unique UNIQUE (user_id, device_id, key_id, algorithm)
);
CREATE INDEX IF NOT EXISTS keyserver_one_time_keys_idx ON keyserver_one_time_keys (user_id, device_id);
`
const upsertKeysSQL = "" +
"INSERT INTO keyserver_one_time_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json)" +
" VALUES ($1, $2, $3, $4, $5, $6)" +
" ON CONFLICT ON CONSTRAINT keyserver_one_time_keys_unique" +
" DO UPDATE SET key_json = $6"
const selectOneTimeKeysSQL = "" +
"SELECT concat(algorithm, ':', key_id) as algorithmwithid, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 AND concat(algorithm, ':', key_id) = ANY($3);"
const selectKeysCountSQL = "" +
"SELECT algorithm, COUNT(key_id) FROM " +
" (SELECT algorithm, key_id FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 LIMIT 100)" +
" x GROUP BY algorithm"
const deleteOneTimeKeySQL = "" +
"DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 AND key_id = $4"
const selectKeyByAlgorithmSQL = "" +
"SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1"
const deleteOneTimeKeysSQL = "" +
"DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2"
type oneTimeKeysStatements struct {
db *sql.DB
upsertKeysStmt *sql.Stmt
selectKeysStmt *sql.Stmt
selectKeysCountStmt *sql.Stmt
selectKeyByAlgorithmStmt *sql.Stmt
deleteOneTimeKeyStmt *sql.Stmt
deleteOneTimeKeysStmt *sql.Stmt
}
func NewPostgresOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
s := &oneTimeKeysStatements{
db: db,
}
_, err := db.Exec(oneTimeKeysSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.upsertKeysStmt, upsertKeysSQL},
{&s.selectKeysStmt, selectOneTimeKeysSQL},
{&s.selectKeysCountStmt, selectKeysCountSQL},
{&s.selectKeyByAlgorithmStmt, selectKeyByAlgorithmSQL},
{&s.deleteOneTimeKeyStmt, deleteOneTimeKeySQL},
{&s.deleteOneTimeKeysStmt, deleteOneTimeKeysSQL},
}.Prepare(db)
}
func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
rows, err := s.selectKeysStmt.QueryContext(ctx, userID, deviceID, pq.Array(keyIDsWithAlgorithms))
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt: rows.close() failed")
result := make(map[string]json.RawMessage)
var (
algorithmWithID string
keyJSONStr string
)
for rows.Next() {
if err := rows.Scan(&algorithmWithID, &keyJSONStr); err != nil {
return nil, err
}
result[algorithmWithID] = json.RawMessage(keyJSONStr)
}
return result, rows.Err()
}
func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) {
counts := &api.OneTimeKeysCount{
DeviceID: deviceID,
UserID: userID,
KeyCount: make(map[string]int),
}
rows, err := s.selectKeysCountStmt.QueryContext(ctx, userID, deviceID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
for rows.Next() {
var algorithm string
var count int
if err = rows.Scan(&algorithm, &count); err != nil {
return nil, err
}
counts.KeyCount[algorithm] = count
}
return counts, nil
}
func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, txn *sql.Tx, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) {
now := time.Now().Unix()
counts := &api.OneTimeKeysCount{
DeviceID: keys.DeviceID,
UserID: keys.UserID,
KeyCount: make(map[string]int),
}
for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
algo, keyID := keys.Split(keyIDWithAlgo)
_, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext(
ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON),
)
if err != nil {
return nil, err
}
}
rows, err := sqlutil.TxStmt(txn, s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
for rows.Next() {
var algorithm string
var count int
if err = rows.Scan(&algorithm, &count); err != nil {
return nil, err
}
counts.KeyCount[algorithm] = count
}
return counts, rows.Err()
}
func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey(
ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string,
) (map[string]json.RawMessage, error) {
var keyID string
var keyJSON string
err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
}
_, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
return map[string]json.RawMessage{
algorithm + ":" + keyID: json.RawMessage(keyJSON),
}, err
}
func (s *oneTimeKeysStatements) DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error {
_, err := sqlutil.TxStmt(txn, s.deleteOneTimeKeysStmt).ExecContext(ctx, userID, deviceID)
return err
}

View file

@ -0,0 +1,131 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"time"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
)
var staleDeviceListsSchema = `
-- Stores whether a user's device lists are stale or not.
CREATE TABLE IF NOT EXISTS keyserver_stale_device_lists (
user_id TEXT PRIMARY KEY NOT NULL,
domain TEXT NOT NULL,
is_stale BOOLEAN NOT NULL,
ts_added_secs BIGINT NOT NULL
);
CREATE INDEX IF NOT EXISTS keyserver_stale_device_lists_idx ON keyserver_stale_device_lists (domain, is_stale);
`
const upsertStaleDeviceListSQL = "" +
"INSERT INTO keyserver_stale_device_lists (user_id, domain, is_stale, ts_added_secs)" +
" VALUES ($1, $2, $3, $4)" +
" ON CONFLICT (user_id)" +
" DO UPDATE SET is_stale = $3, ts_added_secs = $4"
const selectStaleDeviceListsWithDomainsSQL = "" +
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2 ORDER BY ts_added_secs DESC"
const selectStaleDeviceListsSQL = "" +
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 ORDER BY ts_added_secs DESC"
const deleteStaleDevicesSQL = "" +
"DELETE FROM keyserver_stale_device_lists WHERE user_id = ANY($1)"
type staleDeviceListsStatements struct {
upsertStaleDeviceListStmt *sql.Stmt
selectStaleDeviceListsWithDomainsStmt *sql.Stmt
selectStaleDeviceListsStmt *sql.Stmt
deleteStaleDeviceListsStmt *sql.Stmt
}
func NewPostgresStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) {
s := &staleDeviceListsStatements{}
_, err := db.Exec(staleDeviceListsSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.upsertStaleDeviceListStmt, upsertStaleDeviceListSQL},
{&s.selectStaleDeviceListsStmt, selectStaleDeviceListsSQL},
{&s.selectStaleDeviceListsWithDomainsStmt, selectStaleDeviceListsWithDomainsSQL},
{&s.deleteStaleDeviceListsStmt, deleteStaleDevicesSQL},
}.Prepare(db)
}
func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error {
_, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
return err
}
_, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, gomatrixserverlib.AsTimestamp(time.Now()))
return err
}
func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
// we only query for 1 domain or all domains so optimise for those use cases
if len(domains) == 0 {
rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true)
if err != nil {
return nil, err
}
return rowsToUserIDs(ctx, rows)
}
var result []string
for _, domain := range domains {
rows, err := s.selectStaleDeviceListsWithDomainsStmt.QueryContext(ctx, true, string(domain))
if err != nil {
return nil, err
}
userIDs, err := rowsToUserIDs(ctx, rows)
if err != nil {
return nil, err
}
result = append(result, userIDs...)
}
return result, nil
}
// DeleteStaleDeviceLists removes users from stale device lists
func (s *staleDeviceListsStatements) DeleteStaleDeviceLists(
ctx context.Context, txn *sql.Tx, userIDs []string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteStaleDeviceListsStmt)
_, err := stmt.ExecContext(ctx, pq.Array(userIDs))
return err
}
func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) {
defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed")
for rows.Next() {
var userID string
if err := rows.Scan(&userID); err != nil {
return nil, err
}
result = append(result, userID)
}
return result, rows.Err()
}

View file

@ -136,3 +136,44 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
OpenIDTokenLifetimeMS: openIDTokenLifetimeMS,
}, nil
}
func NewKeyDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.KeyDatabase, error) {
db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter())
if err != nil {
return nil, err
}
otk, err := NewPostgresOneTimeKeysTable(db)
if err != nil {
return nil, err
}
dk, err := NewPostgresDeviceKeysTable(db)
if err != nil {
return nil, err
}
kc, err := NewPostgresKeyChangesTable(db)
if err != nil {
return nil, err
}
sdl, err := NewPostgresStaleDeviceListsTable(db)
if err != nil {
return nil, err
}
csk, err := NewPostgresCrossSigningKeysTable(db)
if err != nil {
return nil, err
}
css, err := NewPostgresCrossSigningSigsTable(db)
if err != nil {
return nil, err
}
return &shared.KeyDatabase{
OneTimeKeysTable: otk,
DeviceKeysTable: dk,
KeyChangesTable: kc,
StaleDeviceListsTable: sdl,
CrossSigningKeysTable: csk,
CrossSigningSigsTable: css,
Writer: writer,
}, nil
}

View file

@ -59,6 +59,17 @@ type Database struct {
OpenIDTokenLifetimeMS int64
}
type KeyDatabase struct {
OneTimeKeysTable tables.OneTimeKeys
DeviceKeysTable tables.DeviceKeys
KeyChangesTable tables.KeyChanges
StaleDeviceListsTable tables.StaleDeviceLists
CrossSigningKeysTable tables.CrossSigningKeys
CrossSigningSigsTable tables.CrossSigningSigs
DB *sql.DB
Writer sqlutil.Writer
}
const (
// The length of generated device IDs
deviceIDByteLength = 6
@ -875,3 +886,227 @@ func (d *Database) DailyRoomsMessages(
) (stats types.MessageStats, activeRooms, activeE2EERooms int64, err error) {
return d.Stats.DailyRoomsMessages(ctx, nil, serverName)
}
//
func (d *KeyDatabase) ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
return d.OneTimeKeysTable.SelectOneTimeKeys(ctx, userID, deviceID, keyIDsWithAlgorithms)
}
func (d *KeyDatabase) StoreOneTimeKeys(ctx context.Context, keys api.OneTimeKeys) (counts *api.OneTimeKeysCount, err error) {
_ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
counts, err = d.OneTimeKeysTable.InsertOneTimeKeys(ctx, txn, keys)
return err
})
return
}
func (d *KeyDatabase) OneTimeKeysCount(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) {
return d.OneTimeKeysTable.CountOneTimeKeys(ctx, userID, deviceID)
}
func (d *KeyDatabase) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys)
}
func (d *KeyDatabase) PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) {
count, err := d.DeviceKeysTable.CountStreamIDsForUser(ctx, userID, prevIDs)
if err != nil {
return false, err
}
return count == len(prevIDs), nil
}
func (d *KeyDatabase) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
for _, userID := range clearUserIDs {
err := d.DeviceKeysTable.DeleteAllDeviceKeys(ctx, txn, userID)
if err != nil {
return err
}
}
return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys)
})
}
func (d *KeyDatabase) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error {
// work out the latest stream IDs for each user
userIDToStreamID := make(map[string]int64)
for _, k := range keys {
userIDToStreamID[k.UserID] = 0
}
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
for userID := range userIDToStreamID {
streamID, err := d.DeviceKeysTable.SelectMaxStreamIDForUser(ctx, txn, userID)
if err != nil {
return err
}
userIDToStreamID[userID] = streamID
}
// set the stream IDs for each key
for i := range keys {
k := keys[i]
userIDToStreamID[k.UserID]++ // start stream from 1
k.StreamID = userIDToStreamID[k.UserID]
keys[i] = k
}
return d.DeviceKeysTable.InsertDeviceKeys(ctx, txn, keys)
})
}
func (d *KeyDatabase) DeviceKeysForUser(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
return d.DeviceKeysTable.SelectBatchDeviceKeys(ctx, userID, deviceIDs, includeEmpty)
}
func (d *KeyDatabase) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) {
var result []api.OneTimeKeys
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
for userID, deviceToAlgo := range userToDeviceToAlgorithm {
for deviceID, algo := range deviceToAlgo {
keyJSON, err := d.OneTimeKeysTable.SelectAndDeleteOneTimeKey(ctx, txn, userID, deviceID, algo)
if err != nil {
return err
}
if keyJSON != nil {
result = append(result, api.OneTimeKeys{
UserID: userID,
DeviceID: deviceID,
KeyJSON: keyJSON,
})
}
}
}
return nil
})
return result, err
}
func (d *KeyDatabase) StoreKeyChange(ctx context.Context, userID string) (id int64, err error) {
err = d.Writer.Do(nil, nil, func(_ *sql.Tx) error {
id, err = d.KeyChangesTable.InsertKeyChange(ctx, userID)
return err
})
return
}
func (d *KeyDatabase) KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) {
return d.KeyChangesTable.SelectKeyChanges(ctx, fromOffset, toOffset)
}
// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
// If no domains are given, all user IDs with stale device lists are returned.
func (d *KeyDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
return d.StaleDeviceListsTable.SelectUserIDsWithStaleDeviceLists(ctx, domains)
}
// MarkDeviceListStale sets the stale bit for this user to isStale.
func (d *KeyDatabase) MarkDeviceListStale(ctx context.Context, userID string, isStale bool) error {
return d.Writer.Do(nil, nil, func(_ *sql.Tx) error {
return d.StaleDeviceListsTable.InsertStaleDeviceList(ctx, userID, isStale)
})
}
// DeleteDeviceKeys removes the device keys for a given user/device, and any accompanying
// cross-signing signatures relating to that device.
func (d *KeyDatabase) DeleteDeviceKeys(ctx context.Context, userID string, deviceIDs []gomatrixserverlib.KeyID) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
for _, deviceID := range deviceIDs {
if err := d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget(ctx, txn, userID, deviceID); err != nil && err != sql.ErrNoRows {
return fmt.Errorf("d.CrossSigningSigsTable.DeleteCrossSigningSigsForTarget: %w", err)
}
if err := d.DeviceKeysTable.DeleteDeviceKeys(ctx, txn, userID, string(deviceID)); err != nil && err != sql.ErrNoRows {
return fmt.Errorf("d.DeviceKeysTable.DeleteDeviceKeys: %w", err)
}
if err := d.OneTimeKeysTable.DeleteOneTimeKeys(ctx, txn, userID, string(deviceID)); err != nil && err != sql.ErrNoRows {
return fmt.Errorf("d.OneTimeKeysTable.DeleteOneTimeKeys: %w", err)
}
}
return nil
})
}
// CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any.
func (d *KeyDatabase) CrossSigningKeysForUser(ctx context.Context, userID string) (map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey, error) {
keyMap, err := d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID)
if err != nil {
return nil, fmt.Errorf("d.CrossSigningKeysTable.SelectCrossSigningKeysForUser: %w", err)
}
results := map[gomatrixserverlib.CrossSigningKeyPurpose]gomatrixserverlib.CrossSigningKey{}
for purpose, key := range keyMap {
keyID := gomatrixserverlib.KeyID("ed25519:" + key.Encode())
result := gomatrixserverlib.CrossSigningKey{
UserID: userID,
Usage: []gomatrixserverlib.CrossSigningKeyPurpose{purpose},
Keys: map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{
keyID: key,
},
}
sigMap, err := d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, userID, userID, keyID)
if err != nil {
continue
}
for sigUserID, forSigUserID := range sigMap {
if userID != sigUserID {
continue
}
if result.Signatures == nil {
result.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
}
if _, ok := result.Signatures[sigUserID]; !ok {
result.Signatures[sigUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
}
for sigKeyID, sigBytes := range forSigUserID {
result.Signatures[sigUserID][sigKeyID] = sigBytes
}
}
results[purpose] = result
}
return results, nil
}
// CrossSigningKeysForUser returns the latest known cross-signing keys for a user, if any.
func (d *KeyDatabase) CrossSigningKeysDataForUser(ctx context.Context, userID string) (types.CrossSigningKeyMap, error) {
return d.CrossSigningKeysTable.SelectCrossSigningKeysForUser(ctx, nil, userID)
}
// CrossSigningSigsForTarget returns the signatures for a given user's key ID, if any.
func (d *KeyDatabase) CrossSigningSigsForTarget(ctx context.Context, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (types.CrossSigningSigMap, error) {
return d.CrossSigningSigsTable.SelectCrossSigningSigsForTarget(ctx, nil, originUserID, targetUserID, targetKeyID)
}
// StoreCrossSigningKeysForUser stores the latest known cross-signing keys for a user.
func (d *KeyDatabase) StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
for keyType, keyData := range keyMap {
if err := d.CrossSigningKeysTable.UpsertCrossSigningKeysForUser(ctx, txn, userID, keyType, keyData); err != nil {
return fmt.Errorf("d.CrossSigningKeysTable.InsertCrossSigningKeysForUser: %w", err)
}
}
return nil
})
}
// StoreCrossSigningSigsForTarget stores a signature for a target user ID and key/dvice.
func (d *KeyDatabase) StoreCrossSigningSigsForTarget(
ctx context.Context,
originUserID string, originKeyID gomatrixserverlib.KeyID,
targetUserID string, targetKeyID gomatrixserverlib.KeyID,
signature gomatrixserverlib.Base64Bytes,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if err := d.CrossSigningSigsTable.UpsertCrossSigningSigsForTarget(ctx, nil, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil {
return fmt.Errorf("d.CrossSigningSigsTable.InsertCrossSigningSigsForTarget: %w", err)
}
return nil
})
}
// DeleteStaleDeviceLists deletes stale device list entries for users we don't share a room with anymore.
func (d *KeyDatabase) DeleteStaleDeviceLists(
ctx context.Context,
userIDs []string,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.StaleDeviceListsTable.DeleteStaleDeviceLists(ctx, txn, userIDs)
})
}

View file

@ -0,0 +1,101 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/dendrite/userapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
var crossSigningKeysSchema = `
CREATE TABLE IF NOT EXISTS keyserver_cross_signing_keys (
user_id TEXT NOT NULL,
key_type INTEGER NOT NULL,
key_data TEXT NOT NULL,
PRIMARY KEY (user_id, key_type)
);
`
const selectCrossSigningKeysForUserSQL = "" +
"SELECT key_type, key_data FROM keyserver_cross_signing_keys" +
" WHERE user_id = $1"
const upsertCrossSigningKeysForUserSQL = "" +
"INSERT OR REPLACE INTO keyserver_cross_signing_keys (user_id, key_type, key_data)" +
" VALUES($1, $2, $3)"
type crossSigningKeysStatements struct {
db *sql.DB
selectCrossSigningKeysForUserStmt *sql.Stmt
upsertCrossSigningKeysForUserStmt *sql.Stmt
}
func NewSqliteCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error) {
s := &crossSigningKeysStatements{
db: db,
}
_, err := db.Exec(crossSigningKeysSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.selectCrossSigningKeysForUserStmt, selectCrossSigningKeysForUserSQL},
{&s.upsertCrossSigningKeysForUserStmt, upsertCrossSigningKeysForUserSQL},
}.Prepare(db)
}
func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser(
ctx context.Context, txn *sql.Tx, userID string,
) (r types.CrossSigningKeyMap, err error) {
rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningKeysForUserStmt).QueryContext(ctx, userID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningKeysForUserStmt: rows.close() failed")
r = types.CrossSigningKeyMap{}
for rows.Next() {
var keyTypeInt int16
var keyData gomatrixserverlib.Base64Bytes
if err := rows.Scan(&keyTypeInt, &keyData); err != nil {
return nil, err
}
keyType, ok := types.KeyTypeIntToPurpose[keyTypeInt]
if !ok {
return nil, fmt.Errorf("unknown key purpose int %d", keyTypeInt)
}
r[keyType] = keyData
}
return
}
func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser(
ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes,
) error {
keyTypeInt, ok := types.KeyTypePurposeToInt[keyType]
if !ok {
return fmt.Errorf("unknown key purpose %q", keyType)
}
if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningKeysForUserStmt).ExecContext(ctx, userID, keyTypeInt, keyData); err != nil {
return fmt.Errorf("s.upsertCrossSigningKeysForUserStmt: %w", err)
}
return nil
}

View file

@ -0,0 +1,129 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/dendrite/userapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
var crossSigningSigsSchema = `
CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs (
origin_user_id TEXT NOT NULL,
origin_key_id TEXT NOT NULL,
target_user_id TEXT NOT NULL,
target_key_id TEXT NOT NULL,
signature TEXT NOT NULL,
PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id)
);
CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id);
`
const selectCrossSigningSigsForTargetSQL = "" +
"SELECT origin_user_id, origin_key_id, signature FROM keyserver_cross_signing_sigs" +
" WHERE (origin_user_id = $1 OR origin_user_id = $2) AND target_user_id = $3 AND target_key_id = $4"
const upsertCrossSigningSigsForTargetSQL = "" +
"INSERT OR REPLACE INTO keyserver_cross_signing_sigs (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" +
" VALUES($1, $2, $3, $4, $5)"
const deleteCrossSigningSigsForTargetSQL = "" +
"DELETE FROM keyserver_cross_signing_sigs WHERE target_user_id=$1 AND target_key_id=$2"
type crossSigningSigsStatements struct {
db *sql.DB
selectCrossSigningSigsForTargetStmt *sql.Stmt
upsertCrossSigningSigsForTargetStmt *sql.Stmt
deleteCrossSigningSigsForTargetStmt *sql.Stmt
}
func NewSqliteCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, error) {
s := &crossSigningSigsStatements{
db: db,
}
_, err := db.Exec(crossSigningSigsSchema)
if err != nil {
return nil, err
}
m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: "keyserver: cross signing signature indexes",
Up: deltas.UpFixCrossSigningSignatureIndexes,
})
if err = m.Up(context.Background()); err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.selectCrossSigningSigsForTargetStmt, selectCrossSigningSigsForTargetSQL},
{&s.upsertCrossSigningSigsForTargetStmt, upsertCrossSigningSigsForTargetSQL},
{&s.deleteCrossSigningSigsForTargetStmt, deleteCrossSigningSigsForTargetSQL},
}.Prepare(db)
}
func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget(
ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID,
) (r types.CrossSigningSigMap, err error) {
rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningSigsForTargetStmt).QueryContext(ctx, originUserID, targetUserID, targetUserID, targetKeyID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningSigsForOriginTargetStmt: rows.close() failed")
r = types.CrossSigningSigMap{}
for rows.Next() {
var userID string
var keyID gomatrixserverlib.KeyID
var signature gomatrixserverlib.Base64Bytes
if err := rows.Scan(&userID, &keyID, &signature); err != nil {
return nil, err
}
if _, ok := r[userID]; !ok {
r[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
}
r[userID][keyID] = signature
}
return
}
func (s *crossSigningSigsStatements) UpsertCrossSigningSigsForTarget(
ctx context.Context, txn *sql.Tx,
originUserID string, originKeyID gomatrixserverlib.KeyID,
targetUserID string, targetKeyID gomatrixserverlib.KeyID,
signature gomatrixserverlib.Base64Bytes,
) error {
if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningSigsForTargetStmt).ExecContext(ctx, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil {
return fmt.Errorf("s.upsertCrossSigningSigsForTargetStmt: %w", err)
}
return nil
}
func (s *crossSigningSigsStatements) DeleteCrossSigningSigsForTarget(
ctx context.Context, txn *sql.Tx,
targetUserID string, targetKeyID gomatrixserverlib.KeyID,
) error {
if _, err := sqlutil.TxStmt(txn, s.deleteCrossSigningSigsForTargetStmt).ExecContext(ctx, targetUserID, targetKeyID); err != nil {
return fmt.Errorf("s.deleteCrossSigningSigsForTargetStmt: %w", err)
}
return nil
}

View file

@ -0,0 +1,66 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package deltas
import (
"context"
"database/sql"
"fmt"
)
func UpRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error {
// start counting from the last max offset, else 0.
var maxOffset int64
var userID string
_ = tx.QueryRowContext(ctx, `SELECT user_id, MAX(log_offset) FROM keyserver_key_changes GROUP BY user_id`).Scan(&userID, &maxOffset)
_, err := tx.ExecContext(ctx, `
-- make the new table
DROP TABLE IF EXISTS keyserver_key_changes;
CREATE TABLE IF NOT EXISTS keyserver_key_changes (
change_id INTEGER PRIMARY KEY AUTOINCREMENT,
-- The key owner
user_id TEXT NOT NULL,
UNIQUE (user_id)
);
`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
// to start counting from maxOffset, insert a row with that value
if userID != "" {
_, err = tx.ExecContext(ctx, `INSERT INTO keyserver_key_changes(change_id, user_id) VALUES($1, $2)`, maxOffset, userID)
return err
}
return nil
}
func DownRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, `
-- Drop all data and revert back, we can't keep the data as Kafka offsets determine the numbers
DROP TABLE IF EXISTS keyserver_key_changes;
CREATE TABLE IF NOT EXISTS keyserver_key_changes (
partition BIGINT NOT NULL,
offset BIGINT NOT NULL,
-- The key owner
user_id TEXT NOT NULL,
UNIQUE (partition, offset)
);
`)
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
return nil
}

View file

@ -0,0 +1,71 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package deltas
import (
"context"
"database/sql"
"fmt"
)
func UpFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs_tmp (
origin_user_id TEXT NOT NULL,
origin_key_id TEXT NOT NULL,
target_user_id TEXT NOT NULL,
target_key_id TEXT NOT NULL,
signature TEXT NOT NULL,
PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id)
);
INSERT INTO keyserver_cross_signing_sigs_tmp (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)
SELECT origin_user_id, origin_key_id, target_user_id, target_key_id, signature FROM keyserver_cross_signing_sigs;
DROP TABLE keyserver_cross_signing_sigs;
ALTER TABLE keyserver_cross_signing_sigs_tmp RENAME TO keyserver_cross_signing_sigs;
CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id);
`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
func DownFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error {
_, err := tx.ExecContext(ctx, `
CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs_tmp (
origin_user_id TEXT NOT NULL,
origin_key_id TEXT NOT NULL,
target_user_id TEXT NOT NULL,
target_key_id TEXT NOT NULL,
signature TEXT NOT NULL,
PRIMARY KEY (origin_user_id, target_user_id, target_key_id)
);
INSERT INTO keyserver_cross_signing_sigs_tmp (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)
SELECT origin_user_id, origin_key_id, target_user_id, target_key_id, signature FROM keyserver_cross_signing_sigs;
DROP TABLE keyserver_cross_signing_sigs;
ALTER TABLE keyserver_cross_signing_sigs_tmp RENAME TO keyserver_cross_signing_sigs;
DELETE INDEX IF EXISTS keyserver_cross_signing_sigs_idx;
`)
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
return nil
}

View file

@ -0,0 +1,213 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"strings"
"time"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
)
var deviceKeysSchema = `
-- Stores device keys for users
CREATE TABLE IF NOT EXISTS keyserver_device_keys (
user_id TEXT NOT NULL,
device_id TEXT NOT NULL,
ts_added_secs BIGINT NOT NULL,
key_json TEXT NOT NULL,
stream_id BIGINT NOT NULL,
display_name TEXT,
-- Clobber based on tuple of user/device.
UNIQUE (user_id, device_id)
);
`
const upsertDeviceKeysSQL = "" +
"INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" +
" VALUES ($1, $2, $3, $4, $5, $6)" +
" ON CONFLICT (user_id, device_id)" +
" DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6"
const selectDeviceKeysSQL = "" +
"SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
const selectBatchDeviceKeysSQL = "" +
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''"
const selectBatchDeviceKeysWithEmptiesSQL = "" +
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1"
const selectMaxStreamForUserSQL = "" +
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
const countStreamIDsForUserSQL = "" +
"SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id IN ($2)"
const deleteDeviceKeysSQL = "" +
"DELETE FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
const deleteAllDeviceKeysSQL = "" +
"DELETE FROM keyserver_device_keys WHERE user_id=$1"
type deviceKeysStatements struct {
db *sql.DB
upsertDeviceKeysStmt *sql.Stmt
selectDeviceKeysStmt *sql.Stmt
selectBatchDeviceKeysStmt *sql.Stmt
selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt
selectMaxStreamForUserStmt *sql.Stmt
deleteDeviceKeysStmt *sql.Stmt
deleteAllDeviceKeysStmt *sql.Stmt
}
func NewSqliteDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
s := &deviceKeysStatements{
db: db,
}
_, err := db.Exec(deviceKeysSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.upsertDeviceKeysStmt, upsertDeviceKeysSQL},
{&s.selectDeviceKeysStmt, selectDeviceKeysSQL},
{&s.selectBatchDeviceKeysStmt, selectBatchDeviceKeysSQL},
{&s.selectBatchDeviceKeysWithEmptiesStmt, selectBatchDeviceKeysWithEmptiesSQL},
{&s.selectMaxStreamForUserStmt, selectMaxStreamForUserSQL},
// {&s.countStreamIDsForUserStmt, countStreamIDsForUserSQL}, // prepared at runtime
{&s.deleteDeviceKeysStmt, deleteDeviceKeysSQL},
{&s.deleteAllDeviceKeysStmt, deleteAllDeviceKeysSQL},
}.Prepare(db)
}
func (s *deviceKeysStatements) DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error {
_, err := sqlutil.TxStmt(txn, s.deleteDeviceKeysStmt).ExecContext(ctx, userID, deviceID)
return err
}
func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error {
_, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
return err
}
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
deviceIDMap := make(map[string]bool)
for _, d := range deviceIDs {
deviceIDMap[d] = true
}
var stmt *sql.Stmt
if includeEmpty {
stmt = s.selectBatchDeviceKeysWithEmptiesStmt
} else {
stmt = s.selectBatchDeviceKeysStmt
}
rows, err := stmt.QueryContext(ctx, userID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed")
var result []api.DeviceMessage
var displayName sql.NullString
for rows.Next() {
dk := api.DeviceMessage{
Type: api.TypeDeviceKeyUpdate,
DeviceKeys: &api.DeviceKeys{
UserID: userID,
},
}
if err := rows.Scan(&dk.DeviceID, &dk.KeyJSON, &dk.StreamID, &displayName); err != nil {
return nil, err
}
if displayName.Valid {
dk.DisplayName = displayName.String
}
// include the key if we want all keys (no device) or it was asked
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
result = append(result, dk)
}
}
return result, rows.Err()
}
func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
for i, key := range keys {
var keyJSONStr string
var streamID int64
var displayName sql.NullString
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName)
if err != nil && err != sql.ErrNoRows {
return err
}
// this will be '' when there is no device
keys[i].Type = api.TypeDeviceKeyUpdate
keys[i].KeyJSON = []byte(keyJSONStr)
keys[i].StreamID = streamID
if displayName.Valid {
keys[i].DisplayName = displayName.String
}
}
return nil
}
func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error) {
// nullable if there are no results
var nullStream sql.NullInt64
err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
if err == sql.ErrNoRows {
err = nil
}
if nullStream.Valid {
streamID = nullStream.Int64
}
return
}
func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) {
iStreamIDs := make([]interface{}, len(streamIDs)+1)
iStreamIDs[0] = userID
for i := range streamIDs {
iStreamIDs[i+1] = streamIDs[i]
}
query := strings.Replace(countStreamIDsForUserSQL, "($2)", sqlutil.QueryVariadicOffset(len(streamIDs), 1), 1)
// nullable if there are no results
var count sql.NullInt64
err := s.db.QueryRowContext(ctx, query, iStreamIDs...).Scan(&count)
if err != nil {
return 0, err
}
if count.Valid {
return int(count.Int64), nil
}
return 0, nil
}
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
for _, key := range keys {
now := time.Now().Unix()
_, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext(
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName,
)
if err != nil {
return err
}
}
return nil
}

View file

@ -151,7 +151,7 @@ func (s *devicesStatements) InsertDevice(
if _, err := insertStmt.ExecContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil {
return nil, err
}
return &api.Device{
dev := &api.Device{
ID: id,
UserID: userutil.MakeUserID(localpart, serverName),
AccessToken: accessToken,
@ -159,7 +159,11 @@ func (s *devicesStatements) InsertDevice(
LastSeenTS: createdTimeMS,
LastSeenIP: ipAddr,
UserAgent: userAgent,
}, nil
}
if displayName != nil {
dev.DisplayName = *displayName
}
return dev, nil
}
func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn *sql.Tx, id,
@ -172,7 +176,7 @@ func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn *
if _, err := insertStmt.ExecContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil {
return nil, err
}
return &api.Device{
dev := &api.Device{
ID: id,
UserID: userutil.MakeUserID(localpart, serverName),
AccessToken: accessToken,
@ -180,7 +184,11 @@ func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn *
LastSeenTS: createdTimeMS,
LastSeenIP: ipAddr,
UserAgent: userAgent,
}, nil
}
if displayName != nil {
dev.DisplayName = *displayName
}
return dev, nil
}
func (s *devicesStatements) DeleteDevice(
@ -202,6 +210,7 @@ func (s *devicesStatements) DeleteDevices(
if err != nil {
return err
}
defer internal.CloseAndLogIfError(ctx, prep, "DeleteDevices.StmtClose() failed")
stmt := sqlutil.TxStmt(txn, prep)
params := make([]interface{}, len(devices)+2)
params[0] = localpart

View file

@ -52,7 +52,7 @@ const updateBackupKeySQL = "" +
const countKeysSQL = "" +
"SELECT COUNT(*) FROM userapi_key_backups WHERE user_id = $1 AND version = $2"
const selectKeysSQL = "" +
const selectBackupKeysSQL = "" +
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " +
"WHERE user_id = $1 AND version = $2"
@ -83,7 +83,7 @@ func NewSQLiteKeyBackupTable(db *sql.DB) (tables.KeyBackupTable, error) {
{&s.insertBackupKeyStmt, insertBackupKeySQL},
{&s.updateBackupKeyStmt, updateBackupKeySQL},
{&s.countKeysStmt, countKeysSQL},
{&s.selectKeysStmt, selectKeysSQL},
{&s.selectKeysStmt, selectBackupKeysSQL},
{&s.selectKeysByRoomIDStmt, selectKeysByRoomIDSQL},
{&s.selectKeysByRoomIDAndSessionIDStmt, selectKeysByRoomIDAndSessionIDSQL},
}.Prepare(db)

View file

@ -0,0 +1,125 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"errors"
"fmt"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
"github.com/matrix-org/dendrite/userapi/storage/tables"
)
var keyChangesSchema = `
-- Stores key change information about users. Used to determine when to send updated device lists to clients.
CREATE TABLE IF NOT EXISTS keyserver_key_changes (
change_id INTEGER PRIMARY KEY AUTOINCREMENT,
-- The key owner
user_id TEXT NOT NULL,
UNIQUE (user_id)
);
`
// Replace based on user ID. We don't care how many times the user's keys have changed, only that they
// have changed, hence we can just keep bumping the change ID for this user.
const upsertKeyChangeSQL = "" +
"INSERT OR REPLACE INTO keyserver_key_changes (user_id)" +
" VALUES ($1)" +
" RETURNING change_id"
const selectKeyChangesSQL = "" +
"SELECT user_id, change_id FROM keyserver_key_changes WHERE change_id > $1 AND change_id <= $2"
type keyChangesStatements struct {
db *sql.DB
upsertKeyChangeStmt *sql.Stmt
selectKeyChangesStmt *sql.Stmt
}
func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
s := &keyChangesStatements{
db: db,
}
_, err := db.Exec(keyChangesSchema)
if err != nil {
return s, err
}
if err = executeMigration(context.Background(), db); err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.upsertKeyChangeStmt, upsertKeyChangeSQL},
{&s.selectKeyChangesStmt, selectKeyChangesSQL},
}.Prepare(db)
}
func executeMigration(ctx context.Context, db *sql.DB) error {
// TODO: Remove when we are sure we are not having goose artefacts in the db
// This forces an error, which indicates the migration is already applied, since the
// column partition was removed from the table
migrationName := "keyserver: refactor key changes"
var cName string
err := db.QueryRowContext(ctx, `SELECT p.name FROM sqlite_master AS m JOIN pragma_table_info(m.name) AS p WHERE m.name = 'keyserver_key_changes' AND p.name = 'partition'`).Scan(&cName)
if err != nil {
if errors.Is(err, sql.ErrNoRows) { // migration was already executed, as the column was removed
if err = sqlutil.InsertMigration(ctx, db, migrationName); err != nil {
return fmt.Errorf("unable to manually insert migration '%s': %w", migrationName, err)
}
return nil
}
return err
}
m := sqlutil.NewMigrator(db)
m.AddMigrations(sqlutil.Migration{
Version: migrationName,
Up: deltas.UpRefactorKeyChanges,
})
return m.Up(ctx)
}
func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, userID string) (changeID int64, err error) {
err = s.upsertKeyChangeStmt.QueryRowContext(ctx, userID).Scan(&changeID)
return
}
func (s *keyChangesStatements) SelectKeyChanges(
ctx context.Context, fromOffset, toOffset int64,
) (userIDs []string, latestOffset int64, err error) {
latestOffset = fromOffset
rows, err := s.selectKeyChangesStmt.QueryContext(ctx, fromOffset, toOffset)
if err != nil {
return nil, 0, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectKeyChangesStmt: rows.close() failed")
for rows.Next() {
var userID string
var offset int64
if err := rows.Scan(&userID, &offset); err != nil {
return nil, 0, err
}
if offset > latestOffset {
latestOffset = offset
}
userIDs = append(userIDs, userID)
}
return
}

View file

@ -0,0 +1,208 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"encoding/json"
"time"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
)
var oneTimeKeysSchema = `
-- Stores one-time public keys for users
CREATE TABLE IF NOT EXISTS keyserver_one_time_keys (
user_id TEXT NOT NULL,
device_id TEXT NOT NULL,
key_id TEXT NOT NULL,
algorithm TEXT NOT NULL,
ts_added_secs BIGINT NOT NULL,
key_json TEXT NOT NULL,
-- Clobber based on 4-uple of user/device/key/algorithm.
UNIQUE (user_id, device_id, key_id, algorithm)
);
CREATE INDEX IF NOT EXISTS keyserver_one_time_keys_idx ON keyserver_one_time_keys (user_id, device_id);
`
const upsertKeysSQL = "" +
"INSERT INTO keyserver_one_time_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json)" +
" VALUES ($1, $2, $3, $4, $5, $6)" +
" ON CONFLICT (user_id, device_id, key_id, algorithm)" +
" DO UPDATE SET key_json = $6"
const selectOneTimeKeysSQL = "" +
"SELECT key_id, algorithm, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2"
const selectKeysCountSQL = "" +
"SELECT algorithm, COUNT(key_id) FROM " +
" (SELECT algorithm, key_id FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 LIMIT 100)" +
" x GROUP BY algorithm"
const deleteOneTimeKeySQL = "" +
"DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 AND key_id = $4"
const selectKeyByAlgorithmSQL = "" +
"SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1"
const deleteOneTimeKeysSQL = "" +
"DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2"
type oneTimeKeysStatements struct {
db *sql.DB
upsertKeysStmt *sql.Stmt
selectKeysStmt *sql.Stmt
selectKeysCountStmt *sql.Stmt
selectKeyByAlgorithmStmt *sql.Stmt
deleteOneTimeKeyStmt *sql.Stmt
deleteOneTimeKeysStmt *sql.Stmt
}
func NewSqliteOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
s := &oneTimeKeysStatements{
db: db,
}
_, err := db.Exec(oneTimeKeysSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.upsertKeysStmt, upsertKeysSQL},
{&s.selectKeysStmt, selectOneTimeKeysSQL},
{&s.selectKeysCountStmt, selectKeysCountSQL},
{&s.selectKeyByAlgorithmStmt, selectKeyByAlgorithmSQL},
{&s.deleteOneTimeKeyStmt, deleteOneTimeKeySQL},
{&s.deleteOneTimeKeysStmt, deleteOneTimeKeysSQL},
}.Prepare(db)
}
func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
rows, err := s.selectKeysStmt.QueryContext(ctx, userID, deviceID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt: rows.close() failed")
wantSet := make(map[string]bool, len(keyIDsWithAlgorithms))
for _, ka := range keyIDsWithAlgorithms {
wantSet[ka] = true
}
result := make(map[string]json.RawMessage)
for rows.Next() {
var keyID string
var algorithm string
var keyJSONStr string
if err := rows.Scan(&keyID, &algorithm, &keyJSONStr); err != nil {
return nil, err
}
keyIDWithAlgo := algorithm + ":" + keyID
if wantSet[keyIDWithAlgo] {
result[keyIDWithAlgo] = json.RawMessage(keyJSONStr)
}
}
return result, rows.Err()
}
func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) {
counts := &api.OneTimeKeysCount{
DeviceID: deviceID,
UserID: userID,
KeyCount: make(map[string]int),
}
rows, err := s.selectKeysCountStmt.QueryContext(ctx, userID, deviceID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
for rows.Next() {
var algorithm string
var count int
if err = rows.Scan(&algorithm, &count); err != nil {
return nil, err
}
counts.KeyCount[algorithm] = count
}
return counts, nil
}
func (s *oneTimeKeysStatements) InsertOneTimeKeys(
ctx context.Context, txn *sql.Tx, keys api.OneTimeKeys,
) (*api.OneTimeKeysCount, error) {
now := time.Now().Unix()
counts := &api.OneTimeKeysCount{
DeviceID: keys.DeviceID,
UserID: keys.UserID,
KeyCount: make(map[string]int),
}
for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
algo, keyID := keys.Split(keyIDWithAlgo)
_, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext(
ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON),
)
if err != nil {
return nil, err
}
}
rows, err := sqlutil.TxStmt(txn, s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
for rows.Next() {
var algorithm string
var count int
if err = rows.Scan(&algorithm, &count); err != nil {
return nil, err
}
counts.KeyCount[algorithm] = count
}
return counts, rows.Err()
}
func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey(
ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string,
) (map[string]json.RawMessage, error) {
var keyID string
var keyJSON string
err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
}
_, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
if err != nil {
return nil, err
}
if keyJSON == "" {
return nil, nil
}
return map[string]json.RawMessage{
algorithm + ":" + keyID: json.RawMessage(keyJSON),
}, err
}
func (s *oneTimeKeysStatements) DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error {
_, err := sqlutil.TxStmt(txn, s.deleteOneTimeKeysStmt).ExecContext(ctx, userID, deviceID)
return err
}

View file

@ -0,0 +1,145 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"strings"
"time"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
)
var staleDeviceListsSchema = `
-- Stores whether a user's device lists are stale or not.
CREATE TABLE IF NOT EXISTS keyserver_stale_device_lists (
user_id TEXT PRIMARY KEY NOT NULL,
domain TEXT NOT NULL,
is_stale BOOLEAN NOT NULL,
ts_added_secs BIGINT NOT NULL
);
CREATE INDEX IF NOT EXISTS keyserver_stale_device_lists_idx ON keyserver_stale_device_lists (domain, is_stale);
`
const upsertStaleDeviceListSQL = "" +
"INSERT INTO keyserver_stale_device_lists (user_id, domain, is_stale, ts_added_secs)" +
" VALUES ($1, $2, $3, $4)" +
" ON CONFLICT (user_id)" +
" DO UPDATE SET is_stale = $3, ts_added_secs = $4"
const selectStaleDeviceListsWithDomainsSQL = "" +
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2 ORDER BY ts_added_secs DESC"
const selectStaleDeviceListsSQL = "" +
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 ORDER BY ts_added_secs DESC"
const deleteStaleDevicesSQL = "" +
"DELETE FROM keyserver_stale_device_lists WHERE user_id IN ($1)"
type staleDeviceListsStatements struct {
db *sql.DB
upsertStaleDeviceListStmt *sql.Stmt
selectStaleDeviceListsWithDomainsStmt *sql.Stmt
selectStaleDeviceListsStmt *sql.Stmt
// deleteStaleDeviceListsStmt *sql.Stmt // Prepared at runtime
}
func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) {
s := &staleDeviceListsStatements{
db: db,
}
_, err := db.Exec(staleDeviceListsSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.upsertStaleDeviceListStmt, upsertStaleDeviceListSQL},
{&s.selectStaleDeviceListsStmt, selectStaleDeviceListsSQL},
{&s.selectStaleDeviceListsWithDomainsStmt, selectStaleDeviceListsWithDomainsSQL},
// { &s.deleteStaleDeviceListsStmt, deleteStaleDevicesSQL}, // Prepared at runtime
}.Prepare(db)
}
func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error {
_, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
return err
}
_, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, gomatrixserverlib.AsTimestamp(time.Now()))
return err
}
func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
// we only query for 1 domain or all domains so optimise for those use cases
if len(domains) == 0 {
rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true)
if err != nil {
return nil, err
}
return rowsToUserIDs(ctx, rows)
}
var result []string
for _, domain := range domains {
rows, err := s.selectStaleDeviceListsWithDomainsStmt.QueryContext(ctx, true, string(domain))
if err != nil {
return nil, err
}
userIDs, err := rowsToUserIDs(ctx, rows)
if err != nil {
return nil, err
}
result = append(result, userIDs...)
}
return result, nil
}
// DeleteStaleDeviceLists removes users from stale device lists
func (s *staleDeviceListsStatements) DeleteStaleDeviceLists(
ctx context.Context, txn *sql.Tx, userIDs []string,
) error {
qry := strings.Replace(deleteStaleDevicesSQL, "($1)", sqlutil.QueryVariadic(len(userIDs)), 1)
stmt, err := s.db.Prepare(qry)
if err != nil {
return err
}
defer internal.CloseAndLogIfError(ctx, stmt, "DeleteStaleDeviceLists: stmt.Close failed")
stmt = sqlutil.TxStmt(txn, stmt)
params := make([]any, len(userIDs))
for i := range userIDs {
params[i] = userIDs[i]
}
_, err = stmt.ExecContext(ctx, params...)
return err
}
func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) {
defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed")
for rows.Next() {
var userID string
if err := rows.Scan(&userID); err != nil {
return nil, err
}
result = append(result, userID)
}
return result, rows.Err()
}

View file

@ -256,6 +256,7 @@ func (s *statsStatements) allUsers(ctx context.Context, txn *sql.Tx) (result int
if err != nil {
return 0, err
}
defer internal.CloseAndLogIfError(ctx, queryStmt, "allUsers.StmtClose() failed")
stmt := sqlutil.TxStmt(txn, queryStmt)
err = stmt.QueryRowContext(ctx,
1, 2, 3, 4,
@ -269,6 +270,7 @@ func (s *statsStatements) nonBridgedUsers(ctx context.Context, txn *sql.Tx) (res
if err != nil {
return 0, err
}
defer internal.CloseAndLogIfError(ctx, queryStmt, "nonBridgedUsers.StmtClose() failed")
stmt := sqlutil.TxStmt(txn, queryStmt)
err = stmt.QueryRowContext(ctx,
1, 2, 3,
@ -286,6 +288,7 @@ func (s *statsStatements) registeredUserByType(ctx context.Context, txn *sql.Tx)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, queryStmt, "registeredUserByType.StmtClose() failed")
stmt := sqlutil.TxStmt(txn, queryStmt)
registeredAfter := time.Now().AddDate(0, 0, -30)

View file

@ -30,8 +30,8 @@ import (
"github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
)
// NewDatabase creates a new accounts and profiles database
func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (*shared.Database, error) {
// NewUserDatabase creates a new accounts and profiles database
func NewUserDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (*shared.Database, error) {
db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter())
if err != nil {
return nil, err
@ -134,3 +134,44 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
OpenIDTokenLifetimeMS: openIDTokenLifetimeMS,
}, nil
}
func NewKeyDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.KeyDatabase, error) {
db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter())
if err != nil {
return nil, err
}
otk, err := NewSqliteOneTimeKeysTable(db)
if err != nil {
return nil, err
}
dk, err := NewSqliteDeviceKeysTable(db)
if err != nil {
return nil, err
}
kc, err := NewSqliteKeyChangesTable(db)
if err != nil {
return nil, err
}
sdl, err := NewSqliteStaleDeviceListsTable(db)
if err != nil {
return nil, err
}
csk, err := NewSqliteCrossSigningKeysTable(db)
if err != nil {
return nil, err
}
css, err := NewSqliteCrossSigningSigsTable(db)
if err != nil {
return nil, err
}
return &shared.KeyDatabase{
OneTimeKeysTable: otk,
DeviceKeysTable: dk,
KeyChangesTable: kc,
StaleDeviceListsTable: sdl,
CrossSigningKeysTable: csk,
CrossSigningSigsTable: css,
Writer: writer,
}, nil
}

View file

@ -29,15 +29,36 @@ import (
"github.com/matrix-org/dendrite/userapi/storage/sqlite3"
)
// NewUserAPIDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
// NewUserDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
// and sets postgres connection parameters
func NewUserAPIDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (Database, error) {
func NewUserDatabase(
base *base.BaseDendrite,
dbProperties *config.DatabaseOptions,
serverName gomatrixserverlib.ServerName,
bcryptCost int,
openIDTokenLifetimeMS int64,
loginTokenLifetime time.Duration,
serverNoticesLocalpart string,
) (UserDatabase, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():
return sqlite3.NewDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart)
return sqlite3.NewUserDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart)
case dbProperties.ConnectionString.IsPostgres():
return postgres.NewDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart)
default:
return nil, fmt.Errorf("unexpected database type")
}
}
// NewKeyDatabase opens a new Postgres or Sqlite database (base on dataSourceName) scheme)
// and sets postgres connection parameters.
func NewKeyDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (KeyDatabase, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():
return sqlite3.NewKeyDatabase(base, dbProperties)
case dbProperties.ConnectionString.IsPostgres():
return postgres.NewKeyDatabase(base, dbProperties)
default:
return nil, fmt.Errorf("unexpected database type")
}
}

View file

@ -4,9 +4,12 @@ import (
"context"
"encoding/json"
"fmt"
"reflect"
"sync"
"testing"
"time"
"github.com/matrix-org/dendrite/userapi/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/stretchr/testify/assert"
@ -29,14 +32,14 @@ var (
ctx = context.Background()
)
func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
func mustCreateUserDatabase(t *testing.T, dbType test.DBType) (storage.UserDatabase, func()) {
base, baseclose := testrig.CreateBaseDendrite(t, dbType)
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := storage.NewUserAPIDatabase(base, &config.DatabaseOptions{
db, err := storage.NewUserDatabase(base, &config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, "localhost", bcrypt.MinCost, openIDLifetimeMS, loginTokenLifetime, "_server")
if err != nil {
t.Fatalf("NewUserAPIDatabase returned %s", err)
t.Fatalf("NewUserDatabase returned %s", err)
}
return db, func() {
close()
@ -47,7 +50,7 @@ func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, fun
// Tests storing and getting account data
func Test_AccountData(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType)
db, close := mustCreateUserDatabase(t, dbType)
defer close()
alice := test.NewUser(t)
localpart, domain, err := gomatrixserverlib.SplitID('@', alice.ID)
@ -78,7 +81,7 @@ func Test_AccountData(t *testing.T) {
// Tests the creation of accounts
func Test_Accounts(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType)
db, close := mustCreateUserDatabase(t, dbType)
defer close()
alice := test.NewUser(t)
aliceLocalpart, aliceDomain, err := gomatrixserverlib.SplitID('@', alice.ID)
@ -158,7 +161,7 @@ func Test_Devices(t *testing.T) {
accessToken := util.RandomString(16)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType)
db, close := mustCreateUserDatabase(t, dbType)
defer close()
deviceWithID, err := db.CreateDevice(ctx, localpart, domain, &deviceID, accessToken, nil, "", "")
@ -238,7 +241,7 @@ func Test_KeyBackup(t *testing.T) {
room := test.NewRoom(t, alice)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType)
db, close := mustCreateUserDatabase(t, dbType)
defer close()
wantAuthData := json.RawMessage("my auth data")
@ -315,7 +318,7 @@ func Test_KeyBackup(t *testing.T) {
func Test_LoginToken(t *testing.T) {
alice := test.NewUser(t)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType)
db, close := mustCreateUserDatabase(t, dbType)
defer close()
// create a new token
@ -347,7 +350,7 @@ func Test_OpenID(t *testing.T) {
token := util.RandomString(24)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType)
db, close := mustCreateUserDatabase(t, dbType)
defer close()
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + openIDLifetimeMS
@ -368,7 +371,7 @@ func Test_Profile(t *testing.T) {
assert.NoError(t, err)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType)
db, close := mustCreateUserDatabase(t, dbType)
defer close()
// create account, which also creates a profile
@ -417,7 +420,7 @@ func Test_Pusher(t *testing.T) {
assert.NoError(t, err)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType)
db, close := mustCreateUserDatabase(t, dbType)
defer close()
appID := util.RandomString(8)
@ -468,7 +471,7 @@ func Test_ThreePID(t *testing.T) {
assert.NoError(t, err)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType)
db, close := mustCreateUserDatabase(t, dbType)
defer close()
threePID := util.RandomString(8)
medium := util.RandomString(8)
@ -507,7 +510,7 @@ func Test_Notification(t *testing.T) {
room := test.NewRoom(t, alice)
room2 := test.NewRoom(t, alice)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType)
db, close := mustCreateUserDatabase(t, dbType)
defer close()
// generate some dummy notifications
for i := 0; i < 10; i++ {
@ -571,3 +574,184 @@ func Test_Notification(t *testing.T) {
assert.Equal(t, int64(0), total)
})
}
func mustCreateKeyDatabase(t *testing.T, dbType test.DBType) (storage.KeyDatabase, func()) {
base, close := testrig.CreateBaseDendrite(t, dbType)
db, err := storage.NewKeyDatabase(base, &base.Cfg.KeyServer.Database)
if err != nil {
t.Fatalf("failed to create new database: %v", err)
}
return db, close
}
func MustNotError(t *testing.T, err error) {
t.Helper()
if err == nil {
return
}
t.Fatalf("operation failed: %s", err)
}
func TestKeyChanges(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, clean := mustCreateKeyDatabase(t, dbType)
defer clean()
_, err := db.StoreKeyChange(ctx, "@alice:localhost")
MustNotError(t, err)
deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost")
MustNotError(t, err)
deviceChangeIDC, err := db.StoreKeyChange(ctx, "@charlie:localhost")
MustNotError(t, err)
userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDB, types.OffsetNewest)
if err != nil {
t.Fatalf("Failed to KeyChanges: %s", err)
}
if latest != deviceChangeIDC {
t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDC)
}
if !reflect.DeepEqual(userIDs, []string{"@charlie:localhost"}) {
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
}
})
}
func TestKeyChangesNoDupes(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, clean := mustCreateKeyDatabase(t, dbType)
defer clean()
deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost")
MustNotError(t, err)
deviceChangeIDB, err := db.StoreKeyChange(ctx, "@alice:localhost")
MustNotError(t, err)
if deviceChangeIDA == deviceChangeIDB {
t.Fatalf("Expected change ID to be different even when inserting key change for the same user, got %d for both changes", deviceChangeIDA)
}
deviceChangeID, err := db.StoreKeyChange(ctx, "@alice:localhost")
MustNotError(t, err)
userIDs, latest, err := db.KeyChanges(ctx, 0, types.OffsetNewest)
if err != nil {
t.Fatalf("Failed to KeyChanges: %s", err)
}
if latest != deviceChangeID {
t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeID)
}
if !reflect.DeepEqual(userIDs, []string{"@alice:localhost"}) {
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
}
})
}
func TestKeyChangesUpperLimit(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, clean := mustCreateKeyDatabase(t, dbType)
defer clean()
deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost")
MustNotError(t, err)
deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost")
MustNotError(t, err)
_, err = db.StoreKeyChange(ctx, "@charlie:localhost")
MustNotError(t, err)
userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDA, deviceChangeIDB)
if err != nil {
t.Fatalf("Failed to KeyChanges: %s", err)
}
if latest != deviceChangeIDB {
t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDB)
}
if !reflect.DeepEqual(userIDs, []string{"@bob:localhost"}) {
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
}
})
}
var dbLock sync.Mutex
var deviceArray = []string{"AAA", "another_device"}
// The purpose of this test is to make sure that the storage layer is generating sequential stream IDs per user,
// and that they are returned correctly when querying for device keys.
func TestDeviceKeysStreamIDGeneration(t *testing.T) {
var err error
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, clean := mustCreateKeyDatabase(t, dbType)
defer clean()
alice := "@alice:TestDeviceKeysStreamIDGeneration"
bob := "@bob:TestDeviceKeysStreamIDGeneration"
msgs := []api.DeviceMessage{
{
Type: api.TypeDeviceKeyUpdate,
DeviceKeys: &api.DeviceKeys{
DeviceID: "AAA",
UserID: alice,
KeyJSON: []byte(`{"key":"v1"}`),
},
// StreamID: 1
},
{
Type: api.TypeDeviceKeyUpdate,
DeviceKeys: &api.DeviceKeys{
DeviceID: "AAA",
UserID: bob,
KeyJSON: []byte(`{"key":"v1"}`),
},
// StreamID: 1 as this is a different user
},
{
Type: api.TypeDeviceKeyUpdate,
DeviceKeys: &api.DeviceKeys{
DeviceID: "another_device",
UserID: alice,
KeyJSON: []byte(`{"key":"v1"}`),
},
// StreamID: 2 as this is a 2nd device key
},
}
MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs))
if msgs[0].StreamID != 1 {
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 but got %d", msgs[0].StreamID)
}
if msgs[1].StreamID != 1 {
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=1 (different user) but got %d", msgs[1].StreamID)
}
if msgs[2].StreamID != 2 {
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=2 (another device) but got %d", msgs[2].StreamID)
}
// updating a device sets the next stream ID for that user
msgs = []api.DeviceMessage{
{
Type: api.TypeDeviceKeyUpdate,
DeviceKeys: &api.DeviceKeys{
DeviceID: "AAA",
UserID: alice,
KeyJSON: []byte(`{"key":"v2"}`),
},
// StreamID: 3
},
}
MustNotError(t, db.StoreLocalDeviceKeys(ctx, msgs))
if msgs[0].StreamID != 3 {
t.Fatalf("Expected StoreLocalDeviceKeys to set StreamID=3 (new key same device) but got %d", msgs[0].StreamID)
}
dbLock.Lock()
defer dbLock.Unlock()
// Querying for device keys returns the latest stream IDs
msgs, err = db.DeviceKeysForUser(ctx, alice, deviceArray, false)
if err != nil {
t.Fatalf("DeviceKeysForUser returned error: %s", err)
}
wantStreamIDs := map[string]int64{
"AAA": 3,
"another_device": 2,
}
if len(msgs) != len(wantStreamIDs) {
t.Fatalf("DeviceKeysForUser: wrong number of devices, got %d want %d", len(msgs), len(wantStreamIDs))
}
for _, m := range msgs {
if m.StreamID != wantStreamIDs[m.DeviceID] {
t.Errorf("DeviceKeysForUser: wrong returned stream ID for key, got %d want %d", m.StreamID, wantStreamIDs[m.DeviceID])
}
}
})
}

View file

@ -32,10 +32,10 @@ func NewUserAPIDatabase(
openIDTokenLifetimeMS int64,
loginTokenLifetime time.Duration,
serverNoticesLocalpart string,
) (Database, error) {
) (UserDatabase, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():
return sqlite3.NewDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart)
return sqlite3.NewUserDatabase(base, dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart)
case dbProperties.ConnectionString.IsPostgres():
return nil, fmt.Errorf("can't use Postgres implementation")
default:

View file

@ -20,10 +20,10 @@ import (
"encoding/json"
"time"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/types"
)
@ -145,3 +145,47 @@ const (
// uint32.
AllNotifications NotificationFilter = (1 << 31) - 1
)
type OneTimeKeys interface {
SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error)
CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error)
InsertOneTimeKeys(ctx context.Context, txn *sql.Tx, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error)
// SelectAndDeleteOneTimeKey selects a single one time key matching the user/device/algorithm specified and returns the algo:key_id => JSON.
// Returns an empty map if the key does not exist.
SelectAndDeleteOneTimeKey(ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string) (map[string]json.RawMessage, error)
DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error
}
type DeviceKeys interface {
SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error
SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error)
CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error)
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error)
DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error
DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error
}
type KeyChanges interface {
InsertKeyChange(ctx context.Context, userID string) (int64, error)
// SelectKeyChanges returns the set (de-duplicated) of users who have changed their keys between the two offsets.
// Results are exclusive of fromOffset and inclusive of toOffset. A toOffset of types.OffsetNewest means no upper offset.
SelectKeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error)
}
type StaleDeviceLists interface {
InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error
SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error)
DeleteStaleDeviceLists(ctx context.Context, txn *sql.Tx, userIDs []string) error
}
type CrossSigningKeys interface {
SelectCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string) (r types.CrossSigningKeyMap, err error)
UpsertCrossSigningKeysForUser(ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes) error
}
type CrossSigningSigs interface {
SelectCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID) (r types.CrossSigningSigMap, err error)
UpsertCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error
DeleteCrossSigningSigsForTarget(ctx context.Context, txn *sql.Tx, targetUserID string, targetKeyID gomatrixserverlib.KeyID) error
}

View file

@ -0,0 +1,94 @@
package tables_test
import (
"context"
"testing"
"github.com/matrix-org/dendrite/userapi/storage/postgres"
"github.com/matrix-org/dendrite/userapi/storage/sqlite3"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/userapi/storage/tables"
)
func mustCreateTable(t *testing.T, dbType test.DBType) (tab tables.StaleDeviceLists, close func()) {
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, nil)
if err != nil {
t.Fatalf("failed to open database: %s", err)
}
switch dbType {
case test.DBTypePostgres:
tab, err = postgres.NewPostgresStaleDeviceListsTable(db)
case test.DBTypeSQLite:
tab, err = sqlite3.NewSqliteStaleDeviceListsTable(db)
}
if err != nil {
t.Fatalf("failed to create new table: %s", err)
}
return tab, close
}
func TestStaleDeviceLists(t *testing.T) {
alice := test.NewUser(t)
bob := test.NewUser(t)
charlie := "@charlie:localhost"
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
tab, closeDB := mustCreateTable(t, dbType)
defer closeDB()
if err := tab.InsertStaleDeviceList(ctx, alice.ID, true); err != nil {
t.Fatalf("failed to insert stale device: %s", err)
}
if err := tab.InsertStaleDeviceList(ctx, bob.ID, true); err != nil {
t.Fatalf("failed to insert stale device: %s", err)
}
if err := tab.InsertStaleDeviceList(ctx, charlie, true); err != nil {
t.Fatalf("failed to insert stale device: %s", err)
}
// Query one server
wantStaleUsers := []string{alice.ID, bob.ID}
gotStaleUsers, err := tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"})
if err != nil {
t.Fatalf("failed to query stale device lists: %s", err)
}
if !test.UnsortedStringSliceEqual(wantStaleUsers, gotStaleUsers) {
t.Fatalf("expected stale users %v, got %v", wantStaleUsers, gotStaleUsers)
}
// Query all servers
wantStaleUsers = []string{alice.ID, bob.ID, charlie}
gotStaleUsers, err = tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{})
if err != nil {
t.Fatalf("failed to query stale device lists: %s", err)
}
if !test.UnsortedStringSliceEqual(wantStaleUsers, gotStaleUsers) {
t.Fatalf("expected stale users %v, got %v", wantStaleUsers, gotStaleUsers)
}
// Delete stale devices
deleteUsers := []string{alice.ID, bob.ID}
if err = tab.DeleteStaleDeviceLists(ctx, nil, deleteUsers); err != nil {
t.Fatalf("failed to delete stale device lists: %s", err)
}
// Verify we don't get anything back after deleting
gotStaleUsers, err = tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"})
if err != nil {
t.Fatalf("failed to query stale device lists: %s", err)
}
if gotCount := len(gotStaleUsers); gotCount > 0 {
t.Fatalf("expected no stale users, got %d", gotCount)
}
})
}