mirror of
https://github.com/hoernschen/dendrite.git
synced 2025-08-01 22:02:46 +00:00
Merge keyserver & userapi (#2972)
As discussed yesterday, a first draft of merging the keyserver and the userapi.
This commit is contained in:
parent
bd6f0c14e5
commit
4594233f89
107 changed files with 1730 additions and 1863 deletions
|
@ -78,7 +78,13 @@ func (s *accountDataStatements) InsertAccountData(
|
|||
roomID, dataType string, content json.RawMessage,
|
||||
) (err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertAccountDataStmt)
|
||||
_, err = stmt.ExecContext(ctx, localpart, serverName, roomID, dataType, content)
|
||||
// Empty/nil json.RawMessage is not interpreted as "nil", so use *json.RawMessage
|
||||
// when passing the data to trigger "NOT NULL" constraint
|
||||
var data *json.RawMessage
|
||||
if len(content) > 0 {
|
||||
data = &content
|
||||
}
|
||||
_, err = stmt.ExecContext(ctx, localpart, serverName, roomID, dataType, data)
|
||||
return
|
||||
}
|
||||
|
||||
|
|
102
userapi/storage/postgres/cross_signing_keys_table.go
Normal file
102
userapi/storage/postgres/cross_signing_keys_table.go
Normal file
|
@ -0,0 +1,102 @@
|
|||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/userapi/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
var crossSigningKeysSchema = `
|
||||
CREATE TABLE IF NOT EXISTS keyserver_cross_signing_keys (
|
||||
user_id TEXT NOT NULL,
|
||||
key_type SMALLINT NOT NULL,
|
||||
key_data TEXT NOT NULL,
|
||||
PRIMARY KEY (user_id, key_type)
|
||||
);
|
||||
`
|
||||
|
||||
const selectCrossSigningKeysForUserSQL = "" +
|
||||
"SELECT key_type, key_data FROM keyserver_cross_signing_keys" +
|
||||
" WHERE user_id = $1"
|
||||
|
||||
const upsertCrossSigningKeysForUserSQL = "" +
|
||||
"INSERT INTO keyserver_cross_signing_keys (user_id, key_type, key_data)" +
|
||||
" VALUES($1, $2, $3)" +
|
||||
" ON CONFLICT (user_id, key_type) DO UPDATE SET key_data = $3"
|
||||
|
||||
type crossSigningKeysStatements struct {
|
||||
db *sql.DB
|
||||
selectCrossSigningKeysForUserStmt *sql.Stmt
|
||||
upsertCrossSigningKeysForUserStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresCrossSigningKeysTable(db *sql.DB) (tables.CrossSigningKeys, error) {
|
||||
s := &crossSigningKeysStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(crossSigningKeysSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.selectCrossSigningKeysForUserStmt, selectCrossSigningKeysForUserSQL},
|
||||
{&s.upsertCrossSigningKeysForUserStmt, upsertCrossSigningKeysForUserSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *crossSigningKeysStatements) SelectCrossSigningKeysForUser(
|
||||
ctx context.Context, txn *sql.Tx, userID string,
|
||||
) (r types.CrossSigningKeyMap, err error) {
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningKeysForUserStmt).QueryContext(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningKeysForUserStmt: rows.close() failed")
|
||||
r = types.CrossSigningKeyMap{}
|
||||
for rows.Next() {
|
||||
var keyTypeInt int16
|
||||
var keyData gomatrixserverlib.Base64Bytes
|
||||
if err := rows.Scan(&keyTypeInt, &keyData); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keyType, ok := types.KeyTypeIntToPurpose[keyTypeInt]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unknown key purpose int %d", keyTypeInt)
|
||||
}
|
||||
r[keyType] = keyData
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *crossSigningKeysStatements) UpsertCrossSigningKeysForUser(
|
||||
ctx context.Context, txn *sql.Tx, userID string, keyType gomatrixserverlib.CrossSigningKeyPurpose, keyData gomatrixserverlib.Base64Bytes,
|
||||
) error {
|
||||
keyTypeInt, ok := types.KeyTypePurposeToInt[keyType]
|
||||
if !ok {
|
||||
return fmt.Errorf("unknown key purpose %q", keyType)
|
||||
}
|
||||
if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningKeysForUserStmt).ExecContext(ctx, userID, keyTypeInt, keyData); err != nil {
|
||||
return fmt.Errorf("s.upsertCrossSigningKeysForUserStmt: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
131
userapi/storage/postgres/cross_signing_sigs_table.go
Normal file
131
userapi/storage/postgres/cross_signing_sigs_table.go
Normal file
|
@ -0,0 +1,131 @@
|
|||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/postgres/deltas"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/userapi/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
var crossSigningSigsSchema = `
|
||||
CREATE TABLE IF NOT EXISTS keyserver_cross_signing_sigs (
|
||||
origin_user_id TEXT NOT NULL,
|
||||
origin_key_id TEXT NOT NULL,
|
||||
target_user_id TEXT NOT NULL,
|
||||
target_key_id TEXT NOT NULL,
|
||||
signature TEXT NOT NULL,
|
||||
PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id);
|
||||
`
|
||||
|
||||
const selectCrossSigningSigsForTargetSQL = "" +
|
||||
"SELECT origin_user_id, origin_key_id, signature FROM keyserver_cross_signing_sigs" +
|
||||
" WHERE (origin_user_id = $1 OR origin_user_id = $2) AND target_user_id = $2 AND target_key_id = $3"
|
||||
|
||||
const upsertCrossSigningSigsForTargetSQL = "" +
|
||||
"INSERT INTO keyserver_cross_signing_sigs (origin_user_id, origin_key_id, target_user_id, target_key_id, signature)" +
|
||||
" VALUES($1, $2, $3, $4, $5)" +
|
||||
" ON CONFLICT (origin_user_id, origin_key_id, target_user_id, target_key_id) DO UPDATE SET signature = $5"
|
||||
|
||||
const deleteCrossSigningSigsForTargetSQL = "" +
|
||||
"DELETE FROM keyserver_cross_signing_sigs WHERE target_user_id=$1 AND target_key_id=$2"
|
||||
|
||||
type crossSigningSigsStatements struct {
|
||||
db *sql.DB
|
||||
selectCrossSigningSigsForTargetStmt *sql.Stmt
|
||||
upsertCrossSigningSigsForTargetStmt *sql.Stmt
|
||||
deleteCrossSigningSigsForTargetStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresCrossSigningSigsTable(db *sql.DB) (tables.CrossSigningSigs, error) {
|
||||
s := &crossSigningSigsStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(crossSigningSigsSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m := sqlutil.NewMigrator(db)
|
||||
m.AddMigrations(sqlutil.Migration{
|
||||
Version: "keyserver: cross signing signature indexes",
|
||||
Up: deltas.UpFixCrossSigningSignatureIndexes,
|
||||
})
|
||||
if err = m.Up(context.Background()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.selectCrossSigningSigsForTargetStmt, selectCrossSigningSigsForTargetSQL},
|
||||
{&s.upsertCrossSigningSigsForTargetStmt, upsertCrossSigningSigsForTargetSQL},
|
||||
{&s.deleteCrossSigningSigsForTargetStmt, deleteCrossSigningSigsForTargetSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *crossSigningSigsStatements) SelectCrossSigningSigsForTarget(
|
||||
ctx context.Context, txn *sql.Tx, originUserID, targetUserID string, targetKeyID gomatrixserverlib.KeyID,
|
||||
) (r types.CrossSigningSigMap, err error) {
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectCrossSigningSigsForTargetStmt).QueryContext(ctx, originUserID, targetUserID, targetKeyID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectCrossSigningSigsForTargetStmt: rows.close() failed")
|
||||
r = types.CrossSigningSigMap{}
|
||||
for rows.Next() {
|
||||
var userID string
|
||||
var keyID gomatrixserverlib.KeyID
|
||||
var signature gomatrixserverlib.Base64Bytes
|
||||
if err := rows.Scan(&userID, &keyID, &signature); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, ok := r[userID]; !ok {
|
||||
r[userID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
|
||||
}
|
||||
r[userID][keyID] = signature
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *crossSigningSigsStatements) UpsertCrossSigningSigsForTarget(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
originUserID string, originKeyID gomatrixserverlib.KeyID,
|
||||
targetUserID string, targetKeyID gomatrixserverlib.KeyID,
|
||||
signature gomatrixserverlib.Base64Bytes,
|
||||
) error {
|
||||
if _, err := sqlutil.TxStmt(txn, s.upsertCrossSigningSigsForTargetStmt).ExecContext(ctx, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil {
|
||||
return fmt.Errorf("s.upsertCrossSigningSigsForTargetStmt: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *crossSigningSigsStatements) DeleteCrossSigningSigsForTarget(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
targetUserID string, targetKeyID gomatrixserverlib.KeyID,
|
||||
) error {
|
||||
if _, err := sqlutil.TxStmt(txn, s.deleteCrossSigningSigsForTargetStmt).ExecContext(ctx, targetUserID, targetKeyID); err != nil {
|
||||
return fmt.Errorf("s.deleteCrossSigningSigsForTargetStmt: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,69 @@
|
|||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package deltas
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func UpRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error {
|
||||
// start counting from the last max offset, else 0. We need to do a count(*) first to see if there
|
||||
// even are entries in this table to know if we can query for log_offset. Without the count then
|
||||
// the query to SELECT the max log offset fails on new Dendrite instances as log_offset doesn't
|
||||
// exist on that table. Even though we discard the error, the txn is tainted and gets aborted :/
|
||||
var count int
|
||||
_ = tx.QueryRowContext(ctx, `SELECT count(*) FROM keyserver_key_changes`).Scan(&count)
|
||||
if count > 0 {
|
||||
var maxOffset int64
|
||||
_ = tx.QueryRowContext(ctx, `SELECT coalesce(MAX(log_offset), 0) AS offset FROM keyserver_key_changes`).Scan(&maxOffset)
|
||||
if _, err := tx.ExecContext(ctx, fmt.Sprintf(`CREATE SEQUENCE IF NOT EXISTS keyserver_key_changes_seq START %d`, maxOffset)); err != nil {
|
||||
return fmt.Errorf("failed to CREATE SEQUENCE for key changes, starting at %d: %s", maxOffset, err)
|
||||
}
|
||||
}
|
||||
|
||||
_, err := tx.ExecContext(ctx, `
|
||||
-- make the new table
|
||||
DROP TABLE IF EXISTS keyserver_key_changes;
|
||||
CREATE TABLE IF NOT EXISTS keyserver_key_changes (
|
||||
change_id BIGINT PRIMARY KEY DEFAULT nextval('keyserver_key_changes_seq'),
|
||||
user_id TEXT NOT NULL,
|
||||
CONSTRAINT keyserver_key_changes_unique_per_user UNIQUE (user_id)
|
||||
);
|
||||
`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute upgrade: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func DownRefactorKeyChanges(ctx context.Context, tx *sql.Tx) error {
|
||||
_, err := tx.ExecContext(ctx, `
|
||||
-- Drop all data and revert back, we can't keep the data as Kafka offsets determine the numbers
|
||||
DROP SEQUENCE IF EXISTS keyserver_key_changes_seq;
|
||||
DROP TABLE IF EXISTS keyserver_key_changes;
|
||||
CREATE TABLE IF NOT EXISTS keyserver_key_changes (
|
||||
partition BIGINT NOT NULL,
|
||||
log_offset BIGINT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
CONSTRAINT keyserver_key_changes_unique UNIQUE (partition, log_offset)
|
||||
);
|
||||
`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute downgrade: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,47 @@
|
|||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package deltas
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func UpFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error {
|
||||
_, err := tx.ExecContext(ctx, `
|
||||
ALTER TABLE keyserver_cross_signing_sigs DROP CONSTRAINT keyserver_cross_signing_sigs_pkey;
|
||||
ALTER TABLE keyserver_cross_signing_sigs ADD PRIMARY KEY (origin_user_id, origin_key_id, target_user_id, target_key_id);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS keyserver_cross_signing_sigs_idx ON keyserver_cross_signing_sigs (origin_user_id, target_user_id, target_key_id);
|
||||
`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute upgrade: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func DownFixCrossSigningSignatureIndexes(ctx context.Context, tx *sql.Tx) error {
|
||||
_, err := tx.ExecContext(ctx, `
|
||||
ALTER TABLE keyserver_cross_signing_sigs DROP CONSTRAINT keyserver_cross_signing_sigs_pkey;
|
||||
ALTER TABLE keyserver_cross_signing_sigs ADD PRIMARY KEY (origin_user_id, target_user_id, target_key_id);
|
||||
|
||||
DROP INDEX IF EXISTS keyserver_cross_signing_sigs_idx;
|
||||
`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute downgrade: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
213
userapi/storage/postgres/device_keys_table.go
Normal file
213
userapi/storage/postgres/device_keys_table.go
Normal file
|
@ -0,0 +1,213 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
)
|
||||
|
||||
var deviceKeysSchema = `
|
||||
-- Stores device keys for users
|
||||
CREATE TABLE IF NOT EXISTS keyserver_device_keys (
|
||||
user_id TEXT NOT NULL,
|
||||
device_id TEXT NOT NULL,
|
||||
ts_added_secs BIGINT NOT NULL,
|
||||
key_json TEXT NOT NULL,
|
||||
-- the stream ID of this key, scoped per-user. This gets updated when the device key changes.
|
||||
-- This means we do not store an unbounded append-only log of device keys, which is not actually
|
||||
-- required in the spec because in the event of a missed update the server fetches the entire
|
||||
-- current set of keys rather than trying to 'fast-forward' or catchup missing stream IDs.
|
||||
stream_id BIGINT NOT NULL,
|
||||
display_name TEXT,
|
||||
-- Clobber based on tuple of user/device.
|
||||
CONSTRAINT keyserver_device_keys_unique UNIQUE (user_id, device_id)
|
||||
);
|
||||
`
|
||||
|
||||
const upsertDeviceKeysSQL = "" +
|
||||
"INSERT INTO keyserver_device_keys (user_id, device_id, ts_added_secs, key_json, stream_id, display_name)" +
|
||||
" VALUES ($1, $2, $3, $4, $5, $6)" +
|
||||
" ON CONFLICT ON CONSTRAINT keyserver_device_keys_unique" +
|
||||
" DO UPDATE SET key_json = $4, stream_id = $5, display_name = $6"
|
||||
|
||||
const selectDeviceKeysSQL = "" +
|
||||
"SELECT key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
|
||||
|
||||
const selectBatchDeviceKeysSQL = "" +
|
||||
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1 AND key_json <> ''"
|
||||
|
||||
const selectBatchDeviceKeysWithEmptiesSQL = "" +
|
||||
"SELECT device_id, key_json, stream_id, display_name FROM keyserver_device_keys WHERE user_id=$1"
|
||||
|
||||
const selectMaxStreamForUserSQL = "" +
|
||||
"SELECT MAX(stream_id) FROM keyserver_device_keys WHERE user_id=$1"
|
||||
|
||||
const countStreamIDsForUserSQL = "" +
|
||||
"SELECT COUNT(*) FROM keyserver_device_keys WHERE user_id=$1 AND stream_id = ANY($2)"
|
||||
|
||||
const deleteDeviceKeysSQL = "" +
|
||||
"DELETE FROM keyserver_device_keys WHERE user_id=$1 AND device_id=$2"
|
||||
|
||||
const deleteAllDeviceKeysSQL = "" +
|
||||
"DELETE FROM keyserver_device_keys WHERE user_id=$1"
|
||||
|
||||
type deviceKeysStatements struct {
|
||||
db *sql.DB
|
||||
upsertDeviceKeysStmt *sql.Stmt
|
||||
selectDeviceKeysStmt *sql.Stmt
|
||||
selectBatchDeviceKeysStmt *sql.Stmt
|
||||
selectBatchDeviceKeysWithEmptiesStmt *sql.Stmt
|
||||
selectMaxStreamForUserStmt *sql.Stmt
|
||||
countStreamIDsForUserStmt *sql.Stmt
|
||||
deleteDeviceKeysStmt *sql.Stmt
|
||||
deleteAllDeviceKeysStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
||||
s := &deviceKeysStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(deviceKeysSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.upsertDeviceKeysStmt, upsertDeviceKeysSQL},
|
||||
{&s.selectDeviceKeysStmt, selectDeviceKeysSQL},
|
||||
{&s.selectBatchDeviceKeysStmt, selectBatchDeviceKeysSQL},
|
||||
{&s.selectBatchDeviceKeysWithEmptiesStmt, selectBatchDeviceKeysWithEmptiesSQL},
|
||||
{&s.selectMaxStreamForUserStmt, selectMaxStreamForUserSQL},
|
||||
{&s.countStreamIDsForUserStmt, countStreamIDsForUserSQL},
|
||||
{&s.deleteDeviceKeysStmt, deleteDeviceKeysSQL},
|
||||
{&s.deleteAllDeviceKeysStmt, deleteAllDeviceKeysSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
|
||||
for i, key := range keys {
|
||||
var keyJSONStr string
|
||||
var streamID int64
|
||||
var displayName sql.NullString
|
||||
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
// this will be '' when there is no device
|
||||
keys[i].Type = api.TypeDeviceKeyUpdate
|
||||
keys[i].KeyJSON = []byte(keyJSONStr)
|
||||
keys[i].StreamID = streamID
|
||||
if displayName.Valid {
|
||||
keys[i].DisplayName = displayName.String
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error) {
|
||||
// nullable if there are no results
|
||||
var nullStream sql.NullInt64
|
||||
err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
|
||||
if err == sql.ErrNoRows {
|
||||
err = nil
|
||||
}
|
||||
if nullStream.Valid {
|
||||
streamID = nullStream.Int64
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error) {
|
||||
// nullable if there are no results
|
||||
var count sql.NullInt32
|
||||
err := s.countStreamIDsForUserStmt.QueryRowContext(ctx, userID, pq.Int64Array(streamIDs)).Scan(&count)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if count.Valid {
|
||||
return int(count.Int32), nil
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error {
|
||||
for _, key := range keys {
|
||||
now := time.Now().Unix()
|
||||
_, err := sqlutil.TxStmt(txn, s.upsertDeviceKeysStmt).ExecContext(
|
||||
ctx, key.UserID, key.DeviceID, now, string(key.KeyJSON), key.StreamID, key.DisplayName,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.deleteDeviceKeysStmt).ExecContext(ctx, userID, deviceID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) DeleteAllDeviceKeys(ctx context.Context, txn *sql.Tx, userID string) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.deleteAllDeviceKeysStmt).ExecContext(ctx, userID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error) {
|
||||
var stmt *sql.Stmt
|
||||
if includeEmpty {
|
||||
stmt = s.selectBatchDeviceKeysWithEmptiesStmt
|
||||
} else {
|
||||
stmt = s.selectBatchDeviceKeysStmt
|
||||
}
|
||||
rows, err := stmt.QueryContext(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectBatchDeviceKeysStmt: rows.close() failed")
|
||||
deviceIDMap := make(map[string]bool)
|
||||
for _, d := range deviceIDs {
|
||||
deviceIDMap[d] = true
|
||||
}
|
||||
var result []api.DeviceMessage
|
||||
var displayName sql.NullString
|
||||
for rows.Next() {
|
||||
dk := api.DeviceMessage{
|
||||
Type: api.TypeDeviceKeyUpdate,
|
||||
DeviceKeys: &api.DeviceKeys{
|
||||
UserID: userID,
|
||||
},
|
||||
}
|
||||
if err := rows.Scan(&dk.DeviceID, &dk.KeyJSON, &dk.StreamID, &displayName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if displayName.Valid {
|
||||
dk.DisplayName = displayName.String
|
||||
}
|
||||
// include the key if we want all keys (no device) or it was asked
|
||||
if deviceIDMap[dk.DeviceID] || len(deviceIDs) == 0 {
|
||||
result = append(result, dk)
|
||||
}
|
||||
}
|
||||
return result, rows.Err()
|
||||
}
|
|
@ -160,7 +160,7 @@ func (s *devicesStatements) InsertDevice(
|
|||
if err := stmt.QueryRowContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, createdTimeMS, ipAddr, userAgent).Scan(&sessionID); err != nil {
|
||||
return nil, fmt.Errorf("insertDeviceStmt: %w", err)
|
||||
}
|
||||
return &api.Device{
|
||||
dev := &api.Device{
|
||||
ID: id,
|
||||
UserID: userutil.MakeUserID(localpart, serverName),
|
||||
AccessToken: accessToken,
|
||||
|
@ -168,7 +168,11 @@ func (s *devicesStatements) InsertDevice(
|
|||
LastSeenTS: createdTimeMS,
|
||||
LastSeenIP: ipAddr,
|
||||
UserAgent: userAgent,
|
||||
}, nil
|
||||
}
|
||||
if displayName != nil {
|
||||
dev.DisplayName = *displayName
|
||||
}
|
||||
return dev, nil
|
||||
}
|
||||
|
||||
func (s *devicesStatements) InsertDeviceWithSessionID(ctx context.Context, txn *sql.Tx, id,
|
||||
|
|
|
@ -52,7 +52,7 @@ const updateBackupKeySQL = "" +
|
|||
const countKeysSQL = "" +
|
||||
"SELECT COUNT(*) FROM userapi_key_backups WHERE user_id = $1 AND version = $2"
|
||||
|
||||
const selectKeysSQL = "" +
|
||||
const selectBackupKeysSQL = "" +
|
||||
"SELECT room_id, session_id, first_message_index, forwarded_count, is_verified, session_data FROM userapi_key_backups " +
|
||||
"WHERE user_id = $1 AND version = $2"
|
||||
|
||||
|
@ -83,7 +83,7 @@ func NewPostgresKeyBackupTable(db *sql.DB) (tables.KeyBackupTable, error) {
|
|||
{&s.insertBackupKeyStmt, insertBackupKeySQL},
|
||||
{&s.updateBackupKeyStmt, updateBackupKeySQL},
|
||||
{&s.countKeysStmt, countKeysSQL},
|
||||
{&s.selectKeysStmt, selectKeysSQL},
|
||||
{&s.selectKeysStmt, selectBackupKeysSQL},
|
||||
{&s.selectKeysByRoomIDStmt, selectKeysByRoomIDSQL},
|
||||
{&s.selectKeysByRoomIDAndSessionIDStmt, selectKeysByRoomIDAndSessionIDSQL},
|
||||
}.Prepare(db)
|
||||
|
|
127
userapi/storage/postgres/key_changes_table.go
Normal file
127
userapi/storage/postgres/key_changes_table.go
Normal file
|
@ -0,0 +1,127 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/postgres/deltas"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
)
|
||||
|
||||
var keyChangesSchema = `
|
||||
-- Stores key change information about users. Used to determine when to send updated device lists to clients.
|
||||
CREATE SEQUENCE IF NOT EXISTS keyserver_key_changes_seq;
|
||||
CREATE TABLE IF NOT EXISTS keyserver_key_changes (
|
||||
change_id BIGINT PRIMARY KEY DEFAULT nextval('keyserver_key_changes_seq'),
|
||||
user_id TEXT NOT NULL,
|
||||
CONSTRAINT keyserver_key_changes_unique_per_user UNIQUE (user_id)
|
||||
);
|
||||
`
|
||||
|
||||
// Replace based on user ID. We don't care how many times the user's keys have changed, only that they
|
||||
// have changed, hence we can just keep bumping the change ID for this user.
|
||||
const upsertKeyChangeSQL = "" +
|
||||
"INSERT INTO keyserver_key_changes (user_id)" +
|
||||
" VALUES ($1)" +
|
||||
" ON CONFLICT ON CONSTRAINT keyserver_key_changes_unique_per_user" +
|
||||
" DO UPDATE SET change_id = nextval('keyserver_key_changes_seq')" +
|
||||
" RETURNING change_id"
|
||||
|
||||
const selectKeyChangesSQL = "" +
|
||||
"SELECT user_id, change_id FROM keyserver_key_changes WHERE change_id > $1 AND change_id <= $2"
|
||||
|
||||
type keyChangesStatements struct {
|
||||
db *sql.DB
|
||||
upsertKeyChangeStmt *sql.Stmt
|
||||
selectKeyChangesStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
|
||||
s := &keyChangesStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(keyChangesSchema)
|
||||
if err != nil {
|
||||
return s, err
|
||||
}
|
||||
|
||||
if err = executeMigration(context.Background(), db); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.upsertKeyChangeStmt, upsertKeyChangeSQL},
|
||||
{&s.selectKeyChangesStmt, selectKeyChangesSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func executeMigration(ctx context.Context, db *sql.DB) error {
|
||||
// TODO: Remove when we are sure we are not having goose artefacts in the db
|
||||
// This forces an error, which indicates the migration is already applied, since the
|
||||
// column partition was removed from the table
|
||||
migrationName := "keyserver: refactor key changes"
|
||||
|
||||
var cName string
|
||||
err := db.QueryRowContext(ctx, "select column_name from information_schema.columns where table_name = 'keyserver_key_changes' AND column_name = 'partition'").Scan(&cName)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) { // migration was already executed, as the column was removed
|
||||
if err = sqlutil.InsertMigration(ctx, db, migrationName); err != nil {
|
||||
return fmt.Errorf("unable to manually insert migration '%s': %w", migrationName, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
m := sqlutil.NewMigrator(db)
|
||||
m.AddMigrations(sqlutil.Migration{
|
||||
Version: migrationName,
|
||||
Up: deltas.UpRefactorKeyChanges,
|
||||
})
|
||||
|
||||
return m.Up(ctx)
|
||||
}
|
||||
|
||||
func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, userID string) (changeID int64, err error) {
|
||||
err = s.upsertKeyChangeStmt.QueryRowContext(ctx, userID).Scan(&changeID)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *keyChangesStatements) SelectKeyChanges(
|
||||
ctx context.Context, fromOffset, toOffset int64,
|
||||
) (userIDs []string, latestOffset int64, err error) {
|
||||
latestOffset = fromOffset
|
||||
rows, err := s.selectKeyChangesStmt.QueryContext(ctx, fromOffset, toOffset)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectKeyChangesStmt: rows.close() failed")
|
||||
for rows.Next() {
|
||||
var userID string
|
||||
var offset int64
|
||||
if err := rows.Scan(&userID, &offset); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if offset > latestOffset {
|
||||
latestOffset = offset
|
||||
}
|
||||
userIDs = append(userIDs, userID)
|
||||
}
|
||||
return
|
||||
}
|
194
userapi/storage/postgres/one_time_keys_table.go
Normal file
194
userapi/storage/postgres/one_time_keys_table.go
Normal file
|
@ -0,0 +1,194 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
)
|
||||
|
||||
var oneTimeKeysSchema = `
|
||||
-- Stores one-time public keys for users
|
||||
CREATE TABLE IF NOT EXISTS keyserver_one_time_keys (
|
||||
user_id TEXT NOT NULL,
|
||||
device_id TEXT NOT NULL,
|
||||
key_id TEXT NOT NULL,
|
||||
algorithm TEXT NOT NULL,
|
||||
ts_added_secs BIGINT NOT NULL,
|
||||
key_json TEXT NOT NULL,
|
||||
-- Clobber based on 4-uple of user/device/key/algorithm.
|
||||
CONSTRAINT keyserver_one_time_keys_unique UNIQUE (user_id, device_id, key_id, algorithm)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS keyserver_one_time_keys_idx ON keyserver_one_time_keys (user_id, device_id);
|
||||
`
|
||||
|
||||
const upsertKeysSQL = "" +
|
||||
"INSERT INTO keyserver_one_time_keys (user_id, device_id, key_id, algorithm, ts_added_secs, key_json)" +
|
||||
" VALUES ($1, $2, $3, $4, $5, $6)" +
|
||||
" ON CONFLICT ON CONSTRAINT keyserver_one_time_keys_unique" +
|
||||
" DO UPDATE SET key_json = $6"
|
||||
|
||||
const selectOneTimeKeysSQL = "" +
|
||||
"SELECT concat(algorithm, ':', key_id) as algorithmwithid, key_json FROM keyserver_one_time_keys WHERE user_id=$1 AND device_id=$2 AND concat(algorithm, ':', key_id) = ANY($3);"
|
||||
|
||||
const selectKeysCountSQL = "" +
|
||||
"SELECT algorithm, COUNT(key_id) FROM " +
|
||||
" (SELECT algorithm, key_id FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 LIMIT 100)" +
|
||||
" x GROUP BY algorithm"
|
||||
|
||||
const deleteOneTimeKeySQL = "" +
|
||||
"DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 AND key_id = $4"
|
||||
|
||||
const selectKeyByAlgorithmSQL = "" +
|
||||
"SELECT key_id, key_json FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2 AND algorithm = $3 LIMIT 1"
|
||||
|
||||
const deleteOneTimeKeysSQL = "" +
|
||||
"DELETE FROM keyserver_one_time_keys WHERE user_id = $1 AND device_id = $2"
|
||||
|
||||
type oneTimeKeysStatements struct {
|
||||
db *sql.DB
|
||||
upsertKeysStmt *sql.Stmt
|
||||
selectKeysStmt *sql.Stmt
|
||||
selectKeysCountStmt *sql.Stmt
|
||||
selectKeyByAlgorithmStmt *sql.Stmt
|
||||
deleteOneTimeKeyStmt *sql.Stmt
|
||||
deleteOneTimeKeysStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresOneTimeKeysTable(db *sql.DB) (tables.OneTimeKeys, error) {
|
||||
s := &oneTimeKeysStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(oneTimeKeysSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.upsertKeysStmt, upsertKeysSQL},
|
||||
{&s.selectKeysStmt, selectOneTimeKeysSQL},
|
||||
{&s.selectKeysCountStmt, selectKeysCountSQL},
|
||||
{&s.selectKeyByAlgorithmStmt, selectKeyByAlgorithmSQL},
|
||||
{&s.deleteOneTimeKeyStmt, deleteOneTimeKeySQL},
|
||||
{&s.deleteOneTimeKeysStmt, deleteOneTimeKeysSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *oneTimeKeysStatements) SelectOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
|
||||
rows, err := s.selectKeysStmt.QueryContext(ctx, userID, deviceID, pq.Array(keyIDsWithAlgorithms))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysStmt: rows.close() failed")
|
||||
|
||||
result := make(map[string]json.RawMessage)
|
||||
var (
|
||||
algorithmWithID string
|
||||
keyJSONStr string
|
||||
)
|
||||
for rows.Next() {
|
||||
if err := rows.Scan(&algorithmWithID, &keyJSONStr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result[algorithmWithID] = json.RawMessage(keyJSONStr)
|
||||
}
|
||||
return result, rows.Err()
|
||||
}
|
||||
|
||||
func (s *oneTimeKeysStatements) CountOneTimeKeys(ctx context.Context, userID, deviceID string) (*api.OneTimeKeysCount, error) {
|
||||
counts := &api.OneTimeKeysCount{
|
||||
DeviceID: deviceID,
|
||||
UserID: userID,
|
||||
KeyCount: make(map[string]int),
|
||||
}
|
||||
rows, err := s.selectKeysCountStmt.QueryContext(ctx, userID, deviceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
|
||||
for rows.Next() {
|
||||
var algorithm string
|
||||
var count int
|
||||
if err = rows.Scan(&algorithm, &count); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
counts.KeyCount[algorithm] = count
|
||||
}
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
func (s *oneTimeKeysStatements) InsertOneTimeKeys(ctx context.Context, txn *sql.Tx, keys api.OneTimeKeys) (*api.OneTimeKeysCount, error) {
|
||||
now := time.Now().Unix()
|
||||
counts := &api.OneTimeKeysCount{
|
||||
DeviceID: keys.DeviceID,
|
||||
UserID: keys.UserID,
|
||||
KeyCount: make(map[string]int),
|
||||
}
|
||||
for keyIDWithAlgo, keyJSON := range keys.KeyJSON {
|
||||
algo, keyID := keys.Split(keyIDWithAlgo)
|
||||
_, err := sqlutil.TxStmt(txn, s.upsertKeysStmt).ExecContext(
|
||||
ctx, keys.UserID, keys.DeviceID, keyID, algo, now, string(keyJSON),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectKeysCountStmt).QueryContext(ctx, keys.UserID, keys.DeviceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectKeysCountStmt: rows.close() failed")
|
||||
for rows.Next() {
|
||||
var algorithm string
|
||||
var count int
|
||||
if err = rows.Scan(&algorithm, &count); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
counts.KeyCount[algorithm] = count
|
||||
}
|
||||
|
||||
return counts, rows.Err()
|
||||
}
|
||||
|
||||
func (s *oneTimeKeysStatements) SelectAndDeleteOneTimeKey(
|
||||
ctx context.Context, txn *sql.Tx, userID, deviceID, algorithm string,
|
||||
) (map[string]json.RawMessage, error) {
|
||||
var keyID string
|
||||
var keyJSON string
|
||||
err := sqlutil.TxStmtContext(ctx, txn, s.selectKeyByAlgorithmStmt).QueryRowContext(ctx, userID, deviceID, algorithm).Scan(&keyID, &keyJSON)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
_, err = sqlutil.TxStmtContext(ctx, txn, s.deleteOneTimeKeyStmt).ExecContext(ctx, userID, deviceID, algorithm, keyID)
|
||||
return map[string]json.RawMessage{
|
||||
algorithm + ":" + keyID: json.RawMessage(keyJSON),
|
||||
}, err
|
||||
}
|
||||
|
||||
func (s *oneTimeKeysStatements) DeleteOneTimeKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.deleteOneTimeKeysStmt).ExecContext(ctx, userID, deviceID)
|
||||
return err
|
||||
}
|
131
userapi/storage/postgres/stale_device_lists.go
Normal file
131
userapi/storage/postgres/stale_device_lists.go
Normal file
|
@ -0,0 +1,131 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
var staleDeviceListsSchema = `
|
||||
-- Stores whether a user's device lists are stale or not.
|
||||
CREATE TABLE IF NOT EXISTS keyserver_stale_device_lists (
|
||||
user_id TEXT PRIMARY KEY NOT NULL,
|
||||
domain TEXT NOT NULL,
|
||||
is_stale BOOLEAN NOT NULL,
|
||||
ts_added_secs BIGINT NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS keyserver_stale_device_lists_idx ON keyserver_stale_device_lists (domain, is_stale);
|
||||
`
|
||||
|
||||
const upsertStaleDeviceListSQL = "" +
|
||||
"INSERT INTO keyserver_stale_device_lists (user_id, domain, is_stale, ts_added_secs)" +
|
||||
" VALUES ($1, $2, $3, $4)" +
|
||||
" ON CONFLICT (user_id)" +
|
||||
" DO UPDATE SET is_stale = $3, ts_added_secs = $4"
|
||||
|
||||
const selectStaleDeviceListsWithDomainsSQL = "" +
|
||||
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 AND domain = $2 ORDER BY ts_added_secs DESC"
|
||||
|
||||
const selectStaleDeviceListsSQL = "" +
|
||||
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 ORDER BY ts_added_secs DESC"
|
||||
|
||||
const deleteStaleDevicesSQL = "" +
|
||||
"DELETE FROM keyserver_stale_device_lists WHERE user_id = ANY($1)"
|
||||
|
||||
type staleDeviceListsStatements struct {
|
||||
upsertStaleDeviceListStmt *sql.Stmt
|
||||
selectStaleDeviceListsWithDomainsStmt *sql.Stmt
|
||||
selectStaleDeviceListsStmt *sql.Stmt
|
||||
deleteStaleDeviceListsStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) {
|
||||
s := &staleDeviceListsStatements{}
|
||||
_, err := db.Exec(staleDeviceListsSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.upsertStaleDeviceListStmt, upsertStaleDeviceListSQL},
|
||||
{&s.selectStaleDeviceListsStmt, selectStaleDeviceListsSQL},
|
||||
{&s.selectStaleDeviceListsWithDomainsStmt, selectStaleDeviceListsWithDomainsSQL},
|
||||
{&s.deleteStaleDeviceListsStmt, deleteStaleDevicesSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error {
|
||||
_, domain, err := gomatrixserverlib.SplitID('@', userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = s.upsertStaleDeviceListStmt.ExecContext(ctx, userID, string(domain), isStale, gomatrixserverlib.AsTimestamp(time.Now()))
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
|
||||
// we only query for 1 domain or all domains so optimise for those use cases
|
||||
if len(domains) == 0 {
|
||||
rows, err := s.selectStaleDeviceListsStmt.QueryContext(ctx, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return rowsToUserIDs(ctx, rows)
|
||||
}
|
||||
var result []string
|
||||
for _, domain := range domains {
|
||||
rows, err := s.selectStaleDeviceListsWithDomainsStmt.QueryContext(ctx, true, string(domain))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userIDs, err := rowsToUserIDs(ctx, rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, userIDs...)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// DeleteStaleDeviceLists removes users from stale device lists
|
||||
func (s *staleDeviceListsStatements) DeleteStaleDeviceLists(
|
||||
ctx context.Context, txn *sql.Tx, userIDs []string,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteStaleDeviceListsStmt)
|
||||
_, err := stmt.ExecContext(ctx, pq.Array(userIDs))
|
||||
return err
|
||||
}
|
||||
|
||||
func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) {
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed")
|
||||
for rows.Next() {
|
||||
var userID string
|
||||
if err := rows.Scan(&userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, userID)
|
||||
}
|
||||
return result, rows.Err()
|
||||
}
|
|
@ -136,3 +136,44 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
|
|||
OpenIDTokenLifetimeMS: openIDTokenLifetimeMS,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func NewKeyDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions) (*shared.KeyDatabase, error) {
|
||||
db, writer, err := base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
otk, err := NewPostgresOneTimeKeysTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dk, err := NewPostgresDeviceKeysTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
kc, err := NewPostgresKeyChangesTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sdl, err := NewPostgresStaleDeviceListsTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
csk, err := NewPostgresCrossSigningKeysTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
css, err := NewPostgresCrossSigningSigsTable(db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &shared.KeyDatabase{
|
||||
OneTimeKeysTable: otk,
|
||||
DeviceKeysTable: dk,
|
||||
KeyChangesTable: kc,
|
||||
StaleDeviceListsTable: sdl,
|
||||
CrossSigningKeysTable: csk,
|
||||
CrossSigningSigsTable: css,
|
||||
Writer: writer,
|
||||
}, nil
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue