mirror of
https://github.com/hoernschen/dendrite.git
synced 2024-12-26 15:08:28 +00:00
Remodel how device list change IDs are created
Previously we made them using the offset Kafka supplied. We don't run Kafka anymore, so now we make the SQL table assign the change ID via an AUTOINCREMENTing ID. Redesign the `keyserver_key_changes` table to have `UNIQUE(user_id)` so we don't accumulate key changes forevermore, we now have at most 1 row per user which contains the highest change ID. This needs a SQL migration.
This commit is contained in:
parent
31f1810814
commit
5dc360481a
10 changed files with 109 additions and 91 deletions
|
@ -70,6 +70,7 @@ type DeviceMessage struct {
|
|||
*eduapi.OutputCrossSigningKeyUpdate `json:"CrossSigningKeyUpdate,omitempty"`
|
||||
// A monotonically increasing number which represents device changes for this user.
|
||||
StreamID int
|
||||
DeviceChangeID int64
|
||||
}
|
||||
|
||||
// DeviceKeys represents a set of device keys for a single device
|
||||
|
|
|
@ -59,8 +59,7 @@ func (a *KeyInternalAPI) InputDeviceListUpdate(
|
|||
}
|
||||
|
||||
func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) {
|
||||
partition := 0
|
||||
userIDs, latest, err := a.DB.KeyChanges(ctx, int32(partition), req.Offset, req.ToOffset)
|
||||
userIDs, latest, err := a.DB.KeyChanges(ctx, req.Offset, req.ToOffset)
|
||||
if err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: err.Error(),
|
||||
|
|
|
@ -40,7 +40,7 @@ func AddInternalRoutes(router *mux.Router, intAPI api.KeyInternalAPI) {
|
|||
func NewInternalAPI(
|
||||
base *base.BaseDendrite, cfg *config.KeyServer, fedClient fedsenderapi.FederationClient,
|
||||
) api.KeyInternalAPI {
|
||||
_, consumer, producer := jetstream.Prepare(&cfg.Matrix.JetStream)
|
||||
js, consumer, _ := jetstream.Prepare(&cfg.Matrix.JetStream)
|
||||
|
||||
db, err := storage.NewDatabase(&cfg.Database)
|
||||
if err != nil {
|
||||
|
@ -48,7 +48,7 @@ func NewInternalAPI(
|
|||
}
|
||||
keyChangeProducer := &producers.KeyChange{
|
||||
Topic: string(cfg.Matrix.JetStream.TopicFor(jetstream.OutputKeyChangeEvent)),
|
||||
Producer: producer,
|
||||
JetStream: js,
|
||||
DB: db,
|
||||
}
|
||||
ap := &internal.KeyInternalAPI{
|
||||
|
|
|
@ -18,17 +18,18 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/Shopify/sarama"
|
||||
eduapi "github.com/matrix-org/dendrite/eduserver/api"
|
||||
"github.com/matrix-org/dendrite/keyserver/api"
|
||||
"github.com/matrix-org/dendrite/keyserver/storage"
|
||||
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||
"github.com/nats-io/nats.go"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// KeyChange produces key change events for the sync API and federation sender to consume
|
||||
type KeyChange struct {
|
||||
Topic string
|
||||
Producer sarama.SyncProducer
|
||||
JetStream nats.JetStreamContext
|
||||
DB storage.Database
|
||||
}
|
||||
|
||||
|
@ -36,25 +37,28 @@ type KeyChange struct {
|
|||
func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceMessage) error {
|
||||
userToDeviceCount := make(map[string]int)
|
||||
for _, key := range keys {
|
||||
var m sarama.ProducerMessage
|
||||
|
||||
id, err := p.DB.StoreKeyChange(context.Background(), key.UserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
key.DeviceChangeID = id
|
||||
value, err := json.Marshal(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.Topic = string(p.Topic)
|
||||
m.Key = sarama.StringEncoder(key.UserID)
|
||||
m.Value = sarama.ByteEncoder(value)
|
||||
m := &nats.Msg{
|
||||
Subject: p.Topic,
|
||||
Header: nats.Header{},
|
||||
}
|
||||
m.Header.Set(jetstream.UserID, key.UserID)
|
||||
m.Data = value
|
||||
|
||||
partition, offset, err := p.Producer.SendMessage(&m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = p.DB.StoreKeyChange(context.Background(), partition, offset, key.UserID)
|
||||
_, err = p.JetStream.PublishMsg(m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
userToDeviceCount[key.UserID]++
|
||||
}
|
||||
for userID, count := range userToDeviceCount {
|
||||
|
@ -67,7 +71,6 @@ func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceMessage) error {
|
|||
}
|
||||
|
||||
func (p *KeyChange) ProduceSigningKeyUpdate(key eduapi.CrossSigningKeyUpdate) error {
|
||||
var m sarama.ProducerMessage
|
||||
output := &api.DeviceMessage{
|
||||
Type: api.TypeCrossSigningUpdate,
|
||||
OutputCrossSigningKeyUpdate: &eduapi.OutputCrossSigningKeyUpdate{
|
||||
|
@ -75,20 +78,25 @@ func (p *KeyChange) ProduceSigningKeyUpdate(key eduapi.CrossSigningKeyUpdate) er
|
|||
},
|
||||
}
|
||||
|
||||
id, err := p.DB.StoreKeyChange(context.Background(), key.UserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
output.DeviceChangeID = id
|
||||
|
||||
value, err := json.Marshal(output)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.Topic = string(p.Topic)
|
||||
m.Key = sarama.StringEncoder(key.UserID)
|
||||
m.Value = sarama.ByteEncoder(value)
|
||||
|
||||
partition, offset, err := p.Producer.SendMessage(&m)
|
||||
if err != nil {
|
||||
return err
|
||||
m := &nats.Msg{
|
||||
Subject: p.Topic,
|
||||
Header: nats.Header{},
|
||||
}
|
||||
err = p.DB.StoreKeyChange(context.Background(), partition, offset, key.UserID)
|
||||
m.Header.Set(jetstream.UserID, key.UserID)
|
||||
m.Data = value
|
||||
|
||||
_, err = p.JetStream.PublishMsg(m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -66,14 +66,14 @@ type Database interface {
|
|||
// cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice.
|
||||
ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error)
|
||||
|
||||
// StoreKeyChange stores key change metadata after the change has been sent to Kafka. `userID` is the the user who has changed
|
||||
// their keys in some way.
|
||||
StoreKeyChange(ctx context.Context, partition int32, offset int64, userID string) error
|
||||
// StoreKeyChange stores key change metadata and returns the device change ID which represents the position in the /sync stream for this device change.
|
||||
// `userID` is the the user who has changed their keys in some way.
|
||||
StoreKeyChange(ctx context.Context, userID string) (int64, error)
|
||||
|
||||
// KeyChanges returns a list of user IDs who have modified their keys from the offset given (exclusive) to the offset given (inclusive).
|
||||
// A to offset of sarama.OffsetNewest means no upper limit.
|
||||
// Returns the offset of the latest key change.
|
||||
KeyChanges(ctx context.Context, partition int32, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error)
|
||||
KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error)
|
||||
|
||||
// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
|
||||
// If no domains are given, all user IDs with stale device lists are returned.
|
||||
|
|
|
@ -26,27 +26,27 @@ import (
|
|||
|
||||
var keyChangesSchema = `
|
||||
-- Stores key change information about users. Used to determine when to send updated device lists to clients.
|
||||
CREATE SEQUENCE IF NOT EXISTS keyserver_key_changes_seq;
|
||||
CREATE TABLE IF NOT EXISTS keyserver_key_changes (
|
||||
partition BIGINT NOT NULL,
|
||||
log_offset BIGINT NOT NULL,
|
||||
change_id BIGINT PRIMARY KEY DEFAULT nextval('keyserver_key_changes_seq'),
|
||||
user_id TEXT NOT NULL,
|
||||
CONSTRAINT keyserver_key_changes_unique UNIQUE (partition, log_offset)
|
||||
CONSTRAINT keyserver_key_changes_unique_per_user UNIQUE (user_id)
|
||||
);
|
||||
`
|
||||
|
||||
// Replace based on partition|offset - we should never insert duplicates unless the kafka logs are wiped.
|
||||
// Rather than falling over, just overwrite (though this will mean clients with an existing sync token will
|
||||
// miss out on updates). TODO: Ideally we would detect when kafka logs are purged then purge this table too.
|
||||
// Replace based on user ID. We don't care how many times the user's keys have changed, only that they
|
||||
// have changed, hence we can just keep bumping the change ID for this user.
|
||||
const upsertKeyChangeSQL = "" +
|
||||
"INSERT INTO keyserver_key_changes (partition, log_offset, user_id)" +
|
||||
" VALUES ($1, $2, $3)" +
|
||||
" ON CONFLICT ON CONSTRAINT keyserver_key_changes_unique" +
|
||||
" DO UPDATE SET user_id = $3"
|
||||
"INSERT INTO keyserver_key_changes (user_id)" +
|
||||
" VALUES ($1)" +
|
||||
" ON CONFLICT ON CONSTRAINT keyserver_key_changes_unique_per_user" +
|
||||
" DO UPDATE SET change_id = nextval('keyserver_key_changes_seq')" +
|
||||
" RETURNING change_id"
|
||||
|
||||
// select the highest offset for each user in the range. The grouping by user gives distinct entries and then we just
|
||||
// take the max offset value as the latest offset.
|
||||
const selectKeyChangesSQL = "" +
|
||||
"SELECT user_id, MAX(log_offset) FROM keyserver_key_changes WHERE partition = $1 AND log_offset > $2 AND log_offset <= $3 GROUP BY user_id"
|
||||
"SELECT user_id, MAX(change_id) FROM keyserver_key_changes WHERE change_id > $1 AND change_id <= $2 GROUP BY user_id"
|
||||
|
||||
type keyChangesStatements struct {
|
||||
db *sql.DB
|
||||
|
@ -71,19 +71,19 @@ func NewPostgresKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
|
|||
return s, nil
|
||||
}
|
||||
|
||||
func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error {
|
||||
_, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID)
|
||||
return err
|
||||
func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, userID string) (changeID int64, err error) {
|
||||
err = s.upsertKeyChangeStmt.QueryRowContext(ctx, userID).Scan(&changeID)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *keyChangesStatements) SelectKeyChanges(
|
||||
ctx context.Context, partition int32, fromOffset, toOffset int64,
|
||||
ctx context.Context, fromOffset, toOffset int64,
|
||||
) (userIDs []string, latestOffset int64, err error) {
|
||||
if toOffset == sarama.OffsetNewest {
|
||||
toOffset = math.MaxInt64
|
||||
}
|
||||
latestOffset = fromOffset
|
||||
rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset, toOffset)
|
||||
rows, err := s.selectKeyChangesStmt.QueryContext(ctx, fromOffset, toOffset)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
|
|
@ -135,14 +135,16 @@ func (d *Database) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[st
|
|||
return result, err
|
||||
}
|
||||
|
||||
func (d *Database) StoreKeyChange(ctx context.Context, partition int32, offset int64, userID string) error {
|
||||
return d.Writer.Do(nil, nil, func(_ *sql.Tx) error {
|
||||
return d.KeyChangesTable.InsertKeyChange(ctx, partition, offset, userID)
|
||||
func (d *Database) StoreKeyChange(ctx context.Context, userID string) (id int64, err error) {
|
||||
err = d.Writer.Do(nil, nil, func(_ *sql.Tx) error {
|
||||
id, err = d.KeyChangesTable.InsertKeyChange(ctx, userID)
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (d *Database) KeyChanges(ctx context.Context, partition int32, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) {
|
||||
return d.KeyChangesTable.SelectKeyChanges(ctx, partition, fromOffset, toOffset)
|
||||
func (d *Database) KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) {
|
||||
return d.KeyChangesTable.SelectKeyChanges(ctx, fromOffset, toOffset)
|
||||
}
|
||||
|
||||
// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
|
||||
|
|
|
@ -27,27 +27,26 @@ import (
|
|||
var keyChangesSchema = `
|
||||
-- Stores key change information about users. Used to determine when to send updated device lists to clients.
|
||||
CREATE TABLE IF NOT EXISTS keyserver_key_changes (
|
||||
partition BIGINT NOT NULL,
|
||||
offset BIGINT NOT NULL,
|
||||
change_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
-- The key owner
|
||||
user_id TEXT NOT NULL,
|
||||
UNIQUE (partition, offset)
|
||||
UNIQUE (user_id)
|
||||
);
|
||||
`
|
||||
|
||||
// Replace based on partition|offset - we should never insert duplicates unless the kafka logs are wiped.
|
||||
// Rather than falling over, just overwrite (though this will mean clients with an existing sync token will
|
||||
// miss out on updates). TODO: Ideally we would detect when kafka logs are purged then purge this table too.
|
||||
// Replace based on user ID. We don't care how many times the user's keys have changed, only that they
|
||||
// have changed, hence we can just keep bumping the change ID for this user.
|
||||
const upsertKeyChangeSQL = "" +
|
||||
"INSERT INTO keyserver_key_changes (partition, offset, user_id)" +
|
||||
" VALUES ($1, $2, $3)" +
|
||||
" ON CONFLICT (partition, offset)" +
|
||||
" DO UPDATE SET user_id = $3"
|
||||
"INSERT INTO keyserver_key_changes (user_id)" +
|
||||
" VALUES ($1)" +
|
||||
" ON CONFLICT" +
|
||||
" DO UPDATE SET user_id = $1" +
|
||||
" RETURNING change_id"
|
||||
|
||||
// select the highest offset for each user in the range. The grouping by user gives distinct entries and then we just
|
||||
// take the max offset value as the latest offset.
|
||||
const selectKeyChangesSQL = "" +
|
||||
"SELECT user_id, MAX(offset) FROM keyserver_key_changes WHERE partition = $1 AND offset > $2 AND offset <= $3 GROUP BY user_id"
|
||||
"SELECT user_id, MAX(change_id) FROM keyserver_key_changes WHERE change_id > $1 AND change_id <= $2 GROUP BY user_id"
|
||||
|
||||
type keyChangesStatements struct {
|
||||
db *sql.DB
|
||||
|
@ -72,19 +71,19 @@ func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
|
|||
return s, nil
|
||||
}
|
||||
|
||||
func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error {
|
||||
_, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID)
|
||||
return err
|
||||
func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, userID string) (changeID int64, err error) {
|
||||
err = s.upsertKeyChangeStmt.QueryRowContext(ctx, userID).Scan(&changeID)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *keyChangesStatements) SelectKeyChanges(
|
||||
ctx context.Context, partition int32, fromOffset, toOffset int64,
|
||||
ctx context.Context, fromOffset, toOffset int64,
|
||||
) (userIDs []string, latestOffset int64, err error) {
|
||||
if toOffset == sarama.OffsetNewest {
|
||||
toOffset = math.MaxInt64
|
||||
}
|
||||
latestOffset = fromOffset
|
||||
rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset, toOffset)
|
||||
rows, err := s.selectKeyChangesStmt.QueryContext(ctx, fromOffset, toOffset)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
|
|
@ -44,15 +44,18 @@ func MustNotError(t *testing.T, err error) {
|
|||
func TestKeyChanges(t *testing.T) {
|
||||
db, clean := MustCreateDatabase(t)
|
||||
defer clean()
|
||||
MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost"))
|
||||
MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@bob:localhost"))
|
||||
MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@charlie:localhost"))
|
||||
userIDs, latest, err := db.KeyChanges(ctx, 0, 1, sarama.OffsetNewest)
|
||||
_, err := db.StoreKeyChange(ctx, "@alice:localhost")
|
||||
MustNotError(t, err)
|
||||
deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost")
|
||||
MustNotError(t, err)
|
||||
deviceChangeIDC, err := db.StoreKeyChange(ctx, "@charlie:localhost")
|
||||
MustNotError(t, err)
|
||||
userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDB, sarama.OffsetNewest)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to KeyChanges: %s", err)
|
||||
}
|
||||
if latest != 2 {
|
||||
t.Fatalf("KeyChanges: got latest=%d want 2", latest)
|
||||
if latest != deviceChangeIDC {
|
||||
t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDC)
|
||||
}
|
||||
if !reflect.DeepEqual(userIDs, []string{"@charlie:localhost"}) {
|
||||
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
|
||||
|
@ -62,15 +65,18 @@ func TestKeyChanges(t *testing.T) {
|
|||
func TestKeyChangesNoDupes(t *testing.T) {
|
||||
db, clean := MustCreateDatabase(t)
|
||||
defer clean()
|
||||
MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost"))
|
||||
MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@alice:localhost"))
|
||||
MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@alice:localhost"))
|
||||
userIDs, latest, err := db.KeyChanges(ctx, 0, 0, sarama.OffsetNewest)
|
||||
_, err := db.StoreKeyChange(ctx, "@alice:localhost")
|
||||
MustNotError(t, err)
|
||||
_, err = db.StoreKeyChange(ctx, "@alice:localhost")
|
||||
MustNotError(t, err)
|
||||
deviceChangeID, err := db.StoreKeyChange(ctx, "@alice:localhost")
|
||||
MustNotError(t, err)
|
||||
userIDs, latest, err := db.KeyChanges(ctx, 0, sarama.OffsetNewest)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to KeyChanges: %s", err)
|
||||
}
|
||||
if latest != 2 {
|
||||
t.Fatalf("KeyChanges: got latest=%d want 2", latest)
|
||||
if latest != deviceChangeID {
|
||||
t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeID)
|
||||
}
|
||||
if !reflect.DeepEqual(userIDs, []string{"@alice:localhost"}) {
|
||||
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
|
||||
|
@ -80,15 +86,18 @@ func TestKeyChangesNoDupes(t *testing.T) {
|
|||
func TestKeyChangesUpperLimit(t *testing.T) {
|
||||
db, clean := MustCreateDatabase(t)
|
||||
defer clean()
|
||||
MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost"))
|
||||
MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@bob:localhost"))
|
||||
MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@charlie:localhost"))
|
||||
userIDs, latest, err := db.KeyChanges(ctx, 0, 0, 1)
|
||||
deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost")
|
||||
MustNotError(t, err)
|
||||
deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost")
|
||||
MustNotError(t, err)
|
||||
_, err = db.StoreKeyChange(ctx, "@charlie:localhost")
|
||||
MustNotError(t, err)
|
||||
userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDA, deviceChangeIDB)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to KeyChanges: %s", err)
|
||||
}
|
||||
if latest != 1 {
|
||||
t.Fatalf("KeyChanges: got latest=%d want 1", latest)
|
||||
if latest != deviceChangeIDB {
|
||||
t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDB)
|
||||
}
|
||||
if !reflect.DeepEqual(userIDs, []string{"@bob:localhost"}) {
|
||||
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
|
||||
|
|
|
@ -44,10 +44,10 @@ type DeviceKeys interface {
|
|||
}
|
||||
|
||||
type KeyChanges interface {
|
||||
InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error
|
||||
InsertKeyChange(ctx context.Context, userID string) (int64, error)
|
||||
// SelectKeyChanges returns the set (de-duplicated) of users who have changed their keys between the two offsets.
|
||||
// Results are exclusive of fromOffset and inclusive of toOffset. A toOffset of sarama.OffsetNewest means no upper offset.
|
||||
SelectKeyChanges(ctx context.Context, partition int32, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error)
|
||||
SelectKeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error)
|
||||
}
|
||||
|
||||
type StaleDeviceLists interface {
|
||||
|
|
Loading…
Reference in a new issue