Persist partition|offset|user_id in the keyserver (#1226)

* Persist partition|offset|user_id in the keyserver

Required for a query API which will be used by the syncapi which
will be called when a `/sync` request comes in which will return
a list of user IDs of people who have changed their device keys
between two tokens.

* Add tests and fix maxOffset bug

* s/offset/log_offset/g because 'offset' is a reserved word in postgres
This commit is contained in:
Kegsay 2020-07-28 17:38:30 +01:00 committed by GitHub
parent acc8e80a51
commit adf7b59294
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 292 additions and 0 deletions

View file

@ -49,6 +49,7 @@ func NewInternalAPI(
keyChangeProducer := &producers.KeyChange{ keyChangeProducer := &producers.KeyChange{
Topic: string(cfg.Kafka.Topics.OutputKeyChangeEvent), Topic: string(cfg.Kafka.Topics.OutputKeyChangeEvent),
Producer: producer, Producer: producer,
DB: db,
} }
return &internal.KeyInternalAPI{ return &internal.KeyInternalAPI{
DB: db, DB: db,

View file

@ -15,10 +15,12 @@
package producers package producers
import ( import (
"context"
"encoding/json" "encoding/json"
"github.com/Shopify/sarama" "github.com/Shopify/sarama"
"github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/keyserver/storage"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@ -26,6 +28,7 @@ import (
type KeyChange struct { type KeyChange struct {
Topic string Topic string
Producer sarama.SyncProducer Producer sarama.SyncProducer
DB storage.Database
} }
// ProduceKeyChanges creates new change events for each key // ProduceKeyChanges creates new change events for each key
@ -46,6 +49,10 @@ func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceKeys) error {
if err != nil { if err != nil {
return err return err
} }
err = p.DB.StoreKeyChange(context.Background(), partition, offset, key.UserID)
if err != nil {
return err
}
logrus.WithFields(logrus.Fields{ logrus.WithFields(logrus.Fields{
"user_id": key.UserID, "user_id": key.UserID,
"device_id": key.DeviceID, "device_id": key.DeviceID,

View file

@ -43,4 +43,12 @@ type Database interface {
// ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key // ClaimKeys based on the 3-uple of user_id, device_id and algorithm name. Returns the keys claimed. Returns no error if a key
// cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice. // cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice.
ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error)
// StoreKeyChange stores key change metadata after the change has been sent to Kafka. `userID` is the the user who has changed
// their keys in some way.
StoreKeyChange(ctx context.Context, partition int32, offset int64, userID string) error
// KeyChanges returns a list of user IDs who have modified their keys from the offset given.
// Returns the offset of the latest key change.
KeyChanges(ctx context.Context, partition int32, fromOffset int64) (userIDs []string, latestOffset int64, err error)
} }

View file

@ -0,0 +1,97 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
)
var keyChangesSchema = `
-- Stores key change information about users. Used to determine when to send updated device lists to clients.
CREATE TABLE IF NOT EXISTS keyserver_key_changes (
partition BIGINT NOT NULL,
log_offset BIGINT NOT NULL,
user_id TEXT NOT NULL,
CONSTRAINT keyserver_key_changes_unique UNIQUE (partition, log_offset)
);
`
// Replace based on partition|offset - we should never insert duplicates unless the kafka logs are wiped.
// Rather than falling over, just overwrite (though this will mean clients with an existing sync token will
// miss out on updates). TODO: Ideally we would detect when kafka logs are purged then purge this table too.
const upsertKeyChangeSQL = "" +
"INSERT INTO keyserver_key_changes (partition, log_offset, user_id)" +
" VALUES ($1, $2, $3)" +
" ON CONFLICT ON CONSTRAINT keyserver_key_changes_unique" +
" DO UPDATE SET user_id = $3"
// select the highest offset for each user in the range. The grouping by user gives distinct entries and then we just
// take the max offset value as the latest offset.
const selectKeyChangesSQL = "" +
"SELECT user_id, MAX(log_offset) FROM keyserver_key_changes WHERE partition = $1 AND log_offset > $2 GROUP BY user_id"
type keyChangesStatements struct {
db *sql.DB
upsertKeyChangeStmt *sql.Stmt
selectKeyChangesStmt *sql.Stmt
}
func NewPostgresKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
s := &keyChangesStatements{
db: db,
}
_, err := db.Exec(keyChangesSchema)
if err != nil {
return nil, err
}
if s.upsertKeyChangeStmt, err = db.Prepare(upsertKeyChangeSQL); err != nil {
return nil, err
}
if s.selectKeyChangesStmt, err = db.Prepare(selectKeyChangesSQL); err != nil {
return nil, err
}
return s, nil
}
func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error {
_, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID)
return err
}
func (s *keyChangesStatements) SelectKeyChanges(
ctx context.Context, partition int32, fromOffset int64,
) (userIDs []string, latestOffset int64, err error) {
rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset)
if err != nil {
return nil, 0, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectKeyChangesStmt: rows.close() failed")
for rows.Next() {
var userID string
var offset int64
if err := rows.Scan(&userID, &offset); err != nil {
return nil, 0, err
}
if offset > latestOffset {
latestOffset = offset
}
userIDs = append(userIDs, userID)
}
return
}

View file

@ -34,9 +34,14 @@ func NewDatabase(dbDataSourceName string, dbProperties sqlutil.DbProperties) (*s
if err != nil { if err != nil {
return nil, err return nil, err
} }
kc, err := NewPostgresKeyChangesTable(db)
if err != nil {
return nil, err
}
return &shared.Database{ return &shared.Database{
DB: db, DB: db,
OneTimeKeysTable: otk, OneTimeKeysTable: otk,
DeviceKeysTable: dk, DeviceKeysTable: dk,
KeyChangesTable: kc,
}, nil }, nil
} }

View file

@ -28,6 +28,7 @@ type Database struct {
DB *sql.DB DB *sql.DB
OneTimeKeysTable tables.OneTimeKeys OneTimeKeysTable tables.OneTimeKeys
DeviceKeysTable tables.DeviceKeys DeviceKeysTable tables.DeviceKeys
KeyChangesTable tables.KeyChanges
} }
func (d *Database) ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) { func (d *Database) ExistingOneTimeKeys(ctx context.Context, userID, deviceID string, keyIDsWithAlgorithms []string) (map[string]json.RawMessage, error) {
@ -72,3 +73,11 @@ func (d *Database) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[st
}) })
return result, err return result, err
} }
func (d *Database) StoreKeyChange(ctx context.Context, partition int32, offset int64, userID string) error {
return d.KeyChangesTable.InsertKeyChange(ctx, partition, offset, userID)
}
func (d *Database) KeyChanges(ctx context.Context, partition int32, fromOffset int64) (userIDs []string, latestOffset int64, err error) {
return d.KeyChangesTable.SelectKeyChanges(ctx, partition, fromOffset)
}

View file

@ -0,0 +1,98 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/keyserver/storage/tables"
)
var keyChangesSchema = `
-- Stores key change information about users. Used to determine when to send updated device lists to clients.
CREATE TABLE IF NOT EXISTS keyserver_key_changes (
partition BIGINT NOT NULL,
offset BIGINT NOT NULL,
-- The key owner
user_id TEXT NOT NULL,
UNIQUE (partition, offset)
);
`
// Replace based on partition|offset - we should never insert duplicates unless the kafka logs are wiped.
// Rather than falling over, just overwrite (though this will mean clients with an existing sync token will
// miss out on updates). TODO: Ideally we would detect when kafka logs are purged then purge this table too.
const upsertKeyChangeSQL = "" +
"INSERT INTO keyserver_key_changes (partition, offset, user_id)" +
" VALUES ($1, $2, $3)" +
" ON CONFLICT (partition, offset)" +
" DO UPDATE SET user_id = $3"
// select the highest offset for each user in the range. The grouping by user gives distinct entries and then we just
// take the max offset value as the latest offset.
const selectKeyChangesSQL = "" +
"SELECT user_id, MAX(offset) FROM keyserver_key_changes WHERE partition = $1 AND offset > $2 GROUP BY user_id"
type keyChangesStatements struct {
db *sql.DB
upsertKeyChangeStmt *sql.Stmt
selectKeyChangesStmt *sql.Stmt
}
func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
s := &keyChangesStatements{
db: db,
}
_, err := db.Exec(keyChangesSchema)
if err != nil {
return nil, err
}
if s.upsertKeyChangeStmt, err = db.Prepare(upsertKeyChangeSQL); err != nil {
return nil, err
}
if s.selectKeyChangesStmt, err = db.Prepare(selectKeyChangesSQL); err != nil {
return nil, err
}
return s, nil
}
func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error {
_, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID)
return err
}
func (s *keyChangesStatements) SelectKeyChanges(
ctx context.Context, partition int32, fromOffset int64,
) (userIDs []string, latestOffset int64, err error) {
rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset)
if err != nil {
return nil, 0, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectKeyChangesStmt: rows.close() failed")
for rows.Next() {
var userID string
var offset int64
if err := rows.Scan(&userID, &offset); err != nil {
return nil, 0, err
}
if offset > latestOffset {
latestOffset = offset
}
userIDs = append(userIDs, userID)
}
return
}

View file

@ -37,9 +37,14 @@ func NewDatabase(dataSourceName string) (*shared.Database, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
kc, err := NewSqliteKeyChangesTable(db)
if err != nil {
return nil, err
}
return &shared.Database{ return &shared.Database{
DB: db, DB: db,
OneTimeKeysTable: otk, OneTimeKeysTable: otk,
DeviceKeysTable: dk, DeviceKeysTable: dk,
KeyChangesTable: kc,
}, nil }, nil
} }

View file

@ -0,0 +1,57 @@
package storage
import (
"context"
"reflect"
"testing"
)
var ctx = context.Background()
func MustNotError(t *testing.T, err error) {
t.Helper()
if err == nil {
return
}
t.Fatalf("operation failed: %s", err)
}
func TestKeyChanges(t *testing.T) {
db, err := NewDatabase("file::memory:", nil)
if err != nil {
t.Fatalf("Failed to NewDatabase: %s", err)
}
MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost"))
MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@bob:localhost"))
MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@charlie:localhost"))
userIDs, latest, err := db.KeyChanges(ctx, 0, 1)
if err != nil {
t.Fatalf("Failed to KeyChanges: %s", err)
}
if latest != 2 {
t.Fatalf("KeyChanges: got latest=%d want 2", latest)
}
if !reflect.DeepEqual(userIDs, []string{"@charlie:localhost"}) {
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
}
}
func TestKeyChangesNoDupes(t *testing.T) {
db, err := NewDatabase("file::memory:", nil)
if err != nil {
t.Fatalf("Failed to NewDatabase: %s", err)
}
MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost"))
MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@alice:localhost"))
MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@alice:localhost"))
userIDs, latest, err := db.KeyChanges(ctx, 0, 0)
if err != nil {
t.Fatalf("Failed to KeyChanges: %s", err)
}
if latest != 2 {
t.Fatalf("KeyChanges: got latest=%d want 2", latest)
}
if !reflect.DeepEqual(userIDs, []string{"@alice:localhost"}) {
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
}
}

View file

@ -35,3 +35,8 @@ type DeviceKeys interface {
InsertDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error InsertDeviceKeys(ctx context.Context, keys []api.DeviceKeys) error
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error) SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string) ([]api.DeviceKeys, error)
} }
type KeyChanges interface {
InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error
SelectKeyChanges(ctx context.Context, partition int32, fromOffset int64) (userIDs []string, latestOffset int64, err error)
}