Remove partitioned stream positions

This was used by the device list stream position. The device list position
now corresponds to the `Offset`, and the partition is always 0, in prep
for removing reliance on Kafka topics for device list changes.
This commit is contained in:
Kegan Dougal 2022-01-19 17:39:47 +00:00
parent 8d3a2d87e6
commit 43f56a45a5
14 changed files with 63 additions and 227 deletions

View file

@ -473,7 +473,7 @@ func Setup(
}), }),
).Methods(http.MethodGet, http.MethodOptions) ).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/account/passwordaa", r0mux.Handle("/account/password",
httputil.MakeAuthAPI("password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse { httputil.MakeAuthAPI("password", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
if r := rateLimits.Limit(req); r != nil { if r := rateLimits.Limit(req); r != nil {
return *r return *r

View file

@ -224,8 +224,6 @@ type QueryKeysResponse struct {
} }
type QueryKeyChangesRequest struct { type QueryKeyChangesRequest struct {
// The partition which had key events sent to
Partition int32
// The offset of the last received key event, or sarama.OffsetOldest if this is from the beginning // The offset of the last received key event, or sarama.OffsetOldest if this is from the beginning
Offset int64 Offset int64
// The inclusive offset where to track key changes up to. Messages with this offset are included in the response. // The inclusive offset where to track key changes up to. Messages with this offset are included in the response.
@ -236,8 +234,6 @@ type QueryKeyChangesRequest struct {
type QueryKeyChangesResponse struct { type QueryKeyChangesResponse struct {
// The set of users who have had their keys change. // The set of users who have had their keys change.
UserIDs []string UserIDs []string
// The partition being served - useful if the partition is unknown at request time
Partition int32
// The latest offset represented in this response. // The latest offset represented in this response.
Offset int64 Offset int64
// Set if there was a problem handling the request. // Set if there was a problem handling the request.

View file

@ -59,17 +59,14 @@ func (a *KeyInternalAPI) InputDeviceListUpdate(
} }
func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) { func (a *KeyInternalAPI) QueryKeyChanges(ctx context.Context, req *api.QueryKeyChangesRequest, res *api.QueryKeyChangesResponse) {
if req.Partition < 0 { partition := 0
req.Partition = a.Producer.DefaultPartition() userIDs, latest, err := a.DB.KeyChanges(ctx, int32(partition), req.Offset, req.ToOffset)
}
userIDs, latest, err := a.DB.KeyChanges(ctx, req.Partition, req.Offset, req.ToOffset)
if err != nil { if err != nil {
res.Error = &api.KeyError{ res.Error = &api.KeyError{
Err: err.Error(), Err: err.Error(),
} }
} }
res.Offset = latest res.Offset = latest
res.Partition = req.Partition
res.UserIDs = userIDs res.UserIDs = userIDs
} }

View file

@ -32,15 +32,6 @@ type KeyChange struct {
DB storage.Database DB storage.Database
} }
// DefaultPartition returns the default partition this process is sending key changes to.
// NB: A keyserver MUST send key changes to only 1 partition or else query operations will
// become inconsistent. Partitions can be sharded (e.g by hash of user ID of key change) but
// then all keyservers must be queried to calculate the entire set of key changes between
// two sync tokens.
func (p *KeyChange) DefaultPartition() int32 {
return 0
}
// ProduceKeyChanges creates new change events for each key // ProduceKeyChanges creates new change events for each key
func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceMessage) error { func (p *KeyChange) ProduceKeyChanges(keys []api.DeviceMessage) error {
userToDeviceCount := make(map[string]int) userToDeviceCount := make(map[string]int)

View file

@ -38,7 +38,7 @@ type OutputKeyChangeEventConsumer struct {
keyChangeConsumer *internal.ContinualConsumer keyChangeConsumer *internal.ContinualConsumer
db storage.Database db storage.Database
notifier *notifier.Notifier notifier *notifier.Notifier
stream types.PartitionedStreamProvider stream types.StreamProvider
serverName gomatrixserverlib.ServerName // our server name serverName gomatrixserverlib.ServerName // our server name
rsAPI roomserverAPI.RoomserverInternalAPI rsAPI roomserverAPI.RoomserverInternalAPI
keyAPI api.KeyInternalAPI keyAPI api.KeyInternalAPI
@ -57,7 +57,7 @@ func NewOutputKeyChangeEventConsumer(
rsAPI roomserverAPI.RoomserverInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI,
store storage.Database, store storage.Database,
notifier *notifier.Notifier, notifier *notifier.Notifier,
stream types.PartitionedStreamProvider, stream types.StreamProvider,
) *OutputKeyChangeEventConsumer { ) *OutputKeyChangeEventConsumer {
consumer := internal.ContinualConsumer{ consumer := internal.ContinualConsumer{
@ -143,10 +143,7 @@ func (s *OutputKeyChangeEventConsumer) onDeviceKeyMessage(m api.DeviceMessage, o
} }
// make sure we get our own key updates too! // make sure we get our own key updates too!
queryRes.UserIDsToCount[output.UserID] = 1 queryRes.UserIDsToCount[output.UserID] = 1
posUpdate := types.LogPosition{ posUpdate := types.StreamPosition(offset)
Offset: offset,
Partition: partition,
}
s.stream.Advance(posUpdate) s.stream.Advance(posUpdate)
for userID := range queryRes.UserIDsToCount { for userID := range queryRes.UserIDsToCount {
@ -170,10 +167,7 @@ func (s *OutputKeyChangeEventConsumer) onCrossSigningMessage(m api.DeviceMessage
} }
// make sure we get our own key updates too! // make sure we get our own key updates too!
queryRes.UserIDsToCount[output.UserID] = 1 queryRes.UserIDsToCount[output.UserID] = 1
posUpdate := types.LogPosition{ posUpdate := types.StreamPosition(offset)
Offset: offset,
Partition: partition,
}
s.stream.Advance(posUpdate) s.stream.Advance(posUpdate)
for userID := range queryRes.UserIDsToCount { for userID := range queryRes.UserIDsToCount {

View file

@ -47,8 +47,8 @@ func DeviceOTKCounts(ctx context.Context, keyAPI keyapi.KeyInternalAPI, userID,
// be already filled in with join/leave information. // be already filled in with join/leave information.
func DeviceListCatchup( func DeviceListCatchup(
ctx context.Context, keyAPI keyapi.KeyInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI, ctx context.Context, keyAPI keyapi.KeyInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI,
userID string, res *types.Response, from, to types.LogPosition, userID string, res *types.Response, from, to types.StreamPosition,
) (newPos types.LogPosition, hasNew bool, err error) { ) (newPos types.StreamPosition, hasNew bool, err error) {
// Track users who we didn't track before but now do by virtue of sharing a room with them, or not. // Track users who we didn't track before but now do by virtue of sharing a room with them, or not.
newlyJoinedRooms := joinedRooms(res, userID) newlyJoinedRooms := joinedRooms(res, userID)
@ -64,27 +64,18 @@ func DeviceListCatchup(
} }
// now also track users who we already share rooms with but who have updated their devices between the two tokens // now also track users who we already share rooms with but who have updated their devices between the two tokens
offset := sarama.OffsetOldest
var partition int32 toOffset := sarama.OffsetNewest
var offset int64 if to > 0 && to > from {
partition = -1 toOffset = int64(to)
offset = sarama.OffsetOldest
// Extract partition/offset from sync token
// TODO: In a world where keyserver is sharded there will be multiple partitions and hence multiple QueryKeyChanges to make.
if !from.IsEmpty() {
partition = from.Partition
offset = from.Offset
} }
var toOffset int64 if from > 0 {
toOffset = sarama.OffsetNewest offset = int64(from)
if toLog := to; toLog.Partition == partition && toLog.Offset > 0 {
toOffset = toLog.Offset
} }
var queryRes keyapi.QueryKeyChangesResponse var queryRes keyapi.QueryKeyChangesResponse
keyAPI.QueryKeyChanges(ctx, &keyapi.QueryKeyChangesRequest{ keyAPI.QueryKeyChanges(ctx, &keyapi.QueryKeyChangesRequest{
Partition: partition, Offset: offset,
Offset: offset, ToOffset: toOffset,
ToOffset: toOffset,
}, &queryRes) }, &queryRes)
if queryRes.Error != nil { if queryRes.Error != nil {
// don't fail the catchup because we may have got useful information by tracking membership // don't fail the catchup because we may have got useful information by tracking membership
@ -95,8 +86,8 @@ func DeviceListCatchup(
var sharedUsersMap map[string]int var sharedUsersMap map[string]int
sharedUsersMap, queryRes.UserIDs = filterSharedUsers(ctx, rsAPI, userID, queryRes.UserIDs) sharedUsersMap, queryRes.UserIDs = filterSharedUsers(ctx, rsAPI, userID, queryRes.UserIDs)
util.GetLogger(ctx).Debugf( util.GetLogger(ctx).Debugf(
"QueryKeyChanges request p=%d,off=%d,to=%d response p=%d off=%d uids=%v", "QueryKeyChanges request off=%d,to=%d response off=%d uids=%v",
partition, offset, toOffset, queryRes.Partition, queryRes.Offset, queryRes.UserIDs, offset, toOffset, queryRes.Offset, queryRes.UserIDs,
) )
userSet := make(map[string]bool) userSet := make(map[string]bool)
for _, userID := range res.DeviceLists.Changed { for _, userID := range res.DeviceLists.Changed {
@ -125,13 +116,8 @@ func DeviceListCatchup(
res.DeviceLists.Left = append(res.DeviceLists.Left, userID) res.DeviceLists.Left = append(res.DeviceLists.Left, userID)
} }
} }
// set the new token
to = types.LogPosition{
Partition: queryRes.Partition,
Offset: queryRes.Offset,
}
return to, hasNew, nil return types.StreamPosition(queryRes.Offset), hasNew, nil
} }
// TrackChangedUsers calculates the values of device_lists.changed|left in the /sync response. // TrackChangedUsers calculates the values of device_lists.changed|left in the /sync response.

View file

@ -6,7 +6,6 @@ import (
"sort" "sort"
"testing" "testing"
"github.com/Shopify/sarama"
keyapi "github.com/matrix-org/dendrite/keyserver/api" keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
@ -16,11 +15,7 @@ import (
var ( var (
syncingUser = "@alice:localhost" syncingUser = "@alice:localhost"
emptyToken = types.LogPosition{} emptyToken = types.StreamPosition(0)
newestToken = types.LogPosition{
Offset: sarama.OffsetNewest,
Partition: 0,
}
) )
type mockKeyAPI struct{} type mockKeyAPI struct{}
@ -186,7 +181,7 @@ func TestKeyChangeCatchupOnJoinShareNewUser(t *testing.T) {
"!another:room": {syncingUser}, "!another:room": {syncingUser},
}, },
} }
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, newestToken) _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
if err != nil { if err != nil {
t.Fatalf("DeviceListCatchup returned an error: %s", err) t.Fatalf("DeviceListCatchup returned an error: %s", err)
} }
@ -209,7 +204,7 @@ func TestKeyChangeCatchupOnLeaveShareLeftUser(t *testing.T) {
"!another:room": {syncingUser}, "!another:room": {syncingUser},
}, },
} }
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, newestToken) _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
if err != nil { if err != nil {
t.Fatalf("DeviceListCatchup returned an error: %s", err) t.Fatalf("DeviceListCatchup returned an error: %s", err)
} }
@ -232,7 +227,7 @@ func TestKeyChangeCatchupOnJoinShareNoNewUsers(t *testing.T) {
"!another:room": {syncingUser, existingUser}, "!another:room": {syncingUser, existingUser},
}, },
} }
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, newestToken) _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
if err != nil { if err != nil {
t.Fatalf("Catchup returned an error: %s", err) t.Fatalf("Catchup returned an error: %s", err)
} }
@ -254,7 +249,7 @@ func TestKeyChangeCatchupOnLeaveShareNoUsers(t *testing.T) {
"!another:room": {syncingUser, existingUser}, "!another:room": {syncingUser, existingUser},
}, },
} }
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, newestToken) _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
if err != nil { if err != nil {
t.Fatalf("DeviceListCatchup returned an error: %s", err) t.Fatalf("DeviceListCatchup returned an error: %s", err)
} }
@ -313,7 +308,7 @@ func TestKeyChangeCatchupNoNewJoinsButMessages(t *testing.T) {
roomID: {syncingUser, existingUser}, roomID: {syncingUser, existingUser},
}, },
} }
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, newestToken) _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
if err != nil { if err != nil {
t.Fatalf("DeviceListCatchup returned an error: %s", err) t.Fatalf("DeviceListCatchup returned an error: %s", err)
} }
@ -341,7 +336,7 @@ func TestKeyChangeCatchupChangeAndLeft(t *testing.T) {
"!another:room": {syncingUser}, "!another:room": {syncingUser},
}, },
} }
_, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, newestToken) _, hasNew, err := DeviceListCatchup(context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken)
if err != nil { if err != nil {
t.Fatalf("Catchup returned an error: %s", err) t.Fatalf("Catchup returned an error: %s", err)
} }
@ -427,7 +422,7 @@ func TestKeyChangeCatchupChangeAndLeftSameRoom(t *testing.T) {
}, },
} }
_, hasNew, err := DeviceListCatchup( _, hasNew, err := DeviceListCatchup(
context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, newestToken, context.Background(), &mockKeyAPI{}, rsAPI, syncingUser, syncResponse, emptyToken, emptyToken,
) )
if err != nil { if err != nil {
t.Fatalf("DeviceListCatchup returned an error: %s", err) t.Fatalf("DeviceListCatchup returned an error: %s", err)

View file

@ -10,7 +10,7 @@ import (
) )
type DeviceListStreamProvider struct { type DeviceListStreamProvider struct {
PartitionedStreamProvider StreamProvider
rsAPI api.RoomserverInternalAPI rsAPI api.RoomserverInternalAPI
keyAPI keyapi.KeyInternalAPI keyAPI keyapi.KeyInternalAPI
} }
@ -18,15 +18,15 @@ type DeviceListStreamProvider struct {
func (p *DeviceListStreamProvider) CompleteSync( func (p *DeviceListStreamProvider) CompleteSync(
ctx context.Context, ctx context.Context,
req *types.SyncRequest, req *types.SyncRequest,
) types.LogPosition { ) types.StreamPosition {
return p.LatestPosition(ctx) return p.LatestPosition(ctx)
} }
func (p *DeviceListStreamProvider) IncrementalSync( func (p *DeviceListStreamProvider) IncrementalSync(
ctx context.Context, ctx context.Context,
req *types.SyncRequest, req *types.SyncRequest,
from, to types.LogPosition, from, to types.StreamPosition,
) types.LogPosition { ) types.StreamPosition {
var err error var err error
to, _, err = internal.DeviceListCatchup(context.Background(), p.keyAPI, p.rsAPI, req.Device.UserID, req.Response, from, to) to, _, err = internal.DeviceListCatchup(context.Background(), p.keyAPI, p.rsAPI, req.Device.UserID, req.Response, from, to)
if err != nil { if err != nil {

View file

@ -18,7 +18,7 @@ type Streams struct {
InviteStreamProvider types.StreamProvider InviteStreamProvider types.StreamProvider
SendToDeviceStreamProvider types.StreamProvider SendToDeviceStreamProvider types.StreamProvider
AccountDataStreamProvider types.StreamProvider AccountDataStreamProvider types.StreamProvider
DeviceListStreamProvider types.PartitionedStreamProvider DeviceListStreamProvider types.StreamProvider
} }
func NewSyncStreamProviders( func NewSyncStreamProviders(
@ -48,9 +48,9 @@ func NewSyncStreamProviders(
userAPI: userAPI, userAPI: userAPI,
}, },
DeviceListStreamProvider: &DeviceListStreamProvider{ DeviceListStreamProvider: &DeviceListStreamProvider{
PartitionedStreamProvider: PartitionedStreamProvider{DB: d}, StreamProvider: StreamProvider{DB: d},
rsAPI: rsAPI, rsAPI: rsAPI,
keyAPI: keyAPI, keyAPI: keyAPI,
}, },
} }

View file

@ -1,38 +0,0 @@
package streams
import (
"context"
"sync"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types"
)
type PartitionedStreamProvider struct {
DB storage.Database
latest types.LogPosition
latestMutex sync.RWMutex
}
func (p *PartitionedStreamProvider) Setup() {
}
func (p *PartitionedStreamProvider) Advance(
latest types.LogPosition,
) {
p.latestMutex.Lock()
defer p.latestMutex.Unlock()
if latest.IsAfter(&p.latest) {
p.latest = latest
}
}
func (p *PartitionedStreamProvider) LatestPosition(
ctx context.Context,
) types.LogPosition {
p.latestMutex.RLock()
defer p.latestMutex.RUnlock()
return p.latest
}

View file

@ -140,6 +140,12 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
// Extract values from request // Extract values from request
syncReq, err := newSyncRequest(req, *device, rp.db) syncReq, err := newSyncRequest(req, *device, rp.db)
if err != nil { if err != nil {
if err == types.ErrMalformedSyncToken {
return util.JSONResponse{
Code: http.StatusUnauthorized,
JSON: jsonerror.InvalidArgumentValue(err.Error()),
}
}
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.Unknown(err.Error()), JSON: jsonerror.Unknown(err.Error()),

View file

@ -42,11 +42,3 @@ type StreamProvider interface {
// LatestPosition returns the latest stream position for this stream. // LatestPosition returns the latest stream position for this stream.
LatestPosition(ctx context.Context) StreamPosition LatestPosition(ctx context.Context) StreamPosition
} }
type PartitionedStreamProvider interface {
Setup()
Advance(latest LogPosition)
CompleteSync(ctx context.Context, req *SyncRequest) LogPosition
IncrementalSync(ctx context.Context, req *SyncRequest, from, to LogPosition) LogPosition
LatestPosition(ctx context.Context) LogPosition
}

View file

@ -16,6 +16,7 @@ package types
import ( import (
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"strconv" "strconv"
"strings" "strings"
@ -26,13 +27,10 @@ import (
) )
var ( var (
// ErrInvalidSyncTokenType is returned when an attempt at creating a // This error is returned when parsing sync tokens if the token is invalid. Callers can use this
// new instance of SyncToken with an invalid type (i.e. neither "s" // error to detect whether to 400 or 401 the client. It is recommended to 401 them to force a
// nor "t"). // logout.
ErrInvalidSyncTokenType = fmt.Errorf("sync token has an unknown prefix (should be either s or t)") ErrMalformedSyncToken = errors.New("malformed sync token")
// ErrInvalidSyncTokenLen is returned when the pagination token is an
// invalid length
ErrInvalidSyncTokenLen = fmt.Errorf("sync token has an invalid length")
) )
type StateDelta struct { type StateDelta struct {
@ -47,27 +45,6 @@ type StateDelta struct {
// StreamPosition represents the offset in the sync stream a client is at. // StreamPosition represents the offset in the sync stream a client is at.
type StreamPosition int64 type StreamPosition int64
// LogPosition represents the offset in a Kafka log a client is at.
type LogPosition struct {
Partition int32
Offset int64
}
func (p *LogPosition) IsEmpty() bool {
return p.Offset == 0
}
// IsAfter returns true if this position is after `lp`.
func (p *LogPosition) IsAfter(lp *LogPosition) bool {
if lp == nil {
return false
}
if p.Partition != lp.Partition {
return false
}
return p.Offset > lp.Offset
}
// StreamEvent is the same as gomatrixserverlib.Event but also has the PDU stream position for this event. // StreamEvent is the same as gomatrixserverlib.Event but also has the PDU stream position for this event.
type StreamEvent struct { type StreamEvent struct {
*gomatrixserverlib.HeaderedEvent *gomatrixserverlib.HeaderedEvent
@ -124,7 +101,7 @@ type StreamingToken struct {
SendToDevicePosition StreamPosition SendToDevicePosition StreamPosition
InvitePosition StreamPosition InvitePosition StreamPosition
AccountDataPosition StreamPosition AccountDataPosition StreamPosition
DeviceListPosition LogPosition DeviceListPosition StreamPosition
} }
// This will be used as a fallback by json.Marshal. // This will be used as a fallback by json.Marshal.
@ -140,14 +117,11 @@ func (s *StreamingToken) UnmarshalText(text []byte) (err error) {
func (t StreamingToken) String() string { func (t StreamingToken) String() string {
posStr := fmt.Sprintf( posStr := fmt.Sprintf(
"s%d_%d_%d_%d_%d_%d", "s%d_%d_%d_%d_%d_%d_%d",
t.PDUPosition, t.TypingPosition, t.PDUPosition, t.TypingPosition,
t.ReceiptPosition, t.SendToDevicePosition, t.ReceiptPosition, t.SendToDevicePosition,
t.InvitePosition, t.AccountDataPosition, t.InvitePosition, t.AccountDataPosition, t.DeviceListPosition,
) )
if dl := t.DeviceListPosition; !dl.IsEmpty() {
posStr += fmt.Sprintf(".dl-%d-%d", dl.Partition, dl.Offset)
}
return posStr return posStr
} }
@ -166,14 +140,14 @@ func (t *StreamingToken) IsAfter(other StreamingToken) bool {
return true return true
case t.AccountDataPosition > other.AccountDataPosition: case t.AccountDataPosition > other.AccountDataPosition:
return true return true
case t.DeviceListPosition.IsAfter(&other.DeviceListPosition): case t.DeviceListPosition > other.DeviceListPosition:
return true return true
} }
return false return false
} }
func (t *StreamingToken) IsEmpty() bool { func (t *StreamingToken) IsEmpty() bool {
return t == nil || t.PDUPosition+t.TypingPosition+t.ReceiptPosition+t.SendToDevicePosition+t.InvitePosition+t.AccountDataPosition == 0 && t.DeviceListPosition.IsEmpty() return t == nil || t.PDUPosition+t.TypingPosition+t.ReceiptPosition+t.SendToDevicePosition+t.InvitePosition+t.AccountDataPosition+t.DeviceListPosition == 0
} }
// WithUpdates returns a copy of the StreamingToken with updates applied from another StreamingToken. // WithUpdates returns a copy of the StreamingToken with updates applied from another StreamingToken.
@ -208,7 +182,7 @@ func (t *StreamingToken) ApplyUpdates(other StreamingToken) {
if other.AccountDataPosition > t.AccountDataPosition { if other.AccountDataPosition > t.AccountDataPosition {
t.AccountDataPosition = other.AccountDataPosition t.AccountDataPosition = other.AccountDataPosition
} }
if other.DeviceListPosition.IsAfter(&t.DeviceListPosition) { if other.DeviceListPosition > t.DeviceListPosition {
t.DeviceListPosition = other.DeviceListPosition t.DeviceListPosition = other.DeviceListPosition
} }
} }
@ -292,16 +266,15 @@ func NewTopologyTokenFromString(tok string) (token TopologyToken, err error) {
func NewStreamTokenFromString(tok string) (token StreamingToken, err error) { func NewStreamTokenFromString(tok string) (token StreamingToken, err error) {
if len(tok) < 1 { if len(tok) < 1 {
err = fmt.Errorf("empty stream token") err = ErrMalformedSyncToken
return return
} }
if tok[0] != SyncTokenTypeStream[0] { if tok[0] != SyncTokenTypeStream[0] {
err = fmt.Errorf("stream token must start with 's'") err = ErrMalformedSyncToken
return return
} }
categories := strings.Split(tok[1:], ".") parts := strings.Split(tok[1:], "_")
parts := strings.Split(categories[0], "_") var positions [7]StreamPosition
var positions [6]StreamPosition
for i, p := range parts { for i, p := range parts {
if i > len(positions) { if i > len(positions) {
break break
@ -309,6 +282,7 @@ func NewStreamTokenFromString(tok string) (token StreamingToken, err error) {
var pos int var pos int
pos, err = strconv.Atoi(p) pos, err = strconv.Atoi(p)
if err != nil { if err != nil {
err = ErrMalformedSyncToken
return return
} }
positions[i] = StreamPosition(pos) positions[i] = StreamPosition(pos)
@ -320,31 +294,7 @@ func NewStreamTokenFromString(tok string) (token StreamingToken, err error) {
SendToDevicePosition: positions[3], SendToDevicePosition: positions[3],
InvitePosition: positions[4], InvitePosition: positions[4],
AccountDataPosition: positions[5], AccountDataPosition: positions[5],
} DeviceListPosition: positions[6],
// dl-0-1234
// $log_name-$partition-$offset
for _, logStr := range categories[1:] {
segments := strings.Split(logStr, "-")
if len(segments) != 3 {
err = fmt.Errorf("invalid log position %q", logStr)
return
}
switch segments[0] {
case "dl":
// Device list syncing
var partition, offset int
if partition, err = strconv.Atoi(segments[1]); err != nil {
return
}
if offset, err = strconv.Atoi(segments[2]); err != nil {
return
}
token.DeviceListPosition.Partition = int32(partition)
token.DeviceListPosition.Offset = int64(offset)
default:
err = fmt.Errorf("unrecognised token type %q", segments[0])
return
}
} }
return token, nil return token, nil
} }

View file

@ -2,50 +2,17 @@ package types
import ( import (
"encoding/json" "encoding/json"
"reflect"
"testing" "testing"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
func TestNewSyncTokenWithLogs(t *testing.T) {
tests := map[string]*StreamingToken{
"s4_0_0_0_0_0": {
PDUPosition: 4,
},
"s4_0_0_0_0_0.dl-0-123": {
PDUPosition: 4,
DeviceListPosition: LogPosition{
Partition: 0,
Offset: 123,
},
},
}
for tok, want := range tests {
got, err := NewStreamTokenFromString(tok)
if err != nil {
if want == nil {
continue // error expected
}
t.Errorf("%s errored: %s", tok, err)
continue
}
if !reflect.DeepEqual(got, *want) {
t.Errorf("%s mismatch: got %v want %v", tok, got, want)
}
gotStr := got.String()
if gotStr != tok {
t.Errorf("%s reserialisation mismatch: got %s want %s", tok, gotStr, tok)
}
}
}
func TestSyncTokens(t *testing.T) { func TestSyncTokens(t *testing.T) {
shouldPass := map[string]string{ shouldPass := map[string]string{
"s4_0_0_0_0_0": StreamingToken{4, 0, 0, 0, 0, 0, LogPosition{}}.String(), "s4_0_0_0_0_0_0": StreamingToken{4, 0, 0, 0, 0, 0, 0}.String(),
"s3_1_0_0_0_0.dl-1-2": StreamingToken{3, 1, 0, 0, 0, 0, LogPosition{1, 2}}.String(), "s3_1_0_0_0_0_2": StreamingToken{3, 1, 0, 0, 0, 0, 2}.String(),
"s3_1_2_3_5_0": StreamingToken{3, 1, 2, 3, 5, 0, LogPosition{}}.String(), "s3_1_2_3_5_0_0": StreamingToken{3, 1, 2, 3, 5, 0, 0}.String(),
"t3_1": TopologyToken{3, 1}.String(), "t3_1": TopologyToken{3, 1}.String(),
} }
for a, b := range shouldPass { for a, b := range shouldPass {