From db7d9cba8ad28c300dd66c01b9b0ce2414e8cbe0 Mon Sep 17 00:00:00 2001 From: kegsay Date: Thu, 20 Jan 2022 15:26:45 +0000 Subject: [PATCH 1/6] BREAKING: Remove Partitioned Stream Positions (#2096) * go mod tidy * Break complement to check it fails CI * Remove partitioned stream positions This was used by the device list stream position. The device list position now corresponds to the `Offset`, and the partition is always 0, in prep for removing reliance on Kafka topics for device list changes. * Linting * Migrate old style tokens to new style because element-web doesn't soft-logoout on 4xx errors on /sync --- go.mod | 2 +- keyserver/api/api.go | 4 -- keyserver/internal/internal.go | 7 +-- keyserver/producers/keychange.go | 9 --- syncapi/consumers/keychange.go | 22 +++---- syncapi/internal/keychange.go | 40 +++++-------- syncapi/internal/keychange_test.go | 21 +++---- syncapi/streams/stream_devicelist.go | 8 +-- syncapi/streams/streams.go | 8 +-- syncapi/streams/template_pstream.go | 38 ------------ syncapi/sync/requestpool.go | 6 ++ syncapi/types/provider.go | 8 --- syncapi/types/types.go | 87 +++++++--------------------- syncapi/types/types_test.go | 41 ++----------- 14 files changed, 70 insertions(+), 231 deletions(-) delete mode 100644 syncapi/streams/template_pstream.go diff --git a/go.mod b/go.mod index eb421ce1..e6202f66 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/Arceliar/phony v0.0.0-20210209235338-dde1a8dca979 github.com/DATA-DOG/go-sqlmock v1.5.0 github.com/HdrHistogram/hdrhistogram-go v1.0.1 // indirect - github.com/MFAshby/stdemuxerhook v1.0.0 // indirect + github.com/MFAshby/stdemuxerhook v1.0.0 github.com/Masterminds/semver/v3 v3.1.1 github.com/S7evinK/saramajetstream v0.0.0-20210709110708-de6efc8c4a32 github.com/Shopify/sarama v1.29.0 diff --git a/keyserver/api/api.go b/keyserver/api/api.go index 5a109cc6..eae14ae1 100644 --- a/keyserver/api/api.go +++ b/keyserver/api/api.go @@ -224,8 +224,6 @@ type QueryKeysResponse struct { } type QueryKeyChangesRequest struct { - // The partition which had key events sent to - Partition int32 // The offset of the last received key event, or sarama.OffsetOldest if this is from the beginning Offset int64 // The inclusive offset where to track key changes up to. Messages with this offset are included in the response. @@ -236,8 +234,6 @@ type QueryKeyChangesRequest struct { type QueryKeyChangesResponse struct { // The set of users who have had their keys change. UserIDs []string - // The partition being served - useful if the partition is unknown at request time - Partition int32 // The latest offset represented in this response. Offset int64 // Set if there was a problem handling the request. diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 3e91962e..0140a8f4 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -59,17 +59,14 @@ func (a *KeyInternalAPI) InputDeviceListUpdate( } func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) { - if req.Partition < 0 { - req.Partition = a.Producer.DefaultPartition() - } - userIDs, latest, err := a.DB.KeyChanges(ctx, req.Partition, req.Offset, req.ToOffset) + partition := 0 + userIDs, latest, err := a.DB.KeyChanges(ctx, int32(partition), req.Offset, req.ToOffset) if err != nil { res.Error = &api.KeyError{ Err: err.Error(), } } res.Offset = latest - res.Partition = req.Partition res.UserIDs = userIDs } diff --git a/keyserver/producers/keychange.go b/keyserver/producers/keychange.go index 782675c2..c5ddf2c1 100644 --- a/keyserver/producers/keychange.go +++ b/keyserver/producers/keychange.go @@ -32,15 +32,6 @@ type KeyChange struct { DB storage.Database } -// DefaultPartition returns the default partition this process is sending key changes to. -// NB: A keyserver MUST send key changes to only 1 partition or else query operations will -// become inconsistent. Partitions can be sharded (e.g by hash of user ID of key change) but -// then all keyservers must be queried to calculate the entire set of key changes between -// two sync tokens. -func (p *KeyChange) DefaultPartition() int32 { - return 0 -} - // ProduceKeyChanges creates new change events for each key func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceMessage) error { userToDeviceCount := make(map[string]int) diff --git a/syncapi/consumers/keychange.go b/syncapi/consumers/keychange.go index 76b143d8..d63e4832 100644 --- a/syncapi/consumers/keychange.go +++ b/syncapi/consumers/keychange.go @@ -38,7 +38,7 @@ type OutputKeyChangeEventConsumer struct { keyChangeConsumer *internal.ContinualConsumer db storage.Database notifier *notifier.Notifier - stream types.PartitionedStreamProvider + stream types.StreamProvider serverName gomatrixserverlib.ServerName // our server name rsAPI roomserverAPI.RoomserverInternalAPI keyAPI api.KeyInternalAPI @@ -57,7 +57,7 @@ func NewOutputKeyChangeEventConsumer( rsAPI roomserverAPI.RoomserverInternalAPI, store storage.Database, notifier *notifier.Notifier, - stream types.PartitionedStreamProvider, + stream types.StreamProvider, ) *OutputKeyChangeEventConsumer { consumer := internal.ContinualConsumer{ @@ -118,15 +118,15 @@ func (s *OutputKeyChangeEventConsumer) onMessage(msg *sarama.ConsumerMessage) er } switch m.Type { case api.TypeCrossSigningUpdate: - return s.onCrossSigningMessage(m, msg.Offset, msg.Partition) + return s.onCrossSigningMessage(m, msg.Offset) case api.TypeDeviceKeyUpdate: fallthrough default: - return s.onDeviceKeyMessage(m, msg.Offset, msg.Partition) + return s.onDeviceKeyMessage(m, msg.Offset) } } -func (s *OutputKeyChangeEventConsumer) onDeviceKeyMessage(m api.DeviceMessage, offset int64, partition int32) error { +func (s *OutputKeyChangeEventConsumer) onDeviceKeyMessage(m api.DeviceMessage, offset int64) error { if m.DeviceKeys == nil { return nil } @@ -143,10 +143,7 @@ func (s *OutputKeyChangeEventConsumer) onDeviceKeyMessage(m api.DeviceMessage, o } // make sure we get our own key updates too! queryRes.UserIDsToCount[output.UserID] = 1 - posUpdate := types.LogPosition{ - Offset: offset, - Partition: partition, - } + posUpdate := types.StreamPosition(offset) s.stream.Advance(posUpdate) for userID := range queryRes.UserIDsToCount { @@ -156,7 +153,7 @@ func (s *OutputKeyChangeEventConsumer) onDeviceKeyMessage(m api.DeviceMessage, o return nil } -func (s *OutputKeyChangeEventConsumer) onCrossSigningMessage(m api.DeviceMessage, offset int64, partition int32) error { +func (s *OutputKeyChangeEventConsumer) onCrossSigningMessage(m api.DeviceMessage, offset int64) error { output := m.CrossSigningKeyUpdate // work out who we need to notify about the new key var queryRes roomserverAPI.QuerySharedUsersResponse @@ -170,10 +167,7 @@ func (s *OutputKeyChangeEventConsumer) onCrossSigningMessage(m api.DeviceMessage } // make sure we get our own key updates too! queryRes.UserIDsToCount[output.UserID] = 1 - posUpdate := types.LogPosition{ - Offset: offset, - Partition: partition, - } + posUpdate := types.StreamPosition(offset) s.stream.Advance(posUpdate) for userID := range queryRes.UserIDsToCount { diff --git a/syncapi/internal/keychange.go b/syncapi/internal/keychange.go index 56a438fb..41efd4a0 100644 --- a/syncapi/internal/keychange.go +++ b/syncapi/internal/keychange.go @@ -47,8 +47,8 @@ func DeviceOTKCounts(ctx context.Context, keyAPI keyapi.KeyInternalAPI, userID, // be already filled in with join/leave information. func DeviceListCatchup( ctx context.Context, keyAPI keyapi.KeyInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI, - userID string, res *types.Response, from, to types.LogPosition, -) (newPos types.LogPosition, hasNew bool, err error) { + userID string, res *types.Response, from, to types.StreamPosition, +) (newPos types.StreamPosition, hasNew bool, err error) { // Track users who we didn't track before but now do by virtue of sharing a room with them, or not. newlyJoinedRooms := joinedRooms(res, userID) @@ -64,27 +64,18 @@ func DeviceListCatchup( } // now also track users who we already share rooms with but who have updated their devices between the two tokens - - var partition int32 - var offset int64 - partition = -1 - offset = sarama.OffsetOldest - // Extract partition/offset from sync token - // TODO: In a world where keyserver is sharded there will be multiple partitions and hence multiple QueryKeyChanges to make. - if !from.IsEmpty() { - partition = from.Partition - offset = from.Offset + offset := sarama.OffsetOldest + toOffset := sarama.OffsetNewest + if to > 0 && to > from { + toOffset = int64(to) } - var toOffset int64 - toOffset = sarama.OffsetNewest - if toLog := to; toLog.Partition == partition && toLog.Offset > 0 { - toOffset = toLog.Offset + if from > 0 { + offset = int64(from) } var queryRes keyapi.QueryKeyChangesResponse keyAPI.QueryKeyChanges(ctx, &keyapi.QueryKeyChangesRequest{ - Partition: partition, - Offset: offset, - ToOffset: toOffset, + Offset: offset, + ToOffset: toOffset, }, &queryRes) if queryRes.Error != nil { // don't fail the catchup because we may have got useful information by tracking membership @@ -95,8 +86,8 @@ func DeviceListCatchup( var sharedUsersMap map[string]int sharedUsersMap, queryRes.UserIDs = filterSharedUsers(ctx, rsAPI, userID, queryRes.UserIDs) util.GetLogger(ctx).Debugf( - "QueryKeyChanges request p=%d,off=%d,to=%d response p=%d off=%d uids=%v", - partition, offset, toOffset, queryRes.Partition, queryRes.Offset, queryRes.UserIDs, + "QueryKeyChanges request off=%d,to=%d response off=%d uids=%v", + offset, toOffset, queryRes.Offset, queryRes.UserIDs, ) userSet := make(map[string]bool) for _, userID := range res.DeviceLists.Changed { @@ -125,13 +116,8 @@ func DeviceListCatchup( res.DeviceLists.Left = append(res.DeviceLists.Left, userID) } } - // set the new token - to = types.LogPosition{ - Partition: queryRes.Partition, - Offset: queryRes.Offset, - } - return to, hasNew, nil + return types.StreamPosition(queryRes.Offset), hasNew, nil } // TrackChangedUsers calculates the values of device_lists.changed|left in the /sync response. diff --git a/syncapi/internal/keychange_test.go b/syncapi/internal/keychange_test.go index e52e5556..d9fb9cf8 100644 --- a/syncapi/internal/keychange_test.go +++ b/syncapi/internal/keychange_test.go @@ -6,7 +6,6 @@ import ( "sort" "testing" - "github.com/Shopify/sarama" keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/syncapi/types" @@ -16,11 +15,7 @@ import ( var ( syncingUser = "@alice:localhost" - emptyToken = types.LogPosition{} - newestToken = types.LogPosition{ - Offset: sarama.OffsetNewest, - Partition: 0, - } + emptyToken = types.StreamPosition(0) ) type mockKeyAPI struct{} @@ -186,7 +181,7 @@ func TestKeyChangeCatchupOnJoinShareNewUser(t *testing.T) { "!another:room": {syncingUser}, }, } - _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, newestToken) + _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken) if err != nil { t.Fatalf("DeviceListCatchup returned an error: %s", err) } @@ -209,7 +204,7 @@ func TestKeyChangeCatchupOnLeaveShareLeftUser(t *testing.T) { "!another:room": {syncingUser}, }, } - _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, newestToken) + _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken) if err != nil { t.Fatalf("DeviceListCatchup returned an error: %s", err) } @@ -232,7 +227,7 @@ func TestKeyChangeCatchupOnJoinShareNoNewUsers(t *testing.T) { "!another:room": {syncingUser, existingUser}, }, } - _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, newestToken) + _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken) if err != nil { t.Fatalf("Catchup returned an error: %s", err) } @@ -254,7 +249,7 @@ func TestKeyChangeCatchupOnLeaveShareNoUsers(t *testing.T) { "!another:room": {syncingUser, existingUser}, }, } - _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, newestToken) + _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken) if err != nil { t.Fatalf("DeviceListCatchup returned an error: %s", err) } @@ -313,7 +308,7 @@ func TestKeyChangeCatchupNoNewJoinsButMessages(t *testing.T) { roomID: {syncingUser, existingUser}, }, } - _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, newestToken) + _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken) if err != nil { t.Fatalf("DeviceListCatchup returned an error: %s", err) } @@ -341,7 +336,7 @@ func TestKeyChangeCatchupChangeAndLeft(t *testing.T) { "!another:room": {syncingUser}, }, } - _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, newestToken) + _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken) if err != nil { t.Fatalf("Catchup returned an error: %s", err) } @@ -427,7 +422,7 @@ func TestKeyChangeCatchupChangeAndLeftSameRoom(t *testing.T) { }, } _, hasNew, err := DeviceListCatchup( - context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, newestToken, + context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken, ) if err != nil { t.Fatalf("DeviceListCatchup returned an error: %s", err) diff --git a/syncapi/streams/stream_devicelist.go b/syncapi/streams/stream_devicelist.go index 9ea9d088..6ff8a7fd 100644 --- a/syncapi/streams/stream_devicelist.go +++ b/syncapi/streams/stream_devicelist.go @@ -10,7 +10,7 @@ import ( ) type DeviceListStreamProvider struct { - PartitionedStreamProvider + StreamProvider rsAPI api.RoomserverInternalAPI keyAPI keyapi.KeyInternalAPI } @@ -18,15 +18,15 @@ type DeviceListStreamProvider struct { func (p *DeviceListStreamProvider) CompleteSync( ctx context.Context, req *types.SyncRequest, -) types.LogPosition { +) types.StreamPosition { return p.LatestPosition(ctx) } func (p *DeviceListStreamProvider) IncrementalSync( ctx context.Context, req *types.SyncRequest, - from, to types.LogPosition, -) types.LogPosition { + from, to types.StreamPosition, +) types.StreamPosition { var err error to, _, err = internal.DeviceListCatchup(context.Background(), p.keyAPI, p.rsAPI, req.Device.UserID, req.Response, from, to) if err != nil { diff --git a/syncapi/streams/streams.go b/syncapi/streams/streams.go index ba4118df..6b02c75e 100644 --- a/syncapi/streams/streams.go +++ b/syncapi/streams/streams.go @@ -18,7 +18,7 @@ type Streams struct { InviteStreamProvider types.StreamProvider SendToDeviceStreamProvider types.StreamProvider AccountDataStreamProvider types.StreamProvider - DeviceListStreamProvider types.PartitionedStreamProvider + DeviceListStreamProvider types.StreamProvider } func NewSyncStreamProviders( @@ -48,9 +48,9 @@ func NewSyncStreamProviders( userAPI: userAPI, }, DeviceListStreamProvider: &DeviceListStreamProvider{ - PartitionedStreamProvider: PartitionedStreamProvider{DB: d}, - rsAPI: rsAPI, - keyAPI: keyAPI, + StreamProvider: StreamProvider{DB: d}, + rsAPI: rsAPI, + keyAPI: keyAPI, }, } diff --git a/syncapi/streams/template_pstream.go b/syncapi/streams/template_pstream.go deleted file mode 100644 index 265e22a2..00000000 --- a/syncapi/streams/template_pstream.go +++ /dev/null @@ -1,38 +0,0 @@ -package streams - -import ( - "context" - "sync" - - "github.com/matrix-org/dendrite/syncapi/storage" - "github.com/matrix-org/dendrite/syncapi/types" -) - -type PartitionedStreamProvider struct { - DB storage.Database - latest types.LogPosition - latestMutex sync.RWMutex -} - -func (p *PartitionedStreamProvider) Setup() { -} - -func (p *PartitionedStreamProvider) Advance( - latest types.LogPosition, -) { - p.latestMutex.Lock() - defer p.latestMutex.Unlock() - - if latest.IsAfter(&p.latest) { - p.latest = latest - } -} - -func (p *PartitionedStreamProvider) LatestPosition( - ctx context.Context, -) types.LogPosition { - p.latestMutex.RLock() - defer p.latestMutex.RUnlock() - - return p.latest -} diff --git a/syncapi/sync/requestpool.go b/syncapi/sync/requestpool.go index a4573610..ca35951a 100644 --- a/syncapi/sync/requestpool.go +++ b/syncapi/sync/requestpool.go @@ -140,6 +140,12 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi. // Extract values from request syncReq, err := newSyncRequest(req, *device, rp.db) if err != nil { + if err == types.ErrMalformedSyncToken { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidArgumentValue(err.Error()), + } + } return util.JSONResponse{ Code: http.StatusBadRequest, JSON: jsonerror.Unknown(err.Error()), diff --git a/syncapi/types/provider.go b/syncapi/types/provider.go index 93ed1266..f6185fcb 100644 --- a/syncapi/types/provider.go +++ b/syncapi/types/provider.go @@ -42,11 +42,3 @@ type StreamProvider interface { // LatestPosition returns the latest stream position for this stream. LatestPosition(ctx context.Context) StreamPosition } - -type PartitionedStreamProvider interface { - Setup() - Advance(latest LogPosition) - CompleteSync(ctx context.Context, req *SyncRequest) LogPosition - IncrementalSync(ctx context.Context, req *SyncRequest, from, to LogPosition) LogPosition - LatestPosition(ctx context.Context) LogPosition -} diff --git a/syncapi/types/types.go b/syncapi/types/types.go index 44e718b3..68c308d8 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -16,6 +16,7 @@ package types import ( "encoding/json" + "errors" "fmt" "strconv" "strings" @@ -26,13 +27,10 @@ import ( ) var ( - // ErrInvalidSyncTokenType is returned when an attempt at creating a - // new instance of SyncToken with an invalid type (i.e. neither "s" - // nor "t"). - ErrInvalidSyncTokenType = fmt.Errorf("sync token has an unknown prefix (should be either s or t)") - // ErrInvalidSyncTokenLen is returned when the pagination token is an - // invalid length - ErrInvalidSyncTokenLen = fmt.Errorf("sync token has an invalid length") + // This error is returned when parsing sync tokens if the token is invalid. Callers can use this + // error to detect whether to 400 or 401 the client. It is recommended to 401 them to force a + // logout. + ErrMalformedSyncToken = errors.New("malformed sync token") ) type StateDelta struct { @@ -47,27 +45,6 @@ type StateDelta struct { // StreamPosition represents the offset in the sync stream a client is at. type StreamPosition int64 -// LogPosition represents the offset in a Kafka log a client is at. -type LogPosition struct { - Partition int32 - Offset int64 -} - -func (p *LogPosition) IsEmpty() bool { - return p.Offset == 0 -} - -// IsAfter returns true if this position is after `lp`. -func (p *LogPosition) IsAfter(lp *LogPosition) bool { - if lp == nil { - return false - } - if p.Partition != lp.Partition { - return false - } - return p.Offset > lp.Offset -} - // StreamEvent is the same as gomatrixserverlib.Event but also has the PDU stream position for this event. type StreamEvent struct { *gomatrixserverlib.HeaderedEvent @@ -124,7 +101,7 @@ type StreamingToken struct { SendToDevicePosition StreamPosition InvitePosition StreamPosition AccountDataPosition StreamPosition - DeviceListPosition LogPosition + DeviceListPosition StreamPosition } // This will be used as a fallback by json.Marshal. @@ -140,14 +117,11 @@ func (s *StreamingToken) UnmarshalText(text []byte) (err error) { func (t StreamingToken) String() string { posStr := fmt.Sprintf( - "s%d_%d_%d_%d_%d_%d", + "s%d_%d_%d_%d_%d_%d_%d", t.PDUPosition, t.TypingPosition, t.ReceiptPosition, t.SendToDevicePosition, - t.InvitePosition, t.AccountDataPosition, + t.InvitePosition, t.AccountDataPosition, t.DeviceListPosition, ) - if dl := t.DeviceListPosition; !dl.IsEmpty() { - posStr += fmt.Sprintf(".dl-%d-%d", dl.Partition, dl.Offset) - } return posStr } @@ -166,14 +140,14 @@ func (t *StreamingToken) IsAfter(other StreamingToken) bool { return true case t.AccountDataPosition > other.AccountDataPosition: return true - case t.DeviceListPosition.IsAfter(&other.DeviceListPosition): + case t.DeviceListPosition > other.DeviceListPosition: return true } return false } func (t *StreamingToken) IsEmpty() bool { - return t == nil || t.PDUPosition+t.TypingPosition+t.ReceiptPosition+t.SendToDevicePosition+t.InvitePosition+t.AccountDataPosition == 0 && t.DeviceListPosition.IsEmpty() + return t == nil || t.PDUPosition+t.TypingPosition+t.ReceiptPosition+t.SendToDevicePosition+t.InvitePosition+t.AccountDataPosition+t.DeviceListPosition == 0 } // WithUpdates returns a copy of the StreamingToken with updates applied from another StreamingToken. @@ -208,7 +182,7 @@ func (t *StreamingToken) ApplyUpdates(other StreamingToken) { if other.AccountDataPosition > t.AccountDataPosition { t.AccountDataPosition = other.AccountDataPosition } - if other.DeviceListPosition.IsAfter(&t.DeviceListPosition) { + if other.DeviceListPosition > t.DeviceListPosition { t.DeviceListPosition = other.DeviceListPosition } } @@ -292,16 +266,18 @@ func NewTopologyTokenFromString(tok string) (token TopologyToken, err error) { func NewStreamTokenFromString(tok string) (token StreamingToken, err error) { if len(tok) < 1 { - err = fmt.Errorf("empty stream token") + err = ErrMalformedSyncToken return } if tok[0] != SyncTokenTypeStream[0] { - err = fmt.Errorf("stream token must start with 's'") + err = ErrMalformedSyncToken return } - categories := strings.Split(tok[1:], ".") - parts := strings.Split(categories[0], "_") - var positions [6]StreamPosition + // Migration: Remove everything after and including '.' - we previously had tokens like: + // s478_0_0_0_0_13.dl-0-2 but we have now removed partitioned stream positions + tok = strings.Split(tok, ".")[0] + parts := strings.Split(tok[1:], "_") + var positions [7]StreamPosition for i, p := range parts { if i > len(positions) { break @@ -309,6 +285,7 @@ func NewStreamTokenFromString(tok string) (token StreamingToken, err error) { var pos int pos, err = strconv.Atoi(p) if err != nil { + err = ErrMalformedSyncToken return } positions[i] = StreamPosition(pos) @@ -320,31 +297,7 @@ func NewStreamTokenFromString(tok string) (token StreamingToken, err error) { SendToDevicePosition: positions[3], InvitePosition: positions[4], AccountDataPosition: positions[5], - } - // dl-0-1234 - // $log_name-$partition-$offset - for _, logStr := range categories[1:] { - segments := strings.Split(logStr, "-") - if len(segments) != 3 { - err = fmt.Errorf("invalid log position %q", logStr) - return - } - switch segments[0] { - case "dl": - // Device list syncing - var partition, offset int - if partition, err = strconv.Atoi(segments[1]); err != nil { - return - } - if offset, err = strconv.Atoi(segments[2]); err != nil { - return - } - token.DeviceListPosition.Partition = int32(partition) - token.DeviceListPosition.Offset = int64(offset) - default: - err = fmt.Errorf("unrecognised token type %q", segments[0]) - return - } + DeviceListPosition: positions[6], } return token, nil } diff --git a/syncapi/types/types_test.go b/syncapi/types/types_test.go index 3e577788..cda178b3 100644 --- a/syncapi/types/types_test.go +++ b/syncapi/types/types_test.go @@ -2,50 +2,17 @@ package types import ( "encoding/json" - "reflect" "testing" "github.com/matrix-org/gomatrixserverlib" ) -func TestNewSyncTokenWithLogs(t *testing.T) { - tests := map[string]*StreamingToken{ - "s4_0_0_0_0_0": { - PDUPosition: 4, - }, - "s4_0_0_0_0_0.dl-0-123": { - PDUPosition: 4, - DeviceListPosition: LogPosition{ - Partition: 0, - Offset: 123, - }, - }, - } - for tok, want := range tests { - got, err := NewStreamTokenFromString(tok) - if err != nil { - if want == nil { - continue // error expected - } - t.Errorf("%s errored: %s", tok, err) - continue - } - if !reflect.DeepEqual(got, *want) { - t.Errorf("%s mismatch: got %v want %v", tok, got, want) - } - gotStr := got.String() - if gotStr != tok { - t.Errorf("%s reserialisation mismatch: got %s want %s", tok, gotStr, tok) - } - } -} - func TestSyncTokens(t *testing.T) { shouldPass := map[string]string{ - "s4_0_0_0_0_0": StreamingToken{4, 0, 0, 0, 0, 0, LogPosition{}}.String(), - "s3_1_0_0_0_0.dl-1-2": StreamingToken{3, 1, 0, 0, 0, 0, LogPosition{1, 2}}.String(), - "s3_1_2_3_5_0": StreamingToken{3, 1, 2, 3, 5, 0, LogPosition{}}.String(), - "t3_1": TopologyToken{3, 1}.String(), + "s4_0_0_0_0_0_0": StreamingToken{4, 0, 0, 0, 0, 0, 0}.String(), + "s3_1_0_0_0_0_2": StreamingToken{3, 1, 0, 0, 0, 0, 2}.String(), + "s3_1_2_3_5_0_0": StreamingToken{3, 1, 2, 3, 5, 0, 0}.String(), + "t3_1": TopologyToken{3, 1}.String(), } for a, b := range shouldPass { From 2c581377a5f6b243300445947fd0fd38a919873d Mon Sep 17 00:00:00 2001 From: kegsay Date: Fri, 21 Jan 2022 09:56:06 +0000 Subject: [PATCH 2/6] Remodel how device list change IDs are created (#2098) * Remodel how device list change IDs are created Previously we made them using the offset Kafka supplied. We don't run Kafka anymore, so now we make the SQL table assign the change ID via an AUTOINCREMENTing ID. Redesign the `keyserver_key_changes` table to have `UNIQUE(user_id)` so we don't accumulate key changes forevermore, we now have at most 1 row per user which contains the highest change ID. This needs a SQL migration. * Ensure we bump the change ID on sqlite * Actually read the DeviceChangeID not the Offset in synapi * Add SQL migrations * Prepare after migration; fixup dendrite-upgrade-test logging * Use higher version numbers; fix sqlite query to increment better * Default 0 on postgres * fixup postgres migration on fresh dendrite instances --- cmd/dendrite-upgrade-tests/main.go | 6 +- keyserver/api/api.go | 3 +- keyserver/internal/internal.go | 3 +- keyserver/keyserver.go | 8 +- keyserver/producers/keychange.go | 54 +++++++------ keyserver/storage/interface.go | 8 +- .../deltas/2022012016470000_key_changes.go | 79 +++++++++++++++++++ .../storage/postgres/key_changes_table.go | 55 +++++++------ keyserver/storage/postgres/storage.go | 9 +++ keyserver/storage/shared/storage.go | 12 +-- .../deltas/2022012016470000_key_changes.go | 76 ++++++++++++++++++ .../storage/sqlite3/key_changes_table.go | 52 ++++++------ keyserver/storage/sqlite3/storage.go | 10 +++ keyserver/storage/storage_test.go | 48 ++++++----- keyserver/storage/tables/interface.go | 6 +- syncapi/consumers/keychange.go | 65 ++++++--------- 16 files changed, 336 insertions(+), 158 deletions(-) create mode 100644 keyserver/storage/postgres/deltas/2022012016470000_key_changes.go create mode 100644 keyserver/storage/sqlite3/deltas/2022012016470000_key_changes.go diff --git a/cmd/dendrite-upgrade-tests/main.go b/cmd/dendrite-upgrade-tests/main.go index 35852e8a..3241234a 100644 --- a/cmd/dendrite-upgrade-tests/main.go +++ b/cmd/dendrite-upgrade-tests/main.go @@ -189,7 +189,9 @@ func buildDendrite(httpClient *http.Client, dockerClient *client.Client, tmpDir, if err := decoder.Decode(&dl); err != nil { return "", fmt.Errorf("failed to decode build image output line: %w", err) } - log.Printf("%s: %s", branchOrTagName, dl.Stream) + if len(strings.TrimSpace(dl.Stream)) > 0 { + log.Printf("%s: %s", branchOrTagName, dl.Stream) + } if dl.Aux != nil { imgID, ok := dl.Aux["ID"] if ok { @@ -425,8 +427,10 @@ func cleanup(dockerClient *client.Client) { // ignore all errors, we are just cleaning up and don't want to fail just because we fail to cleanup containers, _ := dockerClient.ContainerList(context.Background(), types.ContainerListOptions{ Filters: label(dendriteUpgradeTestLabel), + All: true, }) for _, c := range containers { + log.Printf("Removing container: %v %v\n", c.ID, c.Names) s := time.Second _ = dockerClient.ContainerStop(context.Background(), c.ID, &s) _ = dockerClient.ContainerRemove(context.Background(), c.ID, types.ContainerRemoveOptions{ diff --git a/keyserver/api/api.go b/keyserver/api/api.go index eae14ae1..0eea2f0f 100644 --- a/keyserver/api/api.go +++ b/keyserver/api/api.go @@ -69,7 +69,8 @@ type DeviceMessage struct { *DeviceKeys `json:"DeviceKeys,omitempty"` *eduapi.OutputCrossSigningKeyUpdate `json:"CrossSigningKeyUpdate,omitempty"` // A monotonically increasing number which represents device changes for this user. - StreamID int + StreamID int + DeviceChangeID int64 } // DeviceKeys represents a set of device keys for a single device diff --git a/keyserver/internal/internal.go b/keyserver/internal/internal.go index 0140a8f4..25924921 100644 --- a/keyserver/internal/internal.go +++ b/keyserver/internal/internal.go @@ -59,8 +59,7 @@ func (a *KeyInternalAPI) InputDeviceListUpdate( } func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) { - partition := 0 - userIDs, latest, err := a.DB.KeyChanges(ctx, int32(partition), req.Offset, req.ToOffset) + userIDs, latest, err := a.DB.KeyChanges(ctx, req.Offset, req.ToOffset) if err != nil { res.Error = &api.KeyError{ Err: err.Error(), diff --git a/keyserver/keyserver.go b/keyserver/keyserver.go index 03a221a6..8cc50ea0 100644 --- a/keyserver/keyserver.go +++ b/keyserver/keyserver.go @@ -40,16 +40,16 @@ func AddInternalRoutes(router *mux.Router, intAPI api.KeyInternalAPI) { func NewInternalAPI( base *base.BaseDendrite, cfg *config.KeyServer, fedClient fedsenderapi.FederationClient, ) api.KeyInternalAPI { - _, consumer, producer := jetstream.Prepare(&cfg.Matrix.JetStream) + js, consumer, _ := jetstream.Prepare(&cfg.Matrix.JetStream) db, err := storage.NewDatabase(&cfg.Database) if err != nil { logrus.WithError(err).Panicf("failed to connect to key server database") } keyChangeProducer := &producers.KeyChange{ - Topic: string(cfg.Matrix.JetStream.TopicFor(jetstream.OutputKeyChangeEvent)), - Producer: producer, - DB: db, + Topic: string(cfg.Matrix.JetStream.TopicFor(jetstream.OutputKeyChangeEvent)), + JetStream: js, + DB: db, } ap := &internal.KeyInternalAPI{ DB: db, diff --git a/keyserver/producers/keychange.go b/keyserver/producers/keychange.go index c5ddf2c1..fd143c6c 100644 --- a/keyserver/producers/keychange.go +++ b/keyserver/producers/keychange.go @@ -18,43 +18,47 @@ import ( "context" "encoding/json" - "github.com/Shopify/sarama" eduapi "github.com/matrix-org/dendrite/eduserver/api" "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/keyserver/storage" + "github.com/matrix-org/dendrite/setup/jetstream" + "github.com/nats-io/nats.go" "github.com/sirupsen/logrus" ) // KeyChange produces key change events for the sync API and federation sender to consume type KeyChange struct { - Topic string - Producer sarama.SyncProducer - DB storage.Database + Topic string + JetStream nats.JetStreamContext + DB storage.Database } // ProduceKeyChanges creates new change events for each key func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceMessage) error { userToDeviceCount := make(map[string]int) for _, key := range keys { - var m sarama.ProducerMessage - + id, err := p.DB.StoreKeyChange(context.Background(), key.UserID) + if err != nil { + return err + } + key.DeviceChangeID = id value, err := json.Marshal(key) if err != nil { return err } - m.Topic = string(p.Topic) - m.Key = sarama.StringEncoder(key.UserID) - m.Value = sarama.ByteEncoder(value) + m := &nats.Msg{ + Subject: p.Topic, + Header: nats.Header{}, + } + m.Header.Set(jetstream.UserID, key.UserID) + m.Data = value - partition, offset, err := p.Producer.SendMessage(&m) - if err != nil { - return err - } - err = p.DB.StoreKeyChange(context.Background(), partition, offset, key.UserID) + _, err = p.JetStream.PublishMsg(m) if err != nil { return err } + userToDeviceCount[key.UserID]++ } for userID, count := range userToDeviceCount { @@ -67,7 +71,6 @@ func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceMessage) error { } func (p *KeyChange) ProduceSigningKeyUpdate(key eduapi.CrossSigningKeyUpdate) error { - var m sarama.ProducerMessage output := &api.DeviceMessage{ Type: api.TypeCrossSigningUpdate, OutputCrossSigningKeyUpdate: &eduapi.OutputCrossSigningKeyUpdate{ @@ -75,20 +78,25 @@ func (p *KeyChange) ProduceSigningKeyUpdate(key eduapi.CrossSigningKeyUpdate) er }, } + id, err := p.DB.StoreKeyChange(context.Background(), key.UserID) + if err != nil { + return err + } + output.DeviceChangeID = id + value, err := json.Marshal(output) if err != nil { return err } - m.Topic = string(p.Topic) - m.Key = sarama.StringEncoder(key.UserID) - m.Value = sarama.ByteEncoder(value) - - partition, offset, err := p.Producer.SendMessage(&m) - if err != nil { - return err + m := &nats.Msg{ + Subject: p.Topic, + Header: nats.Header{}, } - err = p.DB.StoreKeyChange(context.Background(), partition, offset, key.UserID) + m.Header.Set(jetstream.UserID, key.UserID) + m.Data = value + + _, err = p.JetStream.PublishMsg(m) if err != nil { return err } diff --git a/keyserver/storage/interface.go b/keyserver/storage/interface.go index 99842bc5..87feae47 100644 --- a/keyserver/storage/interface.go +++ b/keyserver/storage/interface.go @@ -66,14 +66,14 @@ type Database interface { // cannot be claimed or if none exist for this (user, device, algorithm), instead it is omitted from the returned slice. ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[string]map[string]string) ([]api.OneTimeKeys, error) - // StoreKeyChange stores key change metadata after the change has been sent to Kafka. `userID` is the the user who has changed - // their keys in some way. - StoreKeyChange(ctx context.Context, partition int32, offset int64, userID string) error + // StoreKeyChange stores key change metadata and returns the device change ID which represents the position in the /sync stream for this device change. + // `userID` is the the user who has changed their keys in some way. + StoreKeyChange(ctx context.Context, userID string) (int64, error) // KeyChanges returns a list of user IDs who have modified their keys from the offset given (exclusive) to the offset given (inclusive). // A to offset of sarama.OffsetNewest means no upper limit. // Returns the offset of the latest key change. - KeyChanges(ctx context.Context, partition int32, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) + KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. // If no domains are given, all user IDs with stale device lists are returned. diff --git a/keyserver/storage/postgres/deltas/2022012016470000_key_changes.go b/keyserver/storage/postgres/deltas/2022012016470000_key_changes.go new file mode 100644 index 00000000..e5bcf08d --- /dev/null +++ b/keyserver/storage/postgres/deltas/2022012016470000_key_changes.go @@ -0,0 +1,79 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package deltas + +import ( + "database/sql" + "fmt" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/pressly/goose" +) + +func LoadFromGoose() { + goose.AddMigration(UpRefactorKeyChanges, DownRefactorKeyChanges) +} + +func LoadRefactorKeyChanges(m *sqlutil.Migrations) { + m.AddMigration(UpRefactorKeyChanges, DownRefactorKeyChanges) +} + +func UpRefactorKeyChanges(tx *sql.Tx) error { + // start counting from the last max offset, else 0. We need to do a count(*) first to see if there + // even are entries in this table to know if we can query for log_offset. Without the count then + // the query to SELECT the max log offset fails on new Dendrite instances as log_offset doesn't + // exist on that table. Even though we discard the error, the txn is tainted and gets aborted :/ + var count int + _ = tx.QueryRow(`SELECT count(*) FROM keyserver_key_changes`).Scan(&count) + if count > 0 { + var maxOffset int64 + _ = tx.QueryRow(`SELECT coalesce(MAX(log_offset), 0) AS offset FROM keyserver_key_changes`).Scan(&maxOffset) + if _, err := tx.Exec(fmt.Sprintf(`CREATE SEQUENCE IF NOT EXISTS keyserver_key_changes_seq START %d`, maxOffset)); err != nil { + return fmt.Errorf("failed to CREATE SEQUENCE for key changes, starting at %d: %s", maxOffset, err) + } + } + + _, err := tx.Exec(` + -- make the new table + DROP TABLE IF EXISTS keyserver_key_changes; + CREATE TABLE IF NOT EXISTS keyserver_key_changes ( + change_id BIGINT PRIMARY KEY DEFAULT nextval('keyserver_key_changes_seq'), + user_id TEXT NOT NULL, + CONSTRAINT keyserver_key_changes_unique_per_user UNIQUE (user_id) + ); + `) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + return nil +} + +func DownRefactorKeyChanges(tx *sql.Tx) error { + _, err := tx.Exec(` + -- Drop all data and revert back, we can't keep the data as Kafka offsets determine the numbers + DROP SEQUENCE IF EXISTS keyserver_key_changes_seq; + DROP TABLE IF EXISTS keyserver_key_changes; + CREATE TABLE IF NOT EXISTS keyserver_key_changes ( + partition BIGINT NOT NULL, + log_offset BIGINT NOT NULL, + user_id TEXT NOT NULL, + CONSTRAINT keyserver_key_changes_unique UNIQUE (partition, log_offset) + ); + `) + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + return nil +} diff --git a/keyserver/storage/postgres/key_changes_table.go b/keyserver/storage/postgres/key_changes_table.go index df4b47e7..20d227c2 100644 --- a/keyserver/storage/postgres/key_changes_table.go +++ b/keyserver/storage/postgres/key_changes_table.go @@ -26,27 +26,25 @@ import ( var keyChangesSchema = ` -- Stores key change information about users. Used to determine when to send updated device lists to clients. +CREATE SEQUENCE IF NOT EXISTS keyserver_key_changes_seq; CREATE TABLE IF NOT EXISTS keyserver_key_changes ( - partition BIGINT NOT NULL, - log_offset BIGINT NOT NULL, + change_id BIGINT PRIMARY KEY DEFAULT nextval('keyserver_key_changes_seq'), user_id TEXT NOT NULL, - CONSTRAINT keyserver_key_changes_unique UNIQUE (partition, log_offset) + CONSTRAINT keyserver_key_changes_unique_per_user UNIQUE (user_id) ); ` -// Replace based on partition|offset - we should never insert duplicates unless the kafka logs are wiped. -// Rather than falling over, just overwrite (though this will mean clients with an existing sync token will -// miss out on updates). TODO: Ideally we would detect when kafka logs are purged then purge this table too. +// Replace based on user ID. We don't care how many times the user's keys have changed, only that they +// have changed, hence we can just keep bumping the change ID for this user. const upsertKeyChangeSQL = "" + - "INSERT INTO keyserver_key_changes (partition, log_offset, user_id)" + - " VALUES ($1, $2, $3)" + - " ON CONFLICT ON CONSTRAINT keyserver_key_changes_unique" + - " DO UPDATE SET user_id = $3" + "INSERT INTO keyserver_key_changes (user_id)" + + " VALUES ($1)" + + " ON CONFLICT ON CONSTRAINT keyserver_key_changes_unique_per_user" + + " DO UPDATE SET change_id = nextval('keyserver_key_changes_seq')" + + " RETURNING change_id" -// select the highest offset for each user in the range. The grouping by user gives distinct entries and then we just -// take the max offset value as the latest offset. const selectKeyChangesSQL = "" + - "SELECT user_id, MAX(log_offset) FROM keyserver_key_changes WHERE partition = $1 AND log_offset > $2 AND log_offset <= $3 GROUP BY user_id" + "SELECT user_id, change_id FROM keyserver_key_changes WHERE change_id > $1 AND change_id <= $2" type keyChangesStatements struct { db *sql.DB @@ -59,31 +57,32 @@ func NewPostgresKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) { db: db, } _, err := db.Exec(keyChangesSchema) - if err != nil { - return nil, err - } - if s.upsertKeyChangeStmt, err = db.Prepare(upsertKeyChangeSQL); err != nil { - return nil, err - } - if s.selectKeyChangesStmt, err = db.Prepare(selectKeyChangesSQL); err != nil { - return nil, err - } - return s, nil + return s, err } -func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error { - _, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID) - return err +func (s *keyChangesStatements) Prepare() (err error) { + if s.upsertKeyChangeStmt, err = s.db.Prepare(upsertKeyChangeSQL); err != nil { + return err + } + if s.selectKeyChangesStmt, err = s.db.Prepare(selectKeyChangesSQL); err != nil { + return err + } + return nil +} + +func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, userID string) (changeID int64, err error) { + err = s.upsertKeyChangeStmt.QueryRowContext(ctx, userID).Scan(&changeID) + return } func (s *keyChangesStatements) SelectKeyChanges( - ctx context.Context, partition int32, fromOffset, toOffset int64, + ctx context.Context, fromOffset, toOffset int64, ) (userIDs []string, latestOffset int64, err error) { if toOffset == sarama.OffsetNewest { toOffset = math.MaxInt64 } latestOffset = fromOffset - rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset, toOffset) + rows, err := s.selectKeyChangesStmt.QueryContext(ctx, fromOffset, toOffset) if err != nil { return nil, 0, err } diff --git a/keyserver/storage/postgres/storage.go b/keyserver/storage/postgres/storage.go index 52f3a7f6..b71cc1a7 100644 --- a/keyserver/storage/postgres/storage.go +++ b/keyserver/storage/postgres/storage.go @@ -16,6 +16,7 @@ package postgres import ( "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/keyserver/storage/postgres/deltas" "github.com/matrix-org/dendrite/keyserver/storage/shared" "github.com/matrix-org/dendrite/setup/config" ) @@ -51,6 +52,14 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) if err != nil { return nil, err } + m := sqlutil.NewMigrations() + deltas.LoadRefactorKeyChanges(m) + if err = m.RunDeltas(db, dbProperties); err != nil { + return nil, err + } + if err = kc.Prepare(); err != nil { + return nil, err + } d := &shared.Database{ DB: db, Writer: sqlutil.NewDummyWriter(), diff --git a/keyserver/storage/shared/storage.go b/keyserver/storage/shared/storage.go index 5bd8be36..5914d28e 100644 --- a/keyserver/storage/shared/storage.go +++ b/keyserver/storage/shared/storage.go @@ -135,14 +135,16 @@ func (d *Database) ClaimKeys(ctx context.Context, userToDeviceToAlgorithm map[st return result, err } -func (d *Database) StoreKeyChange(ctx context.Context, partition int32, offset int64, userID string) error { - return d.Writer.Do(nil, nil, func(_ *sql.Tx) error { - return d.KeyChangesTable.InsertKeyChange(ctx, partition, offset, userID) +func (d *Database) StoreKeyChange(ctx context.Context, userID string) (id int64, err error) { + err = d.Writer.Do(nil, nil, func(_ *sql.Tx) error { + id, err = d.KeyChangesTable.InsertKeyChange(ctx, userID) + return err }) + return } -func (d *Database) KeyChanges(ctx context.Context, partition int32, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) { - return d.KeyChangesTable.SelectKeyChanges(ctx, partition, fromOffset, toOffset) +func (d *Database) KeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) { + return d.KeyChangesTable.SelectKeyChanges(ctx, fromOffset, toOffset) } // StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists. diff --git a/keyserver/storage/sqlite3/deltas/2022012016470000_key_changes.go b/keyserver/storage/sqlite3/deltas/2022012016470000_key_changes.go new file mode 100644 index 00000000..fbc548c3 --- /dev/null +++ b/keyserver/storage/sqlite3/deltas/2022012016470000_key_changes.go @@ -0,0 +1,76 @@ +// Copyright 2022 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package deltas + +import ( + "database/sql" + "fmt" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/pressly/goose" +) + +func LoadFromGoose() { + goose.AddMigration(UpRefactorKeyChanges, DownRefactorKeyChanges) +} + +func LoadRefactorKeyChanges(m *sqlutil.Migrations) { + m.AddMigration(UpRefactorKeyChanges, DownRefactorKeyChanges) +} + +func UpRefactorKeyChanges(tx *sql.Tx) error { + // start counting from the last max offset, else 0. + var maxOffset int64 + var userID string + _ = tx.QueryRow(`SELECT user_id, MAX(log_offset) FROM keyserver_key_changes GROUP BY user_id`).Scan(&userID, &maxOffset) + + _, err := tx.Exec(` + -- make the new table + DROP TABLE IF EXISTS keyserver_key_changes; + CREATE TABLE IF NOT EXISTS keyserver_key_changes ( + change_id INTEGER PRIMARY KEY AUTOINCREMENT, + -- The key owner + user_id TEXT NOT NULL, + UNIQUE (user_id) + ); + `) + if err != nil { + return fmt.Errorf("failed to execute upgrade: %w", err) + } + // to start counting from maxOffset, insert a row with that value + if userID != "" { + _, err = tx.Exec(`INSERT INTO keyserver_key_changes(change_id, user_id) VALUES($1, $2)`, maxOffset, userID) + return err + } + return nil +} + +func DownRefactorKeyChanges(tx *sql.Tx) error { + _, err := tx.Exec(` + -- Drop all data and revert back, we can't keep the data as Kafka offsets determine the numbers + DROP TABLE IF EXISTS keyserver_key_changes; + CREATE TABLE IF NOT EXISTS keyserver_key_changes ( + partition BIGINT NOT NULL, + offset BIGINT NOT NULL, + -- The key owner + user_id TEXT NOT NULL, + UNIQUE (partition, offset) + ); + `) + if err != nil { + return fmt.Errorf("failed to execute downgrade: %w", err) + } + return nil +} diff --git a/keyserver/storage/sqlite3/key_changes_table.go b/keyserver/storage/sqlite3/key_changes_table.go index b4753ccc..d43c15ca 100644 --- a/keyserver/storage/sqlite3/key_changes_table.go +++ b/keyserver/storage/sqlite3/key_changes_table.go @@ -27,27 +27,22 @@ import ( var keyChangesSchema = ` -- Stores key change information about users. Used to determine when to send updated device lists to clients. CREATE TABLE IF NOT EXISTS keyserver_key_changes ( - partition BIGINT NOT NULL, - offset BIGINT NOT NULL, + change_id INTEGER PRIMARY KEY AUTOINCREMENT, -- The key owner user_id TEXT NOT NULL, - UNIQUE (partition, offset) + UNIQUE (user_id) ); ` -// Replace based on partition|offset - we should never insert duplicates unless the kafka logs are wiped. -// Rather than falling over, just overwrite (though this will mean clients with an existing sync token will -// miss out on updates). TODO: Ideally we would detect when kafka logs are purged then purge this table too. +// Replace based on user ID. We don't care how many times the user's keys have changed, only that they +// have changed, hence we can just keep bumping the change ID for this user. const upsertKeyChangeSQL = "" + - "INSERT INTO keyserver_key_changes (partition, offset, user_id)" + - " VALUES ($1, $2, $3)" + - " ON CONFLICT (partition, offset)" + - " DO UPDATE SET user_id = $3" + "INSERT OR REPLACE INTO keyserver_key_changes (user_id)" + + " VALUES ($1)" + + " RETURNING change_id" -// select the highest offset for each user in the range. The grouping by user gives distinct entries and then we just -// take the max offset value as the latest offset. const selectKeyChangesSQL = "" + - "SELECT user_id, MAX(offset) FROM keyserver_key_changes WHERE partition = $1 AND offset > $2 AND offset <= $3 GROUP BY user_id" + "SELECT user_id, change_id FROM keyserver_key_changes WHERE change_id > $1 AND change_id <= $2" type keyChangesStatements struct { db *sql.DB @@ -60,31 +55,32 @@ func NewSqliteKeyChangesTable(db *sql.DB) (tables.KeyChanges, error) { db: db, } _, err := db.Exec(keyChangesSchema) - if err != nil { - return nil, err - } - if s.upsertKeyChangeStmt, err = db.Prepare(upsertKeyChangeSQL); err != nil { - return nil, err - } - if s.selectKeyChangesStmt, err = db.Prepare(selectKeyChangesSQL); err != nil { - return nil, err - } - return s, nil + return s, err } -func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error { - _, err := s.upsertKeyChangeStmt.ExecContext(ctx, partition, offset, userID) - return err +func (s *keyChangesStatements) Prepare() (err error) { + if s.upsertKeyChangeStmt, err = s.db.Prepare(upsertKeyChangeSQL); err != nil { + return err + } + if s.selectKeyChangesStmt, err = s.db.Prepare(selectKeyChangesSQL); err != nil { + return err + } + return nil +} + +func (s *keyChangesStatements) InsertKeyChange(ctx context.Context, userID string) (changeID int64, err error) { + err = s.upsertKeyChangeStmt.QueryRowContext(ctx, userID).Scan(&changeID) + return } func (s *keyChangesStatements) SelectKeyChanges( - ctx context.Context, partition int32, fromOffset, toOffset int64, + ctx context.Context, fromOffset, toOffset int64, ) (userIDs []string, latestOffset int64, err error) { if toOffset == sarama.OffsetNewest { toOffset = math.MaxInt64 } latestOffset = fromOffset - rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset, toOffset) + rows, err := s.selectKeyChangesStmt.QueryContext(ctx, fromOffset, toOffset) if err != nil { return nil, 0, err } diff --git a/keyserver/storage/sqlite3/storage.go b/keyserver/storage/sqlite3/storage.go index ee1746cd..50ce00d0 100644 --- a/keyserver/storage/sqlite3/storage.go +++ b/keyserver/storage/sqlite3/storage.go @@ -17,6 +17,7 @@ package sqlite3 import ( "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/keyserver/storage/shared" + "github.com/matrix-org/dendrite/keyserver/storage/sqlite3/deltas" "github.com/matrix-org/dendrite/setup/config" ) @@ -49,6 +50,15 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*shared.Database, error) if err != nil { return nil, err } + + m := sqlutil.NewMigrations() + deltas.LoadRefactorKeyChanges(m) + if err = m.RunDeltas(db, dbProperties); err != nil { + return nil, err + } + if err = kc.Prepare(); err != nil { + return nil, err + } d := &shared.Database{ DB: db, Writer: sqlutil.NewExclusiveWriter(), diff --git a/keyserver/storage/storage_test.go b/keyserver/storage/storage_test.go index 4e0a8af1..2f8cf809 100644 --- a/keyserver/storage/storage_test.go +++ b/keyserver/storage/storage_test.go @@ -44,15 +44,18 @@ func MustNotError(t *testing.T, err error) { func TestKeyChanges(t *testing.T) { db, clean := MustCreateDatabase(t) defer clean() - MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost")) - MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@bob:localhost")) - MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@charlie:localhost")) - userIDs, latest, err := db.KeyChanges(ctx, 0, 1, sarama.OffsetNewest) + _, err := db.StoreKeyChange(ctx, "@alice:localhost") + MustNotError(t, err) + deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost") + MustNotError(t, err) + deviceChangeIDC, err := db.StoreKeyChange(ctx, "@charlie:localhost") + MustNotError(t, err) + userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDB, sarama.OffsetNewest) if err != nil { t.Fatalf("Failed to KeyChanges: %s", err) } - if latest != 2 { - t.Fatalf("KeyChanges: got latest=%d want 2", latest) + if latest != deviceChangeIDC { + t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDC) } if !reflect.DeepEqual(userIDs, []string{"@charlie:localhost"}) { t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) @@ -62,15 +65,21 @@ func TestKeyChanges(t *testing.T) { func TestKeyChangesNoDupes(t *testing.T) { db, clean := MustCreateDatabase(t) defer clean() - MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost")) - MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@alice:localhost")) - MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@alice:localhost")) - userIDs, latest, err := db.KeyChanges(ctx, 0, 0, sarama.OffsetNewest) + deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost") + MustNotError(t, err) + deviceChangeIDB, err := db.StoreKeyChange(ctx, "@alice:localhost") + MustNotError(t, err) + if deviceChangeIDA == deviceChangeIDB { + t.Fatalf("Expected change ID to be different even when inserting key change for the same user, got %d for both changes", deviceChangeIDA) + } + deviceChangeID, err := db.StoreKeyChange(ctx, "@alice:localhost") + MustNotError(t, err) + userIDs, latest, err := db.KeyChanges(ctx, 0, sarama.OffsetNewest) if err != nil { t.Fatalf("Failed to KeyChanges: %s", err) } - if latest != 2 { - t.Fatalf("KeyChanges: got latest=%d want 2", latest) + if latest != deviceChangeID { + t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeID) } if !reflect.DeepEqual(userIDs, []string{"@alice:localhost"}) { t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) @@ -80,15 +89,18 @@ func TestKeyChangesNoDupes(t *testing.T) { func TestKeyChangesUpperLimit(t *testing.T) { db, clean := MustCreateDatabase(t) defer clean() - MustNotError(t, db.StoreKeyChange(ctx, 0, 0, "@alice:localhost")) - MustNotError(t, db.StoreKeyChange(ctx, 0, 1, "@bob:localhost")) - MustNotError(t, db.StoreKeyChange(ctx, 0, 2, "@charlie:localhost")) - userIDs, latest, err := db.KeyChanges(ctx, 0, 0, 1) + deviceChangeIDA, err := db.StoreKeyChange(ctx, "@alice:localhost") + MustNotError(t, err) + deviceChangeIDB, err := db.StoreKeyChange(ctx, "@bob:localhost") + MustNotError(t, err) + _, err = db.StoreKeyChange(ctx, "@charlie:localhost") + MustNotError(t, err) + userIDs, latest, err := db.KeyChanges(ctx, deviceChangeIDA, deviceChangeIDB) if err != nil { t.Fatalf("Failed to KeyChanges: %s", err) } - if latest != 1 { - t.Fatalf("KeyChanges: got latest=%d want 1", latest) + if latest != deviceChangeIDB { + t.Fatalf("KeyChanges: got latest=%d want %d", latest, deviceChangeIDB) } if !reflect.DeepEqual(userIDs, []string{"@bob:localhost"}) { t.Fatalf("KeyChanges: wrong user_ids: %v", userIDs) diff --git a/keyserver/storage/tables/interface.go b/keyserver/storage/tables/interface.go index 612eeb86..0d94c94c 100644 --- a/keyserver/storage/tables/interface.go +++ b/keyserver/storage/tables/interface.go @@ -44,10 +44,12 @@ type DeviceKeys interface { } type KeyChanges interface { - InsertKeyChange(ctx context.Context, partition int32, offset int64, userID string) error + InsertKeyChange(ctx context.Context, userID string) (int64, error) // SelectKeyChanges returns the set (de-duplicated) of users who have changed their keys between the two offsets. // Results are exclusive of fromOffset and inclusive of toOffset. A toOffset of sarama.OffsetNewest means no upper offset. - SelectKeyChanges(ctx context.Context, partition int32, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) + SelectKeyChanges(ctx context.Context, fromOffset, toOffset int64) (userIDs []string, latestOffset int64, err error) + + Prepare() error } type StaleDeviceLists interface { diff --git a/syncapi/consumers/keychange.go b/syncapi/consumers/keychange.go index d63e4832..97685cc0 100644 --- a/syncapi/consumers/keychange.go +++ b/syncapi/consumers/keychange.go @@ -17,7 +17,6 @@ package consumers import ( "context" "encoding/json" - "sync" "github.com/Shopify/sarama" "github.com/getsentry/sentry-go" @@ -34,16 +33,14 @@ import ( // OutputKeyChangeEventConsumer consumes events that originated in the key server. type OutputKeyChangeEventConsumer struct { - ctx context.Context - keyChangeConsumer *internal.ContinualConsumer - db storage.Database - notifier *notifier.Notifier - stream types.StreamProvider - serverName gomatrixserverlib.ServerName // our server name - rsAPI roomserverAPI.RoomserverInternalAPI - keyAPI api.KeyInternalAPI - partitionToOffset map[int32]int64 - partitionToOffsetMu sync.Mutex + ctx context.Context + keyChangeConsumer *internal.ContinualConsumer + db storage.Database + notifier *notifier.Notifier + stream types.StreamProvider + serverName gomatrixserverlib.ServerName // our server name + rsAPI roomserverAPI.RoomserverInternalAPI + keyAPI api.KeyInternalAPI } // NewOutputKeyChangeEventConsumer creates a new OutputKeyChangeEventConsumer. @@ -69,16 +66,14 @@ func NewOutputKeyChangeEventConsumer( } s := &OutputKeyChangeEventConsumer{ - ctx: process.Context(), - keyChangeConsumer: &consumer, - db: store, - serverName: serverName, - keyAPI: keyAPI, - rsAPI: rsAPI, - partitionToOffset: make(map[int32]int64), - partitionToOffsetMu: sync.Mutex{}, - notifier: notifier, - stream: stream, + ctx: process.Context(), + keyChangeConsumer: &consumer, + db: store, + serverName: serverName, + keyAPI: keyAPI, + rsAPI: rsAPI, + notifier: notifier, + stream: stream, } consumer.ProcessMessage = s.onMessage @@ -88,24 +83,10 @@ func NewOutputKeyChangeEventConsumer( // Start consuming from the key server func (s *OutputKeyChangeEventConsumer) Start() error { - offsets, err := s.keyChangeConsumer.StartOffsets() - s.partitionToOffsetMu.Lock() - for _, o := range offsets { - s.partitionToOffset[o.Partition] = o.Offset - } - s.partitionToOffsetMu.Unlock() - return err -} - -func (s *OutputKeyChangeEventConsumer) updateOffset(msg *sarama.ConsumerMessage) { - s.partitionToOffsetMu.Lock() - defer s.partitionToOffsetMu.Unlock() - s.partitionToOffset[msg.Partition] = msg.Offset + return s.keyChangeConsumer.Start() } func (s *OutputKeyChangeEventConsumer) onMessage(msg *sarama.ConsumerMessage) error { - defer s.updateOffset(msg) - var m api.DeviceMessage if err := json.Unmarshal(msg.Value, &m); err != nil { logrus.WithError(err).Errorf("failed to read device message from key change topic") @@ -118,15 +99,15 @@ func (s *OutputKeyChangeEventConsumer) onMessage(msg *sarama.ConsumerMessage) er } switch m.Type { case api.TypeCrossSigningUpdate: - return s.onCrossSigningMessage(m, msg.Offset) + return s.onCrossSigningMessage(m, m.DeviceChangeID) case api.TypeDeviceKeyUpdate: fallthrough default: - return s.onDeviceKeyMessage(m, msg.Offset) + return s.onDeviceKeyMessage(m, m.DeviceChangeID) } } -func (s *OutputKeyChangeEventConsumer) onDeviceKeyMessage(m api.DeviceMessage, offset int64) error { +func (s *OutputKeyChangeEventConsumer) onDeviceKeyMessage(m api.DeviceMessage, deviceChangeID int64) error { if m.DeviceKeys == nil { return nil } @@ -143,7 +124,7 @@ func (s *OutputKeyChangeEventConsumer) onDeviceKeyMessage(m api.DeviceMessage, o } // make sure we get our own key updates too! queryRes.UserIDsToCount[output.UserID] = 1 - posUpdate := types.StreamPosition(offset) + posUpdate := types.StreamPosition(deviceChangeID) s.stream.Advance(posUpdate) for userID := range queryRes.UserIDsToCount { @@ -153,7 +134,7 @@ func (s *OutputKeyChangeEventConsumer) onDeviceKeyMessage(m api.DeviceMessage, o return nil } -func (s *OutputKeyChangeEventConsumer) onCrossSigningMessage(m api.DeviceMessage, offset int64) error { +func (s *OutputKeyChangeEventConsumer) onCrossSigningMessage(m api.DeviceMessage, deviceChangeID int64) error { output := m.CrossSigningKeyUpdate // work out who we need to notify about the new key var queryRes roomserverAPI.QuerySharedUsersResponse @@ -167,7 +148,7 @@ func (s *OutputKeyChangeEventConsumer) onCrossSigningMessage(m api.DeviceMessage } // make sure we get our own key updates too! queryRes.UserIDsToCount[output.UserID] = 1 - posUpdate := types.StreamPosition(offset) + posUpdate := types.StreamPosition(deviceChangeID) s.stream.Advance(posUpdate) for userID := range queryRes.UserIDsToCount { From 0bf5104bbbfb34cf869e74bcc17a039db176e935 Mon Sep 17 00:00:00 2001 From: kegsay Date: Fri, 21 Jan 2022 14:23:37 +0000 Subject: [PATCH 3/6] Fix #2027 by gracefully handling stub rooms (#2100) The server ACL code on startup will grab all known rooms from the rooms_table and then call `GetStateEvent` with each found room ID to find the server ACL event. This can fail for stub rooms, which will be present in the rooms table. Previously this would result in an error being returned and the server failing to start (!). Now we just return no event for stub rooms. --- roomserver/storage/shared/storage.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index f49536f4..d4c5ebb5 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -842,9 +842,13 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s if err != nil { return nil, err } - if roomInfo == nil || roomInfo.IsStub { + if roomInfo == nil { return nil, fmt.Errorf("room %s doesn't exist", roomID) } + // e.g invited rooms + if roomInfo.IsStub { + return nil, nil + } eventTypeNID, err := d.EventTypesTable.SelectEventTypeNID(ctx, nil, evType) if err == sql.ErrNoRows { // No rooms have an event of this type, otherwise we'd have an event type NID From 3c92b55bec57519dfd5b5a0f06c32b3ae85f5710 Mon Sep 17 00:00:00 2001 From: FORCHA Date: Fri, 21 Jan 2022 15:37:59 +0100 Subject: [PATCH 4/6] Update monolith-sample.conf (#2087) * Update monolith-sample.conf -Replaced undefined monolith value with server_name (my.hostname.com) value in reference tho ths issue https://github.com/matrix-org/dendrite/issues/2078 * Update monolith-sample.conf Changed IP to location of monolith server Co-authored-by: kegsay --- docs/nginx/monolith-sample.conf | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/nginx/monolith-sample.conf b/docs/nginx/monolith-sample.conf index 0344aa96..360eb925 100644 --- a/docs/nginx/monolith-sample.conf +++ b/docs/nginx/monolith-sample.conf @@ -1,3 +1,7 @@ +#change IP to location of monolith server +upstream monolith{ + server 127.0.0.1:8008; +} server { listen 443 ssl; # IPv4 listen [::]:443 ssl; # IPv6 @@ -23,6 +27,6 @@ server { } location /_matrix { - proxy_pass http://monolith:8008; + proxy_pass http://monolith; } } From cd1391fc62d9801b2246e38e3491f52fba585fff Mon Sep 17 00:00:00 2001 From: kegsay Date: Fri, 21 Jan 2022 14:46:47 +0000 Subject: [PATCH 5/6] Document log levels (#2101) --- dendrite-config.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/dendrite-config.yaml b/dendrite-config.yaml index 7593988f..38b146d7 100644 --- a/dendrite-config.yaml +++ b/dendrite-config.yaml @@ -368,6 +368,7 @@ logging: - type: std level: info - type: file + # The logging level, must be one of debug, info, warn, error, fatal, panic. level: info params: path: ./logs From 96bf4aa8383e13624b823806b07a5878c8b3ade3 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 21 Jan 2022 14:59:47 +0000 Subject: [PATCH 6/6] Add `Forward extremities remain so even after the next events are populated as outliers` to `sytest-whitelist` --- sytest-whitelist | 1 + 1 file changed, 1 insertion(+) diff --git a/sytest-whitelist b/sytest-whitelist index 558eb29a..f11cd96b 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -588,3 +588,4 @@ User can invite remote user to room with version 9 Remote user can backfill in a room with version 9 Can reject invites over federation for rooms with version 9 Can receive redactions from regular users over federation in room version 9 +Forward extremities remain so even after the next events are populated as outliers