mirror of
https://github.com/hoernschen/dendrite.git
synced 2024-12-26 15:08:28 +00:00
Merge branch 'master' into neilalexander/federationinput
This commit is contained in:
commit
7a93bb32e7
29 changed files with 408 additions and 381 deletions
|
@ -189,7 +189,9 @@ func buildDendrite(httpClient *http.Client, dockerClient *client.Client, tmpDir,
|
|||
if err := decoder.Decode(&dl); err != nil {
|
||||
return "", fmt.Errorf("failed to decode build image output line: %w", err)
|
||||
}
|
||||
log.Printf("%s: %s", branchOrTagName, dl.Stream)
|
||||
if len(strings.TrimSpace(dl.Stream)) > 0 {
|
||||
log.Printf("%s: %s", branchOrTagName, dl.Stream)
|
||||
}
|
||||
if dl.Aux != nil {
|
||||
imgID, ok := dl.Aux["ID"]
|
||||
if ok {
|
||||
|
@ -425,8 +427,10 @@ func cleanup(dockerClient *client.Client) {
|
|||
// ignore all errors, we are just cleaning up and don't want to fail just because we fail to cleanup
|
||||
containers, _ := dockerClient.ContainerList(context.Background(), types.ContainerListOptions{
|
||||
Filters: label(dendriteUpgradeTestLabel),
|
||||
All: true,
|
||||
})
|
||||
for _, c := range containers {
|
||||
log.Printf("Removing container: %v %v\n", c.ID, c.Names)
|
||||
s := time.Second
|
||||
_ = dockerClient.ContainerStop(context.Background(), c.ID, &s)
|
||||
_ = dockerClient.ContainerRemove(context.Background(), c.ID, types.ContainerRemoveOptions{
|
||||
|
|
|
@ -368,6 +368,7 @@ logging:
|
|||
- type: std
|
||||
level: info
|
||||
- type: file
|
||||
# The logging level, must be one of debug, info, warn, error, fatal, panic.
|
||||
level: info
|
||||
params:
|
||||
path: ./logs
|
||||
|
|
|
@ -1,3 +1,7 @@
|
|||
#change IP to location of monolith server
|
||||
upstream monolith{
|
||||
server 127.0.0.1:8008;
|
||||
}
|
||||
server {
|
||||
listen 443 ssl; # IPv4
|
||||
listen [::]:443 ssl; # IPv6
|
||||
|
@ -23,6 +27,6 @@ server {
|
|||
}
|
||||
|
||||
location /_matrix {
|
||||
proxy_pass http://monolith:8008;
|
||||
proxy_pass http://monolith;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -69,7 +69,8 @@ type DeviceMessage struct {
|
|||
*DeviceKeys `json:"DeviceKeys,omitempty"`
|
||||
*eduapi.OutputCrossSigningKeyUpdate `json:"CrossSigningKeyUpdate,omitempty"`
|
||||
// A monotonically increasing number which represents device changes for this user.
|
||||
StreamID int
|
||||
StreamID int
|
||||
DeviceChangeID int64
|
||||
}
|
||||
|
||||
// DeviceKeys represents a set of device keys for a single device
|
||||
|
@ -224,8 +225,6 @@ type QueryKeysResponse struct {
|
|||
}
|
||||
|
||||
type QueryKeyChangesRequest struct {
|
||||
// The partition which had key events sent to
|
||||
Partition int32
|
||||
// The offset of the last received key event, or sarama.OffsetOldest if this is from the beginning
|
||||
Offset int64
|
||||
// The inclusive offset where to track key changes up to. Messages with this offset are included in the response.
|
||||
|
@ -236,8 +235,6 @@ type QueryKeyChangesRequest struct {
|
|||
type QueryKeyChangesResponse struct {
|
||||
// The set of users who have had their keys change.
|
||||
UserIDs []string
|
||||
// The partition being served - useful if the partition is unknown at request time
|
||||
Partition int32
|
||||
// The latest offset represented in this response.
|
||||
Offset int64
|
||||
// Set if there was a problem handling the request.
|
||||
|
|
|
@ -59,17 +59,13 @@ func (a *KeyInternalAPI) InputDeviceListUpdate(
|
|||
}
|
||||
|
||||
func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) {
|
||||
if req.Partition < 0 {
|
||||
req.Partition = a.Producer.DefaultPartition()
|
||||
}
|
||||
userIDs, latest, err := a.DB.KeyChanges(ctx, req.Partition, req.Offset, req.ToOffset)
|
||||
userIDs, latest, err := a.DB.KeyChanges(ctx, req.Offset, req.ToOffset)
|
||||
if err != nil {
|
||||
res.Error = &api.KeyError{
|
||||
Err: err.Error(),
|
||||
}
|
||||
}
|
||||
res.Offset = latest
|
||||
res.Partition = req.Partition
|
||||
res.UserIDs = userIDs
|
||||
}
|
||||
|
||||
|
|
|
@ -40,16 +40,16 @@ func AddInternalRoutes(router *mux.Router, intAPI api.KeyInternalAPI) {
|
|||
func NewInternalAPI(
|
||||
base *base.BaseDendrite, cfg *config.KeyServer, fedClient fedsenderapi.FederationClient,
|
||||
) api.KeyInternalAPI {
|
||||
_, consumer, producer := jetstream.Prepare(&cfg.Matrix.JetStream)
|
||||
js, consumer, _ := jetstream.Prepare(&cfg.Matrix.JetStream)
|
||||
|
||||
db, err := storage.NewDatabase(&cfg.Database)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Panicf("failed to connect to key server database")
|
||||
}
|
||||
keyChangeProducer := &producers.KeyChange{
|
||||
Topic: string(cfg.Matrix.JetStream.TopicFor(jetstream.OutputKeyChangeEvent)),
|
||||
Producer: producer,
|
||||
DB: db,
|
||||
Topic: string(cfg.Matrix.JetStream.TopicFor(jetstream.OutputKeyChangeEvent)),
|
||||
JetStream: js,
|
||||
DB: db,
|
||||
}
|
||||
ap := &internal.KeyInternalAPI{
|
||||
DB: db,
|
||||
|
|
|
@ -18,52 +18,47 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/Shopify/sarama"
|
||||
eduapi "github.com/matrix-org/dendrite/eduserver/api"
|
||||
"github.com/matrix-org/dendrite/keyserver/api"
|
||||
"github.com/matrix-org/dendrite/keyserver/storage"
|
||||
"github.com/matrix-org/dendrite/setup/jetstream"
|
||||
"github.com/nats-io/nats.go"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// KeyChange produces key change events for the sync API and federation sender to consume
|
||||
type KeyChange struct {
|
||||
Topic string
|
||||
Producer sarama.SyncProducer
|
||||
DB storage.Database
|
||||
}
|
||||
|
||||
// DefaultPartition returns the default partition this process is sending key changes to.
|
||||
// NB: A keyserver MUST send key changes to only 1 partition or else query operations will
|
||||
// become inconsistent. Partitions can be sharded (e.g by hash of user ID of key change) but
|
||||
// then all keyservers must be queried to calculate the entire set of key changes between
|
||||
// two sync tokens.
|
||||
func (p *KeyChange) DefaultPartition() int32 {
|
||||
return 0
|
||||
Topic string
|
||||
JetStream nats.JetStreamContext
|
||||
DB storage.Database
|
||||
}
|
||||
|
||||
// ProduceKeyChanges creates new change events for each key
|
||||
func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceMessage) error {
|
||||
userToDeviceCount := make(map[string]int)
|
||||
for _, key := range keys {
|
||||
var m sarama.ProducerMessage
|
||||
|
||||
id, err := p.DB.StoreKeyChange(context.Background(), key.UserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
key.DeviceChangeID = id
|
||||
value, err := json.Marshal(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.Topic = string(p.Topic)
|
||||
m.Key = sarama.StringEncoder(key.UserID)
|
||||
m.Value = sarama.ByteEncoder(value)
|
||||
m := &nats.Msg{
|
||||
Subject: p.Topic,
|
||||
Header: nats.Header{},
|
||||
}
|
||||
m.Header.Set(jetstream.UserID, key.UserID)
|
||||
m.Data = value
|
||||
|
||||
partition, offset, err := p.Producer.SendMessage(&m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = p.DB.StoreKeyChange(context.Background(), partition, offset, key.UserID)
|
||||
_, err = p.JetStream.PublishMsg(m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
userToDeviceCount[key.UserID]++
|
||||
}
|
||||
for userID, count := range userToDeviceCount {
|
||||
|
@ -76,7 +71,6 @@ func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceMessage) error {
|
|||
}
|
||||
|
||||
func (p *KeyChange) ProduceSigningKeyUpdate(key eduapi.CrossSigningKeyUpdate) error {
|
||||
var m sarama.ProducerMessage
|
||||
output := &api.DeviceMessage{
|
||||
Type: api.TypeCrossSigningUpdate,
|
||||
OutputCrossSigningKeyUpdate: &eduapi.OutputCrossSigningKeyUpdate{
|
||||
|
@ -84,20 +78,25 @@ func (p *KeyChange) ProduceSigningKeyUpdate(key eduapi.CrossSigningKeyUpdate) er
|
|||
},
|
||||
}
|
||||
|
||||
id, err := p.DB.StoreKeyChange(context.Background(), key.UserID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
output.DeviceChangeID = id
|
||||
|
||||
value, err := json.Marshal(output)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.Topic = string(p.Topic)
|
||||
m.Key = sarama.StringEncoder(key.UserID)
|
||||
m.Value = sarama.ByteEncoder(value)
|
||||
|
||||
partition, offset, err := p.Producer.SendMessage(&m)
|
||||
if err != nil {
|
||||
return err
|
||||
m := &nats.Msg{
|
||||
Subject: p.Topic,
|
||||
Header: nats.Header{},
|
||||
}
|
||||
err = p.DB.StoreKeyChange(context.Background(), partition, offset, key.UserID)
|
||||
m.Header.Set(jetstream.UserID, key.UserID)
|
||||
m.Data = value
|
||||
|
||||
_, err = p.JetStream.PublishMsg(m)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -66,14 +66,14 @@ type Database interface {
|
|||
// cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice.
|
||||
ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error)
|
||||
|
||||
// StoreKeyChange stores key change metadata after the change has been sent to Kafka. `userID` is the the user who has changed
|
||||
// their keys in some way.
|
||||
StoreKeyChange(ctx context.Context, partition int32, offset int64, userID string) error
|
||||
// StoreKeyChange stores key change metadata and returns the device change ID which represents the position in the /sync stream for this device change.
|
||||
// `userID` is the the user who has changed their keys in some way.
|
||||
StoreKeyChange(ctx context.Context, userID string) (int64, error)
|
||||
|
||||
// KeyChanges returns a list of user IDs who have modified their keys from the offset given (exclusive) to the offset given (inclusive).
|
||||
// A to offset of sarama.OffsetNewest means no upper limit.
|
||||
// Returns the offset of the latest key change.
|
||||
KeyChanges(ctx context.Context, partition int32, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error)
|
||||
KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error)
|
||||
|
||||
// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
|
||||
// If no domains are given, all user IDs with stale device lists are returned.
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package deltas
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/pressly/goose"
|
||||
)
|
||||
|
||||
func LoadFromGoose() {
|
||||
goose.AddMigration(UpRefactorKeyChanges, DownRefactorKeyChanges)
|
||||
}
|
||||
|
||||
func LoadRefactorKeyChanges(m *sqlutil.Migrations) {
|
||||
m.AddMigration(UpRefactorKeyChanges, DownRefactorKeyChanges)
|
||||
}
|
||||
|
||||
func UpRefactorKeyChanges(tx *sql.Tx) error {
|
||||
// start counting from the last max offset, else 0. We need to do a count(*) first to see if there
|
||||
// even are entries in this table to know if we can query for log_offset. Without the count then
|
||||
// the query to SELECT the max log offset fails on new Dendrite instances as log_offset doesn't
|
||||
// exist on that table. Even though we discard the error, the txn is tainted and gets aborted :/
|
||||
var count int
|
||||
_ = tx.QueryRow(`SELECT count(*) FROM keyserver_key_changes`).Scan(&count)
|
||||
if count > 0 {
|
||||
var maxOffset int64
|
||||
_ = tx.QueryRow(`SELECT coalesce(MAX(log_offset), 0) AS offset FROM keyserver_key_changes`).Scan(&maxOffset)
|
||||
if _, err := tx.Exec(fmt.Sprintf(`CREATE SEQUENCE IF NOT EXISTS keyserver_key_changes_seq START %d`, maxOffset)); err != nil {
|
||||
return fmt.Errorf("failed to CREATE SEQUENCE for key changes, starting at %d: %s", maxOffset, err)
|
||||
}
|
||||
}
|
||||
|
||||
_, err := tx.Exec(`
|
||||
-- make the new table
|
||||
DROP TABLE IF EXISTS keyserver_key_changes;
|
||||
CREATE TABLE IF NOT EXISTS keyserver_key_changes (
|
||||
change_id BIGINT PRIMARY KEY DEFAULT nextval('keyserver_key_changes_seq'),
|
||||
user_id TEXT NOT NULL,
|
||||
CONSTRAINT keyserver_key_changes_unique_per_user UNIQUE (user_id)
|
||||
);
|
||||
`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute upgrade: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func DownRefactorKeyChanges(tx *sql.Tx) error {
|
||||
_, err := tx.Exec(`
|
||||
-- Drop all data and revert back, we can't keep the data as Kafka offsets determine the numbers
|
||||
DROP SEQUENCE IF EXISTS keyserver_key_changes_seq;
|
||||
DROP TABLE IF EXISTS keyserver_key_changes;
|
||||
CREATE TABLE IF NOT EXISTS keyserver_key_changes (
|
||||
partition BIGINT NOT NULL,
|
||||
log_offset BIGINT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
CONSTRAINT keyserver_key_changes_unique UNIQUE (partition, log_offset)
|
||||
);
|
||||
`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute downgrade: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -26,27 +26,25 @@ import (
|
|||
|
||||
var keyChangesSchema = `
|
||||
-- Stores key change information about users. Used to determine when to send updated device lists to clients.
|
||||
CREATE SEQUENCE IF NOT EXISTS keyserver_key_changes_seq;
|
||||
CREATE TABLE IF NOT EXISTS keyserver_key_changes (
|
||||
partition BIGINT NOT NULL,
|
||||
log_offset BIGINT NOT NULL,
|
||||
change_id BIGINT PRIMARY KEY DEFAULT nextval('keyserver_key_changes_seq'),
|
||||
user_id TEXT NOT NULL,
|
||||
CONSTRAINT keyserver_key_changes_unique UNIQUE (partition, log_offset)
|
||||
CONSTRAINT keyserver_key_changes_unique_per_user UNIQUE (user_id)
|
||||
);
|
||||
`
|
||||
|
||||
// Replace based on partition|offset - we should never insert duplicates unless the kafka logs are wiped.
|
||||
// Rather than falling over, just overwrite (though this will mean clients with an existing sync token will
|
||||
// miss out on updates). TODO: Ideally we would detect when kafka logs are purged then purge this table too.
|
||||
// Replace based on user ID. We don't care how many times the user's keys have changed, only that they
|
||||
// have changed, hence we can just keep bumping the change ID for this user.
|
||||
const upsertKeyChangeSQL = "" +
|
||||
"INSERT INTO keyserver_key_changes (partition, log_offset, user_id)" +
|
||||
" VALUES ($1, $2, $3)" +
|
||||
" ON CONFLICT ON CONSTRAINT keyserver_key_changes_unique" +
|
||||
" DO UPDATE SET user_id = $3"
|
||||
"INSERT INTO keyserver_key_changes (user_id)" +
|
||||
" VALUES ($1)" +
|
||||
" ON CONFLICT ON CONSTRAINT keyserver_key_changes_unique_per_user" +
|
||||
" DO UPDATE SET change_id = nextval('keyserver_key_changes_seq')" +
|
||||
" RETURNING change_id"
|
||||
|
||||
// select the highest offset for each user in the range. The grouping by user gives distinct entries and then we just
|
||||
// take the max offset value as the latest offset.
|
||||
const selectKeyChangesSQL = "" +
|
||||
"SELECT user_id, MAX(log_offset) FROM keyserver_key_changes WHERE partition = $1 AND log_offset > $2 AND log_offset <= $3 GROUP BY user_id"
|
||||
"SELECT user_id, change_id FROM keyserver_key_changes WHERE change_id > $1 AND change_id <= $2"
|
||||
|
||||
type keyChangesStatements struct {
|
||||
db *sql.DB
|
||||
|
@ -59,31 +57,32 @@ func NewPostgresKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
|
|||
db: db,
|
||||
}
|
||||
_, err := db.Exec(keyChangesSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.upsertKeyChangeStmt, err = db.Prepare(upsertKeyChangeSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectKeyChangesStmt, err = db.Prepare(selectKeyChangesSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
return s, err
|
||||
}
|
||||
|
||||
func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error {
|
||||
_, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID)
|
||||
return err
|
||||
func (s *keyChangesStatements) Prepare() (err error) {
|
||||
if s.upsertKeyChangeStmt, err = s.db.Prepare(upsertKeyChangeSQL); err != nil {
|
||||
return err
|
||||
}
|
||||
if s.selectKeyChangesStmt, err = s.db.Prepare(selectKeyChangesSQL); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, userID string) (changeID int64, err error) {
|
||||
err = s.upsertKeyChangeStmt.QueryRowContext(ctx, userID).Scan(&changeID)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *keyChangesStatements) SelectKeyChanges(
|
||||
ctx context.Context, partition int32, fromOffset, toOffset int64,
|
||||
ctx context.Context, fromOffset, toOffset int64,
|
||||
) (userIDs []string, latestOffset int64, err error) {
|
||||
if toOffset == sarama.OffsetNewest {
|
||||
toOffset = math.MaxInt64
|
||||
}
|
||||
latestOffset = fromOffset
|
||||
rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset, toOffset)
|
||||
rows, err := s.selectKeyChangesStmt.QueryContext(ctx, fromOffset, toOffset)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@ package postgres
|
|||
|
||||
import (
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/keyserver/storage/postgres/deltas"
|
||||
"github.com/matrix-org/dendrite/keyserver/storage/shared"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
)
|
||||
|
@ -51,6 +52,14 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error)
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m := sqlutil.NewMigrations()
|
||||
deltas.LoadRefactorKeyChanges(m)
|
||||
if err = m.RunDeltas(db, dbProperties); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = kc.Prepare(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d := &shared.Database{
|
||||
DB: db,
|
||||
Writer: sqlutil.NewDummyWriter(),
|
||||
|
|
|
@ -135,14 +135,16 @@ func (d *Database) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[st
|
|||
return result, err
|
||||
}
|
||||
|
||||
func (d *Database) StoreKeyChange(ctx context.Context, partition int32, offset int64, userID string) error {
|
||||
return d.Writer.Do(nil, nil, func(_ *sql.Tx) error {
|
||||
return d.KeyChangesTable.InsertKeyChange(ctx, partition, offset, userID)
|
||||
func (d *Database) StoreKeyChange(ctx context.Context, userID string) (id int64, err error) {
|
||||
err = d.Writer.Do(nil, nil, func(_ *sql.Tx) error {
|
||||
id, err = d.KeyChangesTable.InsertKeyChange(ctx, userID)
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (d *Database) KeyChanges(ctx context.Context, partition int32, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) {
|
||||
return d.KeyChangesTable.SelectKeyChanges(ctx, partition, fromOffset, toOffset)
|
||||
func (d *Database) KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) {
|
||||
return d.KeyChangesTable.SelectKeyChanges(ctx, fromOffset, toOffset)
|
||||
}
|
||||
|
||||
// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package deltas
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/pressly/goose"
|
||||
)
|
||||
|
||||
func LoadFromGoose() {
|
||||
goose.AddMigration(UpRefactorKeyChanges, DownRefactorKeyChanges)
|
||||
}
|
||||
|
||||
func LoadRefactorKeyChanges(m *sqlutil.Migrations) {
|
||||
m.AddMigration(UpRefactorKeyChanges, DownRefactorKeyChanges)
|
||||
}
|
||||
|
||||
func UpRefactorKeyChanges(tx *sql.Tx) error {
|
||||
// start counting from the last max offset, else 0.
|
||||
var maxOffset int64
|
||||
var userID string
|
||||
_ = tx.QueryRow(`SELECT user_id, MAX(log_offset) FROM keyserver_key_changes GROUP BY user_id`).Scan(&userID, &maxOffset)
|
||||
|
||||
_, err := tx.Exec(`
|
||||
-- make the new table
|
||||
DROP TABLE IF EXISTS keyserver_key_changes;
|
||||
CREATE TABLE IF NOT EXISTS keyserver_key_changes (
|
||||
change_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
-- The key owner
|
||||
user_id TEXT NOT NULL,
|
||||
UNIQUE (user_id)
|
||||
);
|
||||
`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute upgrade: %w", err)
|
||||
}
|
||||
// to start counting from maxOffset, insert a row with that value
|
||||
if userID != "" {
|
||||
_, err = tx.Exec(`INSERT INTO keyserver_key_changes(change_id, user_id) VALUES($1, $2)`, maxOffset, userID)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func DownRefactorKeyChanges(tx *sql.Tx) error {
|
||||
_, err := tx.Exec(`
|
||||
-- Drop all data and revert back, we can't keep the data as Kafka offsets determine the numbers
|
||||
DROP TABLE IF EXISTS keyserver_key_changes;
|
||||
CREATE TABLE IF NOT EXISTS keyserver_key_changes (
|
||||
partition BIGINT NOT NULL,
|
||||
offset BIGINT NOT NULL,
|
||||
-- The key owner
|
||||
user_id TEXT NOT NULL,
|
||||
UNIQUE (partition, offset)
|
||||
);
|
||||
`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute downgrade: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -27,27 +27,22 @@ import (
|
|||
var keyChangesSchema = `
|
||||
-- Stores key change information about users. Used to determine when to send updated device lists to clients.
|
||||
CREATE TABLE IF NOT EXISTS keyserver_key_changes (
|
||||
partition BIGINT NOT NULL,
|
||||
offset BIGINT NOT NULL,
|
||||
change_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
-- The key owner
|
||||
user_id TEXT NOT NULL,
|
||||
UNIQUE (partition, offset)
|
||||
UNIQUE (user_id)
|
||||
);
|
||||
`
|
||||
|
||||
// Replace based on partition|offset - we should never insert duplicates unless the kafka logs are wiped.
|
||||
// Rather than falling over, just overwrite (though this will mean clients with an existing sync token will
|
||||
// miss out on updates). TODO: Ideally we would detect when kafka logs are purged then purge this table too.
|
||||
// Replace based on user ID. We don't care how many times the user's keys have changed, only that they
|
||||
// have changed, hence we can just keep bumping the change ID for this user.
|
||||
const upsertKeyChangeSQL = "" +
|
||||
"INSERT INTO keyserver_key_changes (partition, offset, user_id)" +
|
||||
" VALUES ($1, $2, $3)" +
|
||||
" ON CONFLICT (partition, offset)" +
|
||||
" DO UPDATE SET user_id = $3"
|
||||
"INSERT OR REPLACE INTO keyserver_key_changes (user_id)" +
|
||||
" VALUES ($1)" +
|
||||
" RETURNING change_id"
|
||||
|
||||
// select the highest offset for each user in the range. The grouping by user gives distinct entries and then we just
|
||||
// take the max offset value as the latest offset.
|
||||
const selectKeyChangesSQL = "" +
|
||||
"SELECT user_id, MAX(offset) FROM keyserver_key_changes WHERE partition = $1 AND offset > $2 AND offset <= $3 GROUP BY user_id"
|
||||
"SELECT user_id, change_id FROM keyserver_key_changes WHERE change_id > $1 AND change_id <= $2"
|
||||
|
||||
type keyChangesStatements struct {
|
||||
db *sql.DB
|
||||
|
@ -60,31 +55,32 @@ func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) {
|
|||
db: db,
|
||||
}
|
||||
_, err := db.Exec(keyChangesSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.upsertKeyChangeStmt, err = db.Prepare(upsertKeyChangeSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectKeyChangesStmt, err = db.Prepare(selectKeyChangesSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
return s, err
|
||||
}
|
||||
|
||||
func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error {
|
||||
_, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID)
|
||||
return err
|
||||
func (s *keyChangesStatements) Prepare() (err error) {
|
||||
if s.upsertKeyChangeStmt, err = s.db.Prepare(upsertKeyChangeSQL); err != nil {
|
||||
return err
|
||||
}
|
||||
if s.selectKeyChangesStmt, err = s.db.Prepare(selectKeyChangesSQL); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, userID string) (changeID int64, err error) {
|
||||
err = s.upsertKeyChangeStmt.QueryRowContext(ctx, userID).Scan(&changeID)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *keyChangesStatements) SelectKeyChanges(
|
||||
ctx context.Context, partition int32, fromOffset, toOffset int64,
|
||||
ctx context.Context, fromOffset, toOffset int64,
|
||||
) (userIDs []string, latestOffset int64, err error) {
|
||||
if toOffset == sarama.OffsetNewest {
|
||||
toOffset = math.MaxInt64
|
||||
}
|
||||
latestOffset = fromOffset
|
||||
rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset, toOffset)
|
||||
rows, err := s.selectKeyChangesStmt.QueryContext(ctx, fromOffset, toOffset)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
|
|
@ -17,6 +17,7 @@ package sqlite3
|
|||
import (
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/keyserver/storage/shared"
|
||||
"github.com/matrix-org/dendrite/keyserver/storage/sqlite3/deltas"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
)
|
||||
|
||||
|
@ -49,6 +50,15 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error)
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m := sqlutil.NewMigrations()
|
||||
deltas.LoadRefactorKeyChanges(m)
|
||||
if err = m.RunDeltas(db, dbProperties); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = kc.Prepare(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d := &shared.Database{
|
||||
DB: db,
|
||||
Writer: sqlutil.NewExclusiveWriter(),
|
||||
|
|
|
@ -44,15 +44,18 @@ func MustNotError(t *testing.T, err error) {
|
|||
func TestKeyChanges(t *testing.T) {
|
||||
db, clean := MustCreateDatabase(t)
|
||||
defer clean()
|
||||
MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost"))
|
||||
MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@bob:localhost"))
|
||||
MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@charlie:localhost"))
|
||||
userIDs, latest, err := db.KeyChanges(ctx, 0, 1, sarama.OffsetNewest)
|
||||
_, err := db.StoreKeyChange(ctx, "@alice:localhost")
|
||||
MustNotError(t, err)
|
||||
deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost")
|
||||
MustNotError(t, err)
|
||||
deviceChangeIDC, err := db.StoreKeyChange(ctx, "@charlie:localhost")
|
||||
MustNotError(t, err)
|
||||
userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDB, sarama.OffsetNewest)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to KeyChanges: %s", err)
|
||||
}
|
||||
if latest != 2 {
|
||||
t.Fatalf("KeyChanges: got latest=%d want 2", latest)
|
||||
if latest != deviceChangeIDC {
|
||||
t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDC)
|
||||
}
|
||||
if !reflect.DeepEqual(userIDs, []string{"@charlie:localhost"}) {
|
||||
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
|
||||
|
@ -62,15 +65,21 @@ func TestKeyChanges(t *testing.T) {
|
|||
func TestKeyChangesNoDupes(t *testing.T) {
|
||||
db, clean := MustCreateDatabase(t)
|
||||
defer clean()
|
||||
MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost"))
|
||||
MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@alice:localhost"))
|
||||
MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@alice:localhost"))
|
||||
userIDs, latest, err := db.KeyChanges(ctx, 0, 0, sarama.OffsetNewest)
|
||||
deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost")
|
||||
MustNotError(t, err)
|
||||
deviceChangeIDB, err := db.StoreKeyChange(ctx, "@alice:localhost")
|
||||
MustNotError(t, err)
|
||||
if deviceChangeIDA == deviceChangeIDB {
|
||||
t.Fatalf("Expected change ID to be different even when inserting key change for the same user, got %d for both changes", deviceChangeIDA)
|
||||
}
|
||||
deviceChangeID, err := db.StoreKeyChange(ctx, "@alice:localhost")
|
||||
MustNotError(t, err)
|
||||
userIDs, latest, err := db.KeyChanges(ctx, 0, sarama.OffsetNewest)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to KeyChanges: %s", err)
|
||||
}
|
||||
if latest != 2 {
|
||||
t.Fatalf("KeyChanges: got latest=%d want 2", latest)
|
||||
if latest != deviceChangeID {
|
||||
t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeID)
|
||||
}
|
||||
if !reflect.DeepEqual(userIDs, []string{"@alice:localhost"}) {
|
||||
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
|
||||
|
@ -80,15 +89,18 @@ func TestKeyChangesNoDupes(t *testing.T) {
|
|||
func TestKeyChangesUpperLimit(t *testing.T) {
|
||||
db, clean := MustCreateDatabase(t)
|
||||
defer clean()
|
||||
MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost"))
|
||||
MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@bob:localhost"))
|
||||
MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@charlie:localhost"))
|
||||
userIDs, latest, err := db.KeyChanges(ctx, 0, 0, 1)
|
||||
deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost")
|
||||
MustNotError(t, err)
|
||||
deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost")
|
||||
MustNotError(t, err)
|
||||
_, err = db.StoreKeyChange(ctx, "@charlie:localhost")
|
||||
MustNotError(t, err)
|
||||
userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDA, deviceChangeIDB)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to KeyChanges: %s", err)
|
||||
}
|
||||
if latest != 1 {
|
||||
t.Fatalf("KeyChanges: got latest=%d want 1", latest)
|
||||
if latest != deviceChangeIDB {
|
||||
t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDB)
|
||||
}
|
||||
if !reflect.DeepEqual(userIDs, []string{"@bob:localhost"}) {
|
||||
t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs)
|
||||
|
|
|
@ -44,10 +44,12 @@ type DeviceKeys interface {
|
|||
}
|
||||
|
||||
type KeyChanges interface {
|
||||
InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error
|
||||
InsertKeyChange(ctx context.Context, userID string) (int64, error)
|
||||
// SelectKeyChanges returns the set (de-duplicated) of users who have changed their keys between the two offsets.
|
||||
// Results are exclusive of fromOffset and inclusive of toOffset. A toOffset of sarama.OffsetNewest means no upper offset.
|
||||
SelectKeyChanges(ctx context.Context, partition int32, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error)
|
||||
SelectKeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error)
|
||||
|
||||
Prepare() error
|
||||
}
|
||||
|
||||
type StaleDeviceLists interface {
|
||||
|
|
|
@ -842,9 +842,13 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if roomInfo == nil || roomInfo.IsStub {
|
||||
if roomInfo == nil {
|
||||
return nil, fmt.Errorf("room %s doesn't exist", roomID)
|
||||
}
|
||||
// e.g invited rooms
|
||||
if roomInfo.IsStub {
|
||||
return nil, nil
|
||||
}
|
||||
eventTypeNID, err := d.EventTypesTable.SelectEventTypeNID(ctx, nil, evType)
|
||||
if err == sql.ErrNoRows {
|
||||
// No rooms have an event of this type, otherwise we'd have an event type NID
|
||||
|
|
|
@ -17,7 +17,6 @@ package consumers
|
|||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"sync"
|
||||
|
||||
"github.com/Shopify/sarama"
|
||||
"github.com/getsentry/sentry-go"
|
||||
|
@ -34,16 +33,14 @@ import (
|
|||
|
||||
// OutputKeyChangeEventConsumer consumes events that originated in the key server.
|
||||
type OutputKeyChangeEventConsumer struct {
|
||||
ctx context.Context
|
||||
keyChangeConsumer *internal.ContinualConsumer
|
||||
db storage.Database
|
||||
notifier *notifier.Notifier
|
||||
stream types.PartitionedStreamProvider
|
||||
serverName gomatrixserverlib.ServerName // our server name
|
||||
rsAPI roomserverAPI.RoomserverInternalAPI
|
||||
keyAPI api.KeyInternalAPI
|
||||
partitionToOffset map[int32]int64
|
||||
partitionToOffsetMu sync.Mutex
|
||||
ctx context.Context
|
||||
keyChangeConsumer *internal.ContinualConsumer
|
||||
db storage.Database
|
||||
notifier *notifier.Notifier
|
||||
stream types.StreamProvider
|
||||
serverName gomatrixserverlib.ServerName // our server name
|
||||
rsAPI roomserverAPI.RoomserverInternalAPI
|
||||
keyAPI api.KeyInternalAPI
|
||||
}
|
||||
|
||||
// NewOutputKeyChangeEventConsumer creates a new OutputKeyChangeEventConsumer.
|
||||
|
@ -57,7 +54,7 @@ func NewOutputKeyChangeEventConsumer(
|
|||
rsAPI roomserverAPI.RoomserverInternalAPI,
|
||||
store storage.Database,
|
||||
notifier *notifier.Notifier,
|
||||
stream types.PartitionedStreamProvider,
|
||||
stream types.StreamProvider,
|
||||
) *OutputKeyChangeEventConsumer {
|
||||
|
||||
consumer := internal.ContinualConsumer{
|
||||
|
@ -69,16 +66,14 @@ func NewOutputKeyChangeEventConsumer(
|
|||
}
|
||||
|
||||
s := &OutputKeyChangeEventConsumer{
|
||||
ctx: process.Context(),
|
||||
keyChangeConsumer: &consumer,
|
||||
db: store,
|
||||
serverName: serverName,
|
||||
keyAPI: keyAPI,
|
||||
rsAPI: rsAPI,
|
||||
partitionToOffset: make(map[int32]int64),
|
||||
partitionToOffsetMu: sync.Mutex{},
|
||||
notifier: notifier,
|
||||
stream: stream,
|
||||
ctx: process.Context(),
|
||||
keyChangeConsumer: &consumer,
|
||||
db: store,
|
||||
serverName: serverName,
|
||||
keyAPI: keyAPI,
|
||||
rsAPI: rsAPI,
|
||||
notifier: notifier,
|
||||
stream: stream,
|
||||
}
|
||||
|
||||
consumer.ProcessMessage = s.onMessage
|
||||
|
@ -88,24 +83,10 @@ func NewOutputKeyChangeEventConsumer(
|
|||
|
||||
// Start consuming from the key server
|
||||
func (s *OutputKeyChangeEventConsumer) Start() error {
|
||||
offsets, err := s.keyChangeConsumer.StartOffsets()
|
||||
s.partitionToOffsetMu.Lock()
|
||||
for _, o := range offsets {
|
||||
s.partitionToOffset[o.Partition] = o.Offset
|
||||
}
|
||||
s.partitionToOffsetMu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *OutputKeyChangeEventConsumer) updateOffset(msg *sarama.ConsumerMessage) {
|
||||
s.partitionToOffsetMu.Lock()
|
||||
defer s.partitionToOffsetMu.Unlock()
|
||||
s.partitionToOffset[msg.Partition] = msg.Offset
|
||||
return s.keyChangeConsumer.Start()
|
||||
}
|
||||
|
||||
func (s *OutputKeyChangeEventConsumer) onMessage(msg *sarama.ConsumerMessage) error {
|
||||
defer s.updateOffset(msg)
|
||||
|
||||
var m api.DeviceMessage
|
||||
if err := json.Unmarshal(msg.Value, &m); err != nil {
|
||||
logrus.WithError(err).Errorf("failed to read device message from key change topic")
|
||||
|
@ -118,15 +99,15 @@ func (s *OutputKeyChangeEventConsumer) onMessage(msg *sarama.ConsumerMessage) er
|
|||
}
|
||||
switch m.Type {
|
||||
case api.TypeCrossSigningUpdate:
|
||||
return s.onCrossSigningMessage(m, msg.Offset, msg.Partition)
|
||||
return s.onCrossSigningMessage(m, m.DeviceChangeID)
|
||||
case api.TypeDeviceKeyUpdate:
|
||||
fallthrough
|
||||
default:
|
||||
return s.onDeviceKeyMessage(m, msg.Offset, msg.Partition)
|
||||
return s.onDeviceKeyMessage(m, m.DeviceChangeID)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *OutputKeyChangeEventConsumer) onDeviceKeyMessage(m api.DeviceMessage, offset int64, partition int32) error {
|
||||
func (s *OutputKeyChangeEventConsumer) onDeviceKeyMessage(m api.DeviceMessage, deviceChangeID int64) error {
|
||||
if m.DeviceKeys == nil {
|
||||
return nil
|
||||
}
|
||||
|
@ -143,10 +124,7 @@ func (s *OutputKeyChangeEventConsumer) onDeviceKeyMessage(m api.DeviceMessage, o
|
|||
}
|
||||
// make sure we get our own key updates too!
|
||||
queryRes.UserIDsToCount[output.UserID] = 1
|
||||
posUpdate := types.LogPosition{
|
||||
Offset: offset,
|
||||
Partition: partition,
|
||||
}
|
||||
posUpdate := types.StreamPosition(deviceChangeID)
|
||||
|
||||
s.stream.Advance(posUpdate)
|
||||
for userID := range queryRes.UserIDsToCount {
|
||||
|
@ -156,7 +134,7 @@ func (s *OutputKeyChangeEventConsumer) onDeviceKeyMessage(m api.DeviceMessage, o
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *OutputKeyChangeEventConsumer) onCrossSigningMessage(m api.DeviceMessage, offset int64, partition int32) error {
|
||||
func (s *OutputKeyChangeEventConsumer) onCrossSigningMessage(m api.DeviceMessage, deviceChangeID int64) error {
|
||||
output := m.CrossSigningKeyUpdate
|
||||
// work out who we need to notify about the new key
|
||||
var queryRes roomserverAPI.QuerySharedUsersResponse
|
||||
|
@ -170,10 +148,7 @@ func (s *OutputKeyChangeEventConsumer) onCrossSigningMessage(m api.DeviceMessage
|
|||
}
|
||||
// make sure we get our own key updates too!
|
||||
queryRes.UserIDsToCount[output.UserID] = 1
|
||||
posUpdate := types.LogPosition{
|
||||
Offset: offset,
|
||||
Partition: partition,
|
||||
}
|
||||
posUpdate := types.StreamPosition(deviceChangeID)
|
||||
|
||||
s.stream.Advance(posUpdate)
|
||||
for userID := range queryRes.UserIDsToCount {
|
||||
|
|
|
@ -47,8 +47,8 @@ func DeviceOTKCounts(ctx context.Context, keyAPI keyapi.KeyInternalAPI, userID,
|
|||
// be already filled in with join/leave information.
|
||||
func DeviceListCatchup(
|
||||
ctx context.Context, keyAPI keyapi.KeyInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI,
|
||||
userID string, res *types.Response, from, to types.LogPosition,
|
||||
) (newPos types.LogPosition, hasNew bool, err error) {
|
||||
userID string, res *types.Response, from, to types.StreamPosition,
|
||||
) (newPos types.StreamPosition, hasNew bool, err error) {
|
||||
|
||||
// Track users who we didn't track before but now do by virtue of sharing a room with them, or not.
|
||||
newlyJoinedRooms := joinedRooms(res, userID)
|
||||
|
@ -64,27 +64,18 @@ func DeviceListCatchup(
|
|||
}
|
||||
|
||||
// now also track users who we already share rooms with but who have updated their devices between the two tokens
|
||||
|
||||
var partition int32
|
||||
var offset int64
|
||||
partition = -1
|
||||
offset = sarama.OffsetOldest
|
||||
// Extract partition/offset from sync token
|
||||
// TODO: In a world where keyserver is sharded there will be multiple partitions and hence multiple QueryKeyChanges to make.
|
||||
if !from.IsEmpty() {
|
||||
partition = from.Partition
|
||||
offset = from.Offset
|
||||
offset := sarama.OffsetOldest
|
||||
toOffset := sarama.OffsetNewest
|
||||
if to > 0 && to > from {
|
||||
toOffset = int64(to)
|
||||
}
|
||||
var toOffset int64
|
||||
toOffset = sarama.OffsetNewest
|
||||
if toLog := to; toLog.Partition == partition && toLog.Offset > 0 {
|
||||
toOffset = toLog.Offset
|
||||
if from > 0 {
|
||||
offset = int64(from)
|
||||
}
|
||||
var queryRes keyapi.QueryKeyChangesResponse
|
||||
keyAPI.QueryKeyChanges(ctx, &keyapi.QueryKeyChangesRequest{
|
||||
Partition: partition,
|
||||
Offset: offset,
|
||||
ToOffset: toOffset,
|
||||
Offset: offset,
|
||||
ToOffset: toOffset,
|
||||
}, &queryRes)
|
||||
if queryRes.Error != nil {
|
||||
// don't fail the catchup because we may have got useful information by tracking membership
|
||||
|
@ -95,8 +86,8 @@ func DeviceListCatchup(
|
|||
var sharedUsersMap map[string]int
|
||||
sharedUsersMap, queryRes.UserIDs = filterSharedUsers(ctx, rsAPI, userID, queryRes.UserIDs)
|
||||
util.GetLogger(ctx).Debugf(
|
||||
"QueryKeyChanges request p=%d,off=%d,to=%d response p=%d off=%d uids=%v",
|
||||
partition, offset, toOffset, queryRes.Partition, queryRes.Offset, queryRes.UserIDs,
|
||||
"QueryKeyChanges request off=%d,to=%d response off=%d uids=%v",
|
||||
offset, toOffset, queryRes.Offset, queryRes.UserIDs,
|
||||
)
|
||||
userSet := make(map[string]bool)
|
||||
for _, userID := range res.DeviceLists.Changed {
|
||||
|
@ -125,13 +116,8 @@ func DeviceListCatchup(
|
|||
res.DeviceLists.Left = append(res.DeviceLists.Left, userID)
|
||||
}
|
||||
}
|
||||
// set the new token
|
||||
to = types.LogPosition{
|
||||
Partition: queryRes.Partition,
|
||||
Offset: queryRes.Offset,
|
||||
}
|
||||
|
||||
return to, hasNew, nil
|
||||
return types.StreamPosition(queryRes.Offset), hasNew, nil
|
||||
}
|
||||
|
||||
// TrackChangedUsers calculates the values of device_lists.changed|left in the /sync response.
|
||||
|
|
|
@ -6,7 +6,6 @@ import (
|
|||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/Shopify/sarama"
|
||||
keyapi "github.com/matrix-org/dendrite/keyserver/api"
|
||||
"github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
|
@ -16,11 +15,7 @@ import (
|
|||
|
||||
var (
|
||||
syncingUser = "@alice:localhost"
|
||||
emptyToken = types.LogPosition{}
|
||||
newestToken = types.LogPosition{
|
||||
Offset: sarama.OffsetNewest,
|
||||
Partition: 0,
|
||||
}
|
||||
emptyToken = types.StreamPosition(0)
|
||||
)
|
||||
|
||||
type mockKeyAPI struct{}
|
||||
|
@ -186,7 +181,7 @@ func TestKeyChangeCatchupOnJoinShareNewUser(t *testing.T) {
|
|||
"!another:room": {syncingUser},
|
||||
},
|
||||
}
|
||||
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, newestToken)
|
||||
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
|
||||
if err != nil {
|
||||
t.Fatalf("DeviceListCatchup returned an error: %s", err)
|
||||
}
|
||||
|
@ -209,7 +204,7 @@ func TestKeyChangeCatchupOnLeaveShareLeftUser(t *testing.T) {
|
|||
"!another:room": {syncingUser},
|
||||
},
|
||||
}
|
||||
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, newestToken)
|
||||
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
|
||||
if err != nil {
|
||||
t.Fatalf("DeviceListCatchup returned an error: %s", err)
|
||||
}
|
||||
|
@ -232,7 +227,7 @@ func TestKeyChangeCatchupOnJoinShareNoNewUsers(t *testing.T) {
|
|||
"!another:room": {syncingUser, existingUser},
|
||||
},
|
||||
}
|
||||
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, newestToken)
|
||||
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
|
||||
if err != nil {
|
||||
t.Fatalf("Catchup returned an error: %s", err)
|
||||
}
|
||||
|
@ -254,7 +249,7 @@ func TestKeyChangeCatchupOnLeaveShareNoUsers(t *testing.T) {
|
|||
"!another:room": {syncingUser, existingUser},
|
||||
},
|
||||
}
|
||||
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, newestToken)
|
||||
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
|
||||
if err != nil {
|
||||
t.Fatalf("DeviceListCatchup returned an error: %s", err)
|
||||
}
|
||||
|
@ -313,7 +308,7 @@ func TestKeyChangeCatchupNoNewJoinsButMessages(t *testing.T) {
|
|||
roomID: {syncingUser, existingUser},
|
||||
},
|
||||
}
|
||||
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, newestToken)
|
||||
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
|
||||
if err != nil {
|
||||
t.Fatalf("DeviceListCatchup returned an error: %s", err)
|
||||
}
|
||||
|
@ -341,7 +336,7 @@ func TestKeyChangeCatchupChangeAndLeft(t *testing.T) {
|
|||
"!another:room": {syncingUser},
|
||||
},
|
||||
}
|
||||
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, newestToken)
|
||||
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
|
||||
if err != nil {
|
||||
t.Fatalf("Catchup returned an error: %s", err)
|
||||
}
|
||||
|
@ -427,7 +422,7 @@ func TestKeyChangeCatchupChangeAndLeftSameRoom(t *testing.T) {
|
|||
},
|
||||
}
|
||||
_, hasNew, err := DeviceListCatchup(
|
||||
context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, newestToken,
|
||||
context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("DeviceListCatchup returned an error: %s", err)
|
||||
|
|
|
@ -10,7 +10,7 @@ import (
|
|||
)
|
||||
|
||||
type DeviceListStreamProvider struct {
|
||||
PartitionedStreamProvider
|
||||
StreamProvider
|
||||
rsAPI api.RoomserverInternalAPI
|
||||
keyAPI keyapi.KeyInternalAPI
|
||||
}
|
||||
|
@ -18,15 +18,15 @@ type DeviceListStreamProvider struct {
|
|||
func (p *DeviceListStreamProvider) CompleteSync(
|
||||
ctx context.Context,
|
||||
req *types.SyncRequest,
|
||||
) types.LogPosition {
|
||||
) types.StreamPosition {
|
||||
return p.LatestPosition(ctx)
|
||||
}
|
||||
|
||||
func (p *DeviceListStreamProvider) IncrementalSync(
|
||||
ctx context.Context,
|
||||
req *types.SyncRequest,
|
||||
from, to types.LogPosition,
|
||||
) types.LogPosition {
|
||||
from, to types.StreamPosition,
|
||||
) types.StreamPosition {
|
||||
var err error
|
||||
to, _, err = internal.DeviceListCatchup(context.Background(), p.keyAPI, p.rsAPI, req.Device.UserID, req.Response, from, to)
|
||||
if err != nil {
|
||||
|
|
|
@ -18,7 +18,7 @@ type Streams struct {
|
|||
InviteStreamProvider types.StreamProvider
|
||||
SendToDeviceStreamProvider types.StreamProvider
|
||||
AccountDataStreamProvider types.StreamProvider
|
||||
DeviceListStreamProvider types.PartitionedStreamProvider
|
||||
DeviceListStreamProvider types.StreamProvider
|
||||
}
|
||||
|
||||
func NewSyncStreamProviders(
|
||||
|
@ -48,9 +48,9 @@ func NewSyncStreamProviders(
|
|||
userAPI: userAPI,
|
||||
},
|
||||
DeviceListStreamProvider: &DeviceListStreamProvider{
|
||||
PartitionedStreamProvider: PartitionedStreamProvider{DB: d},
|
||||
rsAPI: rsAPI,
|
||||
keyAPI: keyAPI,
|
||||
StreamProvider: StreamProvider{DB: d},
|
||||
rsAPI: rsAPI,
|
||||
keyAPI: keyAPI,
|
||||
},
|
||||
}
|
||||
|
||||
|
|
|
@ -1,38 +0,0 @@
|
|||
package streams
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/matrix-org/dendrite/syncapi/storage"
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
)
|
||||
|
||||
type PartitionedStreamProvider struct {
|
||||
DB storage.Database
|
||||
latest types.LogPosition
|
||||
latestMutex sync.RWMutex
|
||||
}
|
||||
|
||||
func (p *PartitionedStreamProvider) Setup() {
|
||||
}
|
||||
|
||||
func (p *PartitionedStreamProvider) Advance(
|
||||
latest types.LogPosition,
|
||||
) {
|
||||
p.latestMutex.Lock()
|
||||
defer p.latestMutex.Unlock()
|
||||
|
||||
if latest.IsAfter(&p.latest) {
|
||||
p.latest = latest
|
||||
}
|
||||
}
|
||||
|
||||
func (p *PartitionedStreamProvider) LatestPosition(
|
||||
ctx context.Context,
|
||||
) types.LogPosition {
|
||||
p.latestMutex.RLock()
|
||||
defer p.latestMutex.RUnlock()
|
||||
|
||||
return p.latest
|
||||
}
|
|
@ -140,6 +140,12 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
|
|||
// Extract values from request
|
||||
syncReq, err := newSyncRequest(req, *device, rp.db)
|
||||
if err != nil {
|
||||
if err == types.ErrMalformedSyncToken {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.InvalidArgumentValue(err.Error()),
|
||||
}
|
||||
}
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.Unknown(err.Error()),
|
||||
|
|
|
@ -42,11 +42,3 @@ type StreamProvider interface {
|
|||
// LatestPosition returns the latest stream position for this stream.
|
||||
LatestPosition(ctx context.Context) StreamPosition
|
||||
}
|
||||
|
||||
type PartitionedStreamProvider interface {
|
||||
Setup()
|
||||
Advance(latest LogPosition)
|
||||
CompleteSync(ctx context.Context, req *SyncRequest) LogPosition
|
||||
IncrementalSync(ctx context.Context, req *SyncRequest, from, to LogPosition) LogPosition
|
||||
LatestPosition(ctx context.Context) LogPosition
|
||||
}
|
||||
|
|
|
@ -16,6 +16,7 @@ package types
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
@ -26,13 +27,10 @@ import (
|
|||
)
|
||||
|
||||
var (
|
||||
// ErrInvalidSyncTokenType is returned when an attempt at creating a
|
||||
// new instance of SyncToken with an invalid type (i.e. neither "s"
|
||||
// nor "t").
|
||||
ErrInvalidSyncTokenType = fmt.Errorf("sync token has an unknown prefix (should be either s or t)")
|
||||
// ErrInvalidSyncTokenLen is returned when the pagination token is an
|
||||
// invalid length
|
||||
ErrInvalidSyncTokenLen = fmt.Errorf("sync token has an invalid length")
|
||||
// This error is returned when parsing sync tokens if the token is invalid. Callers can use this
|
||||
// error to detect whether to 400 or 401 the client. It is recommended to 401 them to force a
|
||||
// logout.
|
||||
ErrMalformedSyncToken = errors.New("malformed sync token")
|
||||
)
|
||||
|
||||
type StateDelta struct {
|
||||
|
@ -47,27 +45,6 @@ type StateDelta struct {
|
|||
// StreamPosition represents the offset in the sync stream a client is at.
|
||||
type StreamPosition int64
|
||||
|
||||
// LogPosition represents the offset in a Kafka log a client is at.
|
||||
type LogPosition struct {
|
||||
Partition int32
|
||||
Offset int64
|
||||
}
|
||||
|
||||
func (p *LogPosition) IsEmpty() bool {
|
||||
return p.Offset == 0
|
||||
}
|
||||
|
||||
// IsAfter returns true if this position is after `lp`.
|
||||
func (p *LogPosition) IsAfter(lp *LogPosition) bool {
|
||||
if lp == nil {
|
||||
return false
|
||||
}
|
||||
if p.Partition != lp.Partition {
|
||||
return false
|
||||
}
|
||||
return p.Offset > lp.Offset
|
||||
}
|
||||
|
||||
// StreamEvent is the same as gomatrixserverlib.Event but also has the PDU stream position for this event.
|
||||
type StreamEvent struct {
|
||||
*gomatrixserverlib.HeaderedEvent
|
||||
|
@ -124,7 +101,7 @@ type StreamingToken struct {
|
|||
SendToDevicePosition StreamPosition
|
||||
InvitePosition StreamPosition
|
||||
AccountDataPosition StreamPosition
|
||||
DeviceListPosition LogPosition
|
||||
DeviceListPosition StreamPosition
|
||||
}
|
||||
|
||||
// This will be used as a fallback by json.Marshal.
|
||||
|
@ -140,14 +117,11 @@ func (s *StreamingToken) UnmarshalText(text []byte) (err error) {
|
|||
|
||||
func (t StreamingToken) String() string {
|
||||
posStr := fmt.Sprintf(
|
||||
"s%d_%d_%d_%d_%d_%d",
|
||||
"s%d_%d_%d_%d_%d_%d_%d",
|
||||
t.PDUPosition, t.TypingPosition,
|
||||
t.ReceiptPosition, t.SendToDevicePosition,
|
||||
t.InvitePosition, t.AccountDataPosition,
|
||||
t.InvitePosition, t.AccountDataPosition, t.DeviceListPosition,
|
||||
)
|
||||
if dl := t.DeviceListPosition; !dl.IsEmpty() {
|
||||
posStr += fmt.Sprintf(".dl-%d-%d", dl.Partition, dl.Offset)
|
||||
}
|
||||
return posStr
|
||||
}
|
||||
|
||||
|
@ -166,14 +140,14 @@ func (t *StreamingToken) IsAfter(other StreamingToken) bool {
|
|||
return true
|
||||
case t.AccountDataPosition > other.AccountDataPosition:
|
||||
return true
|
||||
case t.DeviceListPosition.IsAfter(&other.DeviceListPosition):
|
||||
case t.DeviceListPosition > other.DeviceListPosition:
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (t *StreamingToken) IsEmpty() bool {
|
||||
return t == nil || t.PDUPosition+t.TypingPosition+t.ReceiptPosition+t.SendToDevicePosition+t.InvitePosition+t.AccountDataPosition == 0 && t.DeviceListPosition.IsEmpty()
|
||||
return t == nil || t.PDUPosition+t.TypingPosition+t.ReceiptPosition+t.SendToDevicePosition+t.InvitePosition+t.AccountDataPosition+t.DeviceListPosition == 0
|
||||
}
|
||||
|
||||
// WithUpdates returns a copy of the StreamingToken with updates applied from another StreamingToken.
|
||||
|
@ -208,7 +182,7 @@ func (t *StreamingToken) ApplyUpdates(other StreamingToken) {
|
|||
if other.AccountDataPosition > t.AccountDataPosition {
|
||||
t.AccountDataPosition = other.AccountDataPosition
|
||||
}
|
||||
if other.DeviceListPosition.IsAfter(&t.DeviceListPosition) {
|
||||
if other.DeviceListPosition > t.DeviceListPosition {
|
||||
t.DeviceListPosition = other.DeviceListPosition
|
||||
}
|
||||
}
|
||||
|
@ -292,16 +266,18 @@ func NewTopologyTokenFromString(tok string) (token TopologyToken, err error) {
|
|||
|
||||
func NewStreamTokenFromString(tok string) (token StreamingToken, err error) {
|
||||
if len(tok) < 1 {
|
||||
err = fmt.Errorf("empty stream token")
|
||||
err = ErrMalformedSyncToken
|
||||
return
|
||||
}
|
||||
if tok[0] != SyncTokenTypeStream[0] {
|
||||
err = fmt.Errorf("stream token must start with 's'")
|
||||
err = ErrMalformedSyncToken
|
||||
return
|
||||
}
|
||||
categories := strings.Split(tok[1:], ".")
|
||||
parts := strings.Split(categories[0], "_")
|
||||
var positions [6]StreamPosition
|
||||
// Migration: Remove everything after and including '.' - we previously had tokens like:
|
||||
// s478_0_0_0_0_13.dl-0-2 but we have now removed partitioned stream positions
|
||||
tok = strings.Split(tok, ".")[0]
|
||||
parts := strings.Split(tok[1:], "_")
|
||||
var positions [7]StreamPosition
|
||||
for i, p := range parts {
|
||||
if i > len(positions) {
|
||||
break
|
||||
|
@ -309,6 +285,7 @@ func NewStreamTokenFromString(tok string) (token StreamingToken, err error) {
|
|||
var pos int
|
||||
pos, err = strconv.Atoi(p)
|
||||
if err != nil {
|
||||
err = ErrMalformedSyncToken
|
||||
return
|
||||
}
|
||||
positions[i] = StreamPosition(pos)
|
||||
|
@ -320,31 +297,7 @@ func NewStreamTokenFromString(tok string) (token StreamingToken, err error) {
|
|||
SendToDevicePosition: positions[3],
|
||||
InvitePosition: positions[4],
|
||||
AccountDataPosition: positions[5],
|
||||
}
|
||||
// dl-0-1234
|
||||
// $log_name-$partition-$offset
|
||||
for _, logStr := range categories[1:] {
|
||||
segments := strings.Split(logStr, "-")
|
||||
if len(segments) != 3 {
|
||||
err = fmt.Errorf("invalid log position %q", logStr)
|
||||
return
|
||||
}
|
||||
switch segments[0] {
|
||||
case "dl":
|
||||
// Device list syncing
|
||||
var partition, offset int
|
||||
if partition, err = strconv.Atoi(segments[1]); err != nil {
|
||||
return
|
||||
}
|
||||
if offset, err = strconv.Atoi(segments[2]); err != nil {
|
||||
return
|
||||
}
|
||||
token.DeviceListPosition.Partition = int32(partition)
|
||||
token.DeviceListPosition.Offset = int64(offset)
|
||||
default:
|
||||
err = fmt.Errorf("unrecognised token type %q", segments[0])
|
||||
return
|
||||
}
|
||||
DeviceListPosition: positions[6],
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
|
|
@ -2,50 +2,17 @@ package types
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
func TestNewSyncTokenWithLogs(t *testing.T) {
|
||||
tests := map[string]*StreamingToken{
|
||||
"s4_0_0_0_0_0": {
|
||||
PDUPosition: 4,
|
||||
},
|
||||
"s4_0_0_0_0_0.dl-0-123": {
|
||||
PDUPosition: 4,
|
||||
DeviceListPosition: LogPosition{
|
||||
Partition: 0,
|
||||
Offset: 123,
|
||||
},
|
||||
},
|
||||
}
|
||||
for tok, want := range tests {
|
||||
got, err := NewStreamTokenFromString(tok)
|
||||
if err != nil {
|
||||
if want == nil {
|
||||
continue // error expected
|
||||
}
|
||||
t.Errorf("%s errored: %s", tok, err)
|
||||
continue
|
||||
}
|
||||
if !reflect.DeepEqual(got, *want) {
|
||||
t.Errorf("%s mismatch: got %v want %v", tok, got, want)
|
||||
}
|
||||
gotStr := got.String()
|
||||
if gotStr != tok {
|
||||
t.Errorf("%s reserialisation mismatch: got %s want %s", tok, gotStr, tok)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSyncTokens(t *testing.T) {
|
||||
shouldPass := map[string]string{
|
||||
"s4_0_0_0_0_0": StreamingToken{4, 0, 0, 0, 0, 0, LogPosition{}}.String(),
|
||||
"s3_1_0_0_0_0.dl-1-2": StreamingToken{3, 1, 0, 0, 0, 0, LogPosition{1, 2}}.String(),
|
||||
"s3_1_2_3_5_0": StreamingToken{3, 1, 2, 3, 5, 0, LogPosition{}}.String(),
|
||||
"t3_1": TopologyToken{3, 1}.String(),
|
||||
"s4_0_0_0_0_0_0": StreamingToken{4, 0, 0, 0, 0, 0, 0}.String(),
|
||||
"s3_1_0_0_0_0_2": StreamingToken{3, 1, 0, 0, 0, 0, 2}.String(),
|
||||
"s3_1_2_3_5_0_0": StreamingToken{3, 1, 2, 3, 5, 0, 0}.String(),
|
||||
"t3_1": TopologyToken{3, 1}.String(),
|
||||
}
|
||||
|
||||
for a, b := range shouldPass {
|
||||
|
|
|
@ -588,3 +588,4 @@ User can invite remote user to room with version 9
|
|||
Remote user can backfill in a room with version 9
|
||||
Can reject invites over federation for rooms with version 9
|
||||
Can receive redactions from regular users over federation in room version 9
|
||||
Forward extremities remain so even after the next events are populated as outliers
|
||||
|
|
Loading…
Reference in a new issue