More sane next batch handling, typing notification tweaks, give invites their own stream position, device list fix (#1641)

* Update sync responses

* Fix positions, add ApplyUpdates

* Fix MarshalText as non-pointer, PrevBatch is optional

* Increment by number of read receipts

* Merge branch 'master' into neilalexander/devicelist

* Tweak typing

* Include keyserver position tweak

* Fix typing next position in all cases

* Tweaks

* Fix typo

* Tweaks, restore StreamingToken.MarshalText which somehow went missing?

* Rely on positions from notifier rather than manually advancing them

* Revert "Rely on positions from notifier rather than manually advancing them"

This reverts commit 53112a62cc3bfd9989acab518e69eeb27938117a.

* Give invites their own position, fix other things

* Fix test

* Fix invites maybe

* Un-whitelist tests that look to be genuinely wrong

* Use real receipt positions

* Ensure send-to-device uses real positions too
This commit is contained in:
Neil Alexander 2020-12-18 11:11:21 +00:00 committed by GitHub
parent a518e2971a
commit 50963b724b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 228 additions and 152 deletions

View file

@ -82,6 +82,7 @@ func (s *keyChangesStatements) SelectKeyChanges(
if toOffset == sarama.OffsetNewest { if toOffset == sarama.OffsetNewest {
toOffset = math.MaxInt64 toOffset = math.MaxInt64
} }
latestOffset = fromOffset
rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset, toOffset) rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset, toOffset)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err

View file

@ -83,6 +83,7 @@ func (s *keyChangesStatements) SelectKeyChanges(
if toOffset == sarama.OffsetNewest { if toOffset == sarama.OffsetNewest {
toOffset = math.MaxInt64 toOffset = math.MaxInt64
} }
latestOffset = fromOffset
rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset, toOffset) rows, err := s.selectKeyChangesStmt.QueryContext(ctx, partition, fromOffset, toOffset)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err

View file

@ -88,7 +88,7 @@ func (s *OutputReceiptEventConsumer) onMessage(msg *sarama.ConsumerMessage) erro
return err return err
} }
// update stream position // update stream position
s.notifier.OnNewReceipt(types.StreamingToken{ReceiptPosition: streamPos}) s.notifier.OnNewReceipt(output.RoomID, types.StreamingToken{ReceiptPosition: streamPos})
return nil return nil
} }

View file

@ -94,10 +94,8 @@ func (s *OutputSendToDeviceEventConsumer) onMessage(msg *sarama.ConsumerMessage)
"event_type": output.Type, "event_type": output.Type,
}).Info("sync API received send-to-device event from EDU server") }).Info("sync API received send-to-device event from EDU server")
streamPos := s.db.AddSendToDevice() streamPos, err := s.db.StoreNewSendForDeviceMessage(
context.TODO(), output.UserID, output.DeviceID, output.SendToDeviceEvent,
_, err = s.db.StoreNewSendForDeviceMessage(
context.TODO(), streamPos, output.UserID, output.DeviceID, output.SendToDeviceEvent,
) )
if err != nil { if err != nil {
log.WithError(err).Errorf("failed to store send-to-device message") log.WithError(err).Errorf("failed to store send-to-device message")

View file

@ -64,12 +64,7 @@ func NewOutputTypingEventConsumer(
// Start consuming from EDU api // Start consuming from EDU api
func (s *OutputTypingEventConsumer) Start() error { func (s *OutputTypingEventConsumer) Start() error {
s.db.SetTypingTimeoutCallback(func(userID, roomID string, latestSyncPosition int64) { s.db.SetTypingTimeoutCallback(func(userID, roomID string, latestSyncPosition int64) {
s.notifier.OnNewEvent( s.notifier.OnNewTyping(roomID, types.StreamingToken{TypingPosition: types.StreamPosition(latestSyncPosition)})
nil, roomID, nil,
types.StreamingToken{
TypingPosition: types.StreamPosition(latestSyncPosition),
},
)
}) })
return s.typingConsumer.Start() return s.typingConsumer.Start()
@ -97,6 +92,6 @@ func (s *OutputTypingEventConsumer) onMessage(msg *sarama.ConsumerMessage) error
typingPos = s.db.RemoveTypingUser(typingEvent.UserID, typingEvent.RoomID) typingPos = s.db.RemoveTypingUser(typingEvent.UserID, typingEvent.RoomID)
} }
s.notifier.OnNewEvent(nil, output.Event.RoomID, nil, types.StreamingToken{TypingPosition: typingPos}) s.notifier.OnNewTyping(output.Event.RoomID, types.StreamingToken{TypingPosition: typingPos})
return nil return nil
} }

View file

@ -259,6 +259,12 @@ func (s *OutputRoomEventConsumer) notifyJoinedPeeks(ctx context.Context, ev *gom
func (s *OutputRoomEventConsumer) onNewInviteEvent( func (s *OutputRoomEventConsumer) onNewInviteEvent(
ctx context.Context, msg api.OutputNewInviteEvent, ctx context.Context, msg api.OutputNewInviteEvent,
) error { ) error {
if msg.Event.StateKey() == nil {
log.WithFields(log.Fields{
"event": string(msg.Event.JSON()),
}).Panicf("roomserver output log: invite has no state key")
return nil
}
pduPos, err := s.db.AddInviteEvent(ctx, msg.Event) pduPos, err := s.db.AddInviteEvent(ctx, msg.Event)
if err != nil { if err != nil {
// panic rather than continue with an inconsistent database // panic rather than continue with an inconsistent database
@ -269,14 +275,14 @@ func (s *OutputRoomEventConsumer) onNewInviteEvent(
}).Panicf("roomserver output log: write invite failure") }).Panicf("roomserver output log: write invite failure")
return nil return nil
} }
s.notifier.OnNewEvent(msg.Event, "", nil, types.StreamingToken{PDUPosition: pduPos}) s.notifier.OnNewInvite(types.StreamingToken{InvitePosition: pduPos}, *msg.Event.StateKey())
return nil return nil
} }
func (s *OutputRoomEventConsumer) onRetireInviteEvent( func (s *OutputRoomEventConsumer) onRetireInviteEvent(
ctx context.Context, msg api.OutputRetireInviteEvent, ctx context.Context, msg api.OutputRetireInviteEvent,
) error { ) error {
sp, err := s.db.RetireInviteEvent(ctx, msg.EventID) pduPos, err := s.db.RetireInviteEvent(ctx, msg.EventID)
if err != nil { if err != nil {
// panic rather than continue with an inconsistent database // panic rather than continue with an inconsistent database
log.WithFields(log.Fields{ log.WithFields(log.Fields{
@ -287,7 +293,7 @@ func (s *OutputRoomEventConsumer) onRetireInviteEvent(
} }
// Notify any active sync requests that the invite has been retired. // Notify any active sync requests that the invite has been retired.
// Invites share the same stream counter as PDUs // Invites share the same stream counter as PDUs
s.notifier.OnNewEvent(nil, "", []string{msg.TargetUserID}, types.StreamingToken{PDUPosition: sp}) s.notifier.OnNewInvite(types.StreamingToken{InvitePosition: pduPos}, msg.TargetUserID)
return nil return nil
} }

View file

@ -132,7 +132,7 @@ func DeviceListCatchup(
Partition: queryRes.Partition, Partition: queryRes.Partition,
Offset: queryRes.Offset, Offset: queryRes.Offset,
} }
res.NextBatch = to.String() res.NextBatch.ApplyUpdates(to)
return hasNew, nil return hasNew, nil
} }

View file

@ -130,9 +130,9 @@ type Database interface {
// can be deleted altogether by CleanSendToDeviceUpdates // can be deleted altogether by CleanSendToDeviceUpdates
// The token supplied should be the current requested sync token, e.g. from the "since" // The token supplied should be the current requested sync token, e.g. from the "since"
// parameter. // parameter.
SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, token types.StreamingToken) (events []types.SendToDeviceEvent, changes []types.SendToDeviceNID, deletions []types.SendToDeviceNID, err error) SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, token types.StreamingToken) (pos types.StreamPosition, events []types.SendToDeviceEvent, changes []types.SendToDeviceNID, deletions []types.SendToDeviceNID, err error)
// StoreNewSendForDeviceMessage stores a new send-to-device event for a user's device. // StoreNewSendForDeviceMessage stores a new send-to-device event for a user's device.
StoreNewSendForDeviceMessage(ctx context.Context, streamPos types.StreamPosition, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent) (types.StreamPosition, error) StoreNewSendForDeviceMessage(ctx context.Context, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent) (types.StreamPosition, error)
// CleanSendToDeviceUpdates will update or remove any send-to-device updates based on the // CleanSendToDeviceUpdates will update or remove any send-to-device updates based on the
// result to a previous call to SendDeviceUpdatesForSync. This is separate as it allows // result to a previous call to SendDeviceUpdatesForSync. This is separate as it allows
// SendToDeviceUpdatesForSync to be called multiple times if needed (e.g. before and after // SendToDeviceUpdatesForSync to be called multiple times if needed (e.g. before and after

View file

@ -55,7 +55,7 @@ const upsertReceipt = "" +
" RETURNING id" " RETURNING id"
const selectRoomReceipts = "" + const selectRoomReceipts = "" +
"SELECT room_id, receipt_type, user_id, event_id, receipt_ts" + "SELECT id, room_id, receipt_type, user_id, event_id, receipt_ts" +
" FROM syncapi_receipts" + " FROM syncapi_receipts" +
" WHERE room_id = ANY($1) AND id > $2" " WHERE room_id = ANY($1) AND id > $2"
@ -95,22 +95,27 @@ func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, room
return return
} }
func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]api.OutputReceiptEvent, error) { func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []api.OutputReceiptEvent, error) {
lastPos := types.StreamPosition(0)
rows, err := r.selectRoomReceipts.QueryContext(ctx, pq.Array(roomIDs), streamPos) rows, err := r.selectRoomReceipts.QueryContext(ctx, pq.Array(roomIDs), streamPos)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to query room receipts: %w", err) return 0, nil, fmt.Errorf("unable to query room receipts: %w", err)
} }
defer internal.CloseAndLogIfError(ctx, rows, "SelectRoomReceiptsAfter: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "SelectRoomReceiptsAfter: rows.close() failed")
var res []api.OutputReceiptEvent var res []api.OutputReceiptEvent
for rows.Next() { for rows.Next() {
r := api.OutputReceiptEvent{} r := api.OutputReceiptEvent{}
err = rows.Scan(&r.RoomID, &r.Type, &r.UserID, &r.EventID, &r.Timestamp) var id types.StreamPosition
err = rows.Scan(&id, &r.RoomID, &r.Type, &r.UserID, &r.EventID, &r.Timestamp)
if err != nil { if err != nil {
return res, fmt.Errorf("unable to scan row to api.Receipts: %w", err) return 0, res, fmt.Errorf("unable to scan row to api.Receipts: %w", err)
} }
res = append(res, r) res = append(res, r)
if id > lastPos {
lastPos = id
}
} }
return res, rows.Err() return lastPos, res, rows.Err()
} }
func (s *receiptStatements) SelectMaxReceiptID( func (s *receiptStatements) SelectMaxReceiptID(

View file

@ -49,6 +49,7 @@ CREATE TABLE IF NOT EXISTS syncapi_send_to_device (
const insertSendToDeviceMessageSQL = ` const insertSendToDeviceMessageSQL = `
INSERT INTO syncapi_send_to_device (user_id, device_id, content) INSERT INTO syncapi_send_to_device (user_id, device_id, content)
VALUES ($1, $2, $3) VALUES ($1, $2, $3)
RETURNING id
` `
const countSendToDeviceMessagesSQL = ` const countSendToDeviceMessagesSQL = `
@ -107,8 +108,8 @@ func NewPostgresSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
func (s *sendToDeviceStatements) InsertSendToDeviceMessage( func (s *sendToDeviceStatements) InsertSendToDeviceMessage(
ctx context.Context, txn *sql.Tx, userID, deviceID, content string, ctx context.Context, txn *sql.Tx, userID, deviceID, content string,
) (err error) { ) (pos types.StreamPosition, err error) {
_, err = sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content) err = sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).QueryRowContext(ctx, userID, deviceID, content).Scan(&pos)
return return
} }
@ -124,7 +125,7 @@ func (s *sendToDeviceStatements) CountSendToDeviceMessages(
func (s *sendToDeviceStatements) SelectSendToDeviceMessages( func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, userID, deviceID string, ctx context.Context, txn *sql.Tx, userID, deviceID string,
) (events []types.SendToDeviceEvent, err error) { ) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error) {
rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID) rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID)
if err != nil { if err != nil {
return return
@ -152,9 +153,12 @@ func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
} }
} }
events = append(events, event) events = append(events, event)
if types.StreamPosition(id) > lastPos {
lastPos = types.StreamPosition(id)
}
} }
return events, rows.Err() return lastPos, events, rows.Err()
} }
func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages( func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages(

View file

@ -492,6 +492,7 @@ func (d *Database) syncPositionTx(
PDUPosition: types.StreamPosition(maxEventID), PDUPosition: types.StreamPosition(maxEventID),
TypingPosition: types.StreamPosition(d.EDUCache.GetLatestSyncPosition()), TypingPosition: types.StreamPosition(d.EDUCache.GetLatestSyncPosition()),
ReceiptPosition: types.StreamPosition(maxReceiptID), ReceiptPosition: types.StreamPosition(maxReceiptID),
InvitePosition: types.StreamPosition(maxInviteID),
} }
return return
} }
@ -543,11 +544,6 @@ func (d *Database) addPDUDeltaToResponse(
} }
} }
// TODO: This should be done in getStateDeltas
if err = d.addInvitesToResponse(ctx, txn, device.UserID, r, res); err != nil {
return nil, fmt.Errorf("d.addInvitesToResponse: %w", err)
}
succeeded = true succeeded = true
return joinedRoomIDs, nil return joinedRoomIDs, nil
} }
@ -583,6 +579,7 @@ func (d *Database) addTypingDeltaToResponse(
res.Rooms.Join[roomID] = jr res.Rooms.Join[roomID] = jr
} }
} }
res.NextBatch.TypingPosition = types.StreamPosition(d.EDUCache.GetLatestSyncPosition())
return nil return nil
} }
@ -593,7 +590,7 @@ func (d *Database) addReceiptDeltaToResponse(
joinedRoomIDs []string, joinedRoomIDs []string,
res *types.Response, res *types.Response,
) error { ) error {
receipts, err := d.Receipts.SelectRoomReceiptsAfter(context.TODO(), joinedRoomIDs, since.ReceiptPosition) lastPos, receipts, err := d.Receipts.SelectRoomReceiptsAfter(context.TODO(), joinedRoomIDs, since.ReceiptPosition)
if err != nil { if err != nil {
return fmt.Errorf("unable to select receipts for rooms: %w", err) return fmt.Errorf("unable to select receipts for rooms: %w", err)
} }
@ -638,6 +635,7 @@ func (d *Database) addReceiptDeltaToResponse(
res.Rooms.Join[roomID] = jr res.Rooms.Join[roomID] = jr
} }
res.NextBatch.ReceiptPosition = lastPos
return nil return nil
} }
@ -691,8 +689,7 @@ func (d *Database) IncrementalSync(
numRecentEventsPerRoom int, numRecentEventsPerRoom int,
wantFullState bool, wantFullState bool,
) (*types.Response, error) { ) (*types.Response, error) {
nextBatchPos := fromPos.WithUpdates(toPos) res.NextBatch = fromPos.WithUpdates(toPos)
res.NextBatch = nextBatchPos.String()
var joinedRoomIDs []string var joinedRoomIDs []string
var err error var err error
@ -725,6 +722,14 @@ func (d *Database) IncrementalSync(
return nil, fmt.Errorf("d.addEDUDeltaToResponse: %w", err) return nil, fmt.Errorf("d.addEDUDeltaToResponse: %w", err)
} }
ir := types.Range{
From: fromPos.InvitePosition,
To: toPos.InvitePosition,
}
if err = d.addInvitesToResponse(ctx, nil, device.UserID, ir, res); err != nil {
return nil, fmt.Errorf("d.addInvitesToResponse: %w", err)
}
return res, nil return res, nil
} }
@ -783,8 +788,12 @@ func (d *Database) getResponseWithPDUsForCompleteSync(
From: 0, From: 0,
To: toPos.PDUPosition, To: toPos.PDUPosition,
} }
ir := types.Range{
From: 0,
To: toPos.InvitePosition,
}
res.NextBatch = toPos.String() res.NextBatch.ApplyUpdates(toPos)
// Extract room state and recent events for all rooms the user is joined to. // Extract room state and recent events for all rooms the user is joined to.
joinedRoomIDs, err = d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join) joinedRoomIDs, err = d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join)
@ -824,7 +833,7 @@ func (d *Database) getResponseWithPDUsForCompleteSync(
} }
} }
if err = d.addInvitesToResponse(ctx, txn, userID, r, res); err != nil { if err = d.addInvitesToResponse(ctx, txn, userID, ir, res); err != nil {
return return
} }
@ -884,19 +893,18 @@ func (d *Database) getJoinResponseForCompleteSync(
// Retrieve the backward topology position, i.e. the position of the // Retrieve the backward topology position, i.e. the position of the
// oldest event in the room's topology. // oldest event in the room's topology.
var prevBatchStr string var prevBatch *types.TopologyToken
if len(recentStreamEvents) > 0 { if len(recentStreamEvents) > 0 {
var backwardTopologyPos, backwardStreamPos types.StreamPosition var backwardTopologyPos, backwardStreamPos types.StreamPosition
backwardTopologyPos, backwardStreamPos, err = d.Topology.SelectPositionInTopology(ctx, txn, recentStreamEvents[0].EventID()) backwardTopologyPos, backwardStreamPos, err = d.Topology.SelectPositionInTopology(ctx, txn, recentStreamEvents[0].EventID())
if err != nil { if err != nil {
return return
} }
prevBatch := types.TopologyToken{ prevBatch = &types.TopologyToken{
Depth: backwardTopologyPos, Depth: backwardTopologyPos,
PDUPosition: backwardStreamPos, PDUPosition: backwardStreamPos,
} }
prevBatch.Decrement() prevBatch.Decrement()
prevBatchStr = prevBatch.String()
} }
// We don't include a device here as we don't need to send down // We don't include a device here as we don't need to send down
@ -905,7 +913,7 @@ func (d *Database) getJoinResponseForCompleteSync(
recentEvents := d.StreamEventsToEvents(&device, recentStreamEvents) recentEvents := d.StreamEventsToEvents(&device, recentStreamEvents)
stateEvents = removeDuplicates(stateEvents, recentEvents) stateEvents = removeDuplicates(stateEvents, recentEvents)
jr = types.NewJoinResponse() jr = types.NewJoinResponse()
jr.Timeline.PrevBatch = prevBatchStr jr.Timeline.PrevBatch = prevBatch
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
jr.Timeline.Limited = limited jr.Timeline.Limited = limited
jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(stateEvents, gomatrixserverlib.FormatSync) jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(stateEvents, gomatrixserverlib.FormatSync)
@ -1033,7 +1041,7 @@ func (d *Database) addRoomDeltaToResponse(
case gomatrixserverlib.Join: case gomatrixserverlib.Join:
jr := types.NewJoinResponse() jr := types.NewJoinResponse()
jr.Timeline.PrevBatch = prevBatch.String() jr.Timeline.PrevBatch = &prevBatch
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
jr.Timeline.Limited = limited jr.Timeline.Limited = limited
jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync) jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync)
@ -1041,7 +1049,7 @@ func (d *Database) addRoomDeltaToResponse(
case gomatrixserverlib.Peek: case gomatrixserverlib.Peek:
jr := types.NewJoinResponse() jr := types.NewJoinResponse()
jr.Timeline.PrevBatch = prevBatch.String() jr.Timeline.PrevBatch = &prevBatch
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
jr.Timeline.Limited = limited jr.Timeline.Limited = limited
jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync) jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync)
@ -1052,7 +1060,7 @@ func (d *Database) addRoomDeltaToResponse(
// TODO: recentEvents may contain events that this user is not allowed to see because they are // TODO: recentEvents may contain events that this user is not allowed to see because they are
// no longer in the room. // no longer in the room.
lr := types.NewLeaveResponse() lr := types.NewLeaveResponse()
lr.Timeline.PrevBatch = prevBatch.String() lr.Timeline.PrevBatch = &prevBatch
lr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync) lr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
lr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true lr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true
lr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync) lr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync)
@ -1373,39 +1381,40 @@ func (d *Database) SendToDeviceUpdatesWaiting(
} }
func (d *Database) StoreNewSendForDeviceMessage( func (d *Database) StoreNewSendForDeviceMessage(
ctx context.Context, streamPos types.StreamPosition, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent, ctx context.Context, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent,
) (types.StreamPosition, error) { ) (newPos types.StreamPosition, err error) {
j, err := json.Marshal(event) j, err := json.Marshal(event)
if err != nil { if err != nil {
return streamPos, err return 0, err
} }
// Delegate the database write task to the SendToDeviceWriter. It'll guarantee // Delegate the database write task to the SendToDeviceWriter. It'll guarantee
// that we don't lock the table for writes in more than one place. // that we don't lock the table for writes in more than one place.
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.SendToDevice.InsertSendToDeviceMessage( newPos, err = d.SendToDevice.InsertSendToDeviceMessage(
ctx, txn, userID, deviceID, string(j), ctx, txn, userID, deviceID, string(j),
) )
return err
}) })
if err != nil { if err != nil {
return streamPos, err return 0, err
} }
return streamPos, nil return 0, nil
} }
func (d *Database) SendToDeviceUpdatesForSync( func (d *Database) SendToDeviceUpdatesForSync(
ctx context.Context, ctx context.Context,
userID, deviceID string, userID, deviceID string,
token types.StreamingToken, token types.StreamingToken,
) ([]types.SendToDeviceEvent, []types.SendToDeviceNID, []types.SendToDeviceNID, error) { ) (types.StreamPosition, []types.SendToDeviceEvent, []types.SendToDeviceNID, []types.SendToDeviceNID, error) {
// First of all, get our send-to-device updates for this user. // First of all, get our send-to-device updates for this user.
events, err := d.SendToDevice.SelectSendToDeviceMessages(ctx, nil, userID, deviceID) lastPos, events, err := d.SendToDevice.SelectSendToDeviceMessages(ctx, nil, userID, deviceID)
if err != nil { if err != nil {
return nil, nil, nil, fmt.Errorf("d.SendToDevice.SelectSendToDeviceMessages: %w", err) return 0, nil, nil, nil, fmt.Errorf("d.SendToDevice.SelectSendToDeviceMessages: %w", err)
} }
// If there's nothing to do then stop here. // If there's nothing to do then stop here.
if len(events) == 0 { if len(events) == 0 {
return nil, nil, nil, nil return 0, nil, nil, nil, nil
} }
// Work out whether we need to update any of the database entries. // Work out whether we need to update any of the database entries.
@ -1432,7 +1441,7 @@ func (d *Database) SendToDeviceUpdatesForSync(
} }
} }
return toReturn, toUpdate, toDelete, nil return lastPos, toReturn, toUpdate, toDelete, nil
} }
func (d *Database) CleanSendToDeviceUpdates( func (d *Database) CleanSendToDeviceUpdates(
@ -1519,5 +1528,6 @@ func (d *Database) StoreReceipt(ctx context.Context, roomId, receiptType, userId
} }
func (d *Database) GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]eduAPI.OutputReceiptEvent, error) { func (d *Database) GetRoomReceipts(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]eduAPI.OutputReceiptEvent, error) {
return d.Receipts.SelectRoomReceiptsAfter(ctx, roomIDs, streamPos) _, receipts, err := d.Receipts.SelectRoomReceiptsAfter(ctx, roomIDs, streamPos)
return receipts, err
} }

View file

@ -51,7 +51,7 @@ const upsertReceipt = "" +
" DO UPDATE SET id = $7, event_id = $8, receipt_ts = $9" " DO UPDATE SET id = $7, event_id = $8, receipt_ts = $9"
const selectRoomReceipts = "" + const selectRoomReceipts = "" +
"SELECT room_id, receipt_type, user_id, event_id, receipt_ts" + "SELECT id, room_id, receipt_type, user_id, event_id, receipt_ts" +
" FROM syncapi_receipts" + " FROM syncapi_receipts" +
" WHERE id > $1 and room_id in ($2)" " WHERE id > $1 and room_id in ($2)"
@ -99,9 +99,9 @@ func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, room
} }
// SelectRoomReceiptsAfter select all receipts for a given room after a specific timestamp // SelectRoomReceiptsAfter select all receipts for a given room after a specific timestamp
func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]api.OutputReceiptEvent, error) { func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []api.OutputReceiptEvent, error) {
selectSQL := strings.Replace(selectRoomReceipts, "($2)", sqlutil.QueryVariadicOffset(len(roomIDs), 1), 1) selectSQL := strings.Replace(selectRoomReceipts, "($2)", sqlutil.QueryVariadicOffset(len(roomIDs), 1), 1)
lastPos := types.StreamPosition(0)
params := make([]interface{}, len(roomIDs)+1) params := make([]interface{}, len(roomIDs)+1)
params[0] = streamPos params[0] = streamPos
for k, v := range roomIDs { for k, v := range roomIDs {
@ -109,19 +109,23 @@ func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs
} }
rows, err := r.db.QueryContext(ctx, selectSQL, params...) rows, err := r.db.QueryContext(ctx, selectSQL, params...)
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to query room receipts: %w", err) return 0, nil, fmt.Errorf("unable to query room receipts: %w", err)
} }
defer internal.CloseAndLogIfError(ctx, rows, "SelectRoomReceiptsAfter: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "SelectRoomReceiptsAfter: rows.close() failed")
var res []api.OutputReceiptEvent var res []api.OutputReceiptEvent
for rows.Next() { for rows.Next() {
r := api.OutputReceiptEvent{} r := api.OutputReceiptEvent{}
err = rows.Scan(&r.RoomID, &r.Type, &r.UserID, &r.EventID, &r.Timestamp) var id types.StreamPosition
err = rows.Scan(&id, &r.RoomID, &r.Type, &r.UserID, &r.EventID, &r.Timestamp)
if err != nil { if err != nil {
return res, fmt.Errorf("unable to scan row to api.Receipts: %w", err) return 0, res, fmt.Errorf("unable to scan row to api.Receipts: %w", err)
} }
res = append(res, r) res = append(res, r)
if id > lastPos {
lastPos = id
}
} }
return res, rows.Err() return lastPos, res, rows.Err()
} }
func (s *receiptStatements) SelectMaxReceiptID( func (s *receiptStatements) SelectMaxReceiptID(

View file

@ -100,8 +100,14 @@ func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
func (s *sendToDeviceStatements) InsertSendToDeviceMessage( func (s *sendToDeviceStatements) InsertSendToDeviceMessage(
ctx context.Context, txn *sql.Tx, userID, deviceID, content string, ctx context.Context, txn *sql.Tx, userID, deviceID, content string,
) (err error) { ) (pos types.StreamPosition, err error) {
_, err = sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content) var result sql.Result
result, err = sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content)
if p, err := result.LastInsertId(); err != nil {
return 0, err
} else {
pos = types.StreamPosition(p)
}
return return
} }
@ -117,7 +123,7 @@ func (s *sendToDeviceStatements) CountSendToDeviceMessages(
func (s *sendToDeviceStatements) SelectSendToDeviceMessages( func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, userID, deviceID string, ctx context.Context, txn *sql.Tx, userID, deviceID string,
) (events []types.SendToDeviceEvent, err error) { ) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error) {
rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID) rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID)
if err != nil { if err != nil {
return return
@ -145,9 +151,12 @@ func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
} }
} }
events = append(events, event) events = append(events, event)
if types.StreamPosition(id) > lastPos {
lastPos = types.StreamPosition(id)
}
} }
return events, rows.Err() return lastPos, events, rows.Err()
} }
func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages( func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages(

View file

@ -228,7 +228,7 @@ func TestSyncResponse(t *testing.T) {
ReceiptPosition: latest.ReceiptPosition, ReceiptPosition: latest.ReceiptPosition,
SendToDevicePosition: latest.SendToDevicePosition, SendToDevicePosition: latest.SendToDevicePosition,
} }
if res.NextBatch != next.String() { if res.NextBatch.String() != next.String() {
st.Errorf("NextBatch got %s want %s", res.NextBatch, next.String()) st.Errorf("NextBatch got %s want %s", res.NextBatch, next.String())
} }
roomRes, ok := res.Rooms.Join[testRoomID] roomRes, ok := res.Rooms.Join[testRoomID]
@ -266,7 +266,7 @@ func TestGetEventsInRangeWithPrevBatch(t *testing.T) {
// returns the last event "Message 10" // returns the last event "Message 10"
assertEventsEqual(t, "IncrementalSync Timeline", false, roomRes.Timeline.Events, reversed(events[len(events)-1:])) assertEventsEqual(t, "IncrementalSync Timeline", false, roomRes.Timeline.Events, reversed(events[len(events)-1:]))
prev := roomRes.Timeline.PrevBatch prev := roomRes.Timeline.PrevBatch.String()
if prev == "" { if prev == "" {
t.Fatalf("IncrementalSync expected prev_batch token") t.Fatalf("IncrementalSync expected prev_batch token")
} }
@ -539,7 +539,7 @@ func TestSendToDeviceBehaviour(t *testing.T) {
// At this point there should be no messages. We haven't sent anything // At this point there should be no messages. We haven't sent anything
// yet. // yet.
events, updates, deletions, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{}) _, events, updates, deletions, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -552,7 +552,7 @@ func TestSendToDeviceBehaviour(t *testing.T) {
} }
// Try sending a message. // Try sending a message.
streamPos, err := db.StoreNewSendForDeviceMessage(ctx, types.StreamPosition(0), "alice", "one", gomatrixserverlib.SendToDeviceEvent{ streamPos, err := db.StoreNewSendForDeviceMessage(ctx, "alice", "one", gomatrixserverlib.SendToDeviceEvent{
Sender: "bob", Sender: "bob",
Type: "m.type", Type: "m.type",
Content: json.RawMessage("{}"), Content: json.RawMessage("{}"),
@ -564,7 +564,7 @@ func TestSendToDeviceBehaviour(t *testing.T) {
// At this point we should get exactly one message. We're sending the sync position // At this point we should get exactly one message. We're sending the sync position
// that we were given from the update and the send-to-device update will be updated // that we were given from the update and the send-to-device update will be updated
// in the database to reflect that this was the sync position we sent the message at. // in the database to reflect that this was the sync position we sent the message at.
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{SendToDevicePosition: streamPos}) _, events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{SendToDevicePosition: streamPos})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -579,7 +579,7 @@ func TestSendToDeviceBehaviour(t *testing.T) {
// At this point we should still have one message because we haven't progressed the // At this point we should still have one message because we haven't progressed the
// sync position yet. This is equivalent to the client failing to /sync and retrying // sync position yet. This is equivalent to the client failing to /sync and retrying
// with the same position. // with the same position.
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{SendToDevicePosition: streamPos}) _, events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{SendToDevicePosition: streamPos})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -593,7 +593,7 @@ func TestSendToDeviceBehaviour(t *testing.T) {
// At this point we should now have no updates, because we've progressed the sync // At this point we should now have no updates, because we've progressed the sync
// position. Therefore the update from before will not be sent again. // position. Therefore the update from before will not be sent again.
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{SendToDevicePosition: streamPos + 1}) _, events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{SendToDevicePosition: streamPos + 1})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -607,7 +607,7 @@ func TestSendToDeviceBehaviour(t *testing.T) {
// At this point we should still have no updates, because no new updates have been // At this point we should still have no updates, because no new updates have been
// sent. // sent.
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{SendToDevicePosition: streamPos + 2}) _, events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.StreamingToken{SendToDevicePosition: streamPos + 2})
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -666,12 +666,8 @@ func TestInviteBehaviour(t *testing.T) {
assertInvitedToRooms(t, res, []string{inviteRoom2}) assertInvitedToRooms(t, res, []string{inviteRoom2})
// a sync after we have received both invites should result in a leave for the retired room // a sync after we have received both invites should result in a leave for the retired room
beforeRetireTok, err := types.NewStreamTokenFromString(beforeRetireRes.NextBatch)
if err != nil {
t.Fatalf("NewStreamTokenFromString cannot parse next batch '%s' : %s", beforeRetireRes.NextBatch, err)
}
res = types.NewResponse() res = types.NewResponse()
res, err = db.IncrementalSync(ctx, res, testUserDeviceA, beforeRetireTok, latest, 0, false) res, err = db.IncrementalSync(ctx, res, testUserDeviceA, beforeRetireRes.NextBatch, latest, 0, false)
if err != nil { if err != nil {
t.Fatalf("IncrementalSync failed: %s", err) t.Fatalf("IncrementalSync failed: %s", err)
} }

View file

@ -146,8 +146,8 @@ type BackwardsExtremities interface {
// sync parameter isn't later then we will keep including the updates in the // sync parameter isn't later then we will keep including the updates in the
// sync response, as the client is seemingly trying to repeat the same /sync. // sync response, as the client is seemingly trying to repeat the same /sync.
type SendToDevice interface { type SendToDevice interface {
InsertSendToDeviceMessage(ctx context.Context, txn *sql.Tx, userID, deviceID, content string) (err error) InsertSendToDeviceMessage(ctx context.Context, txn *sql.Tx, userID, deviceID, content string) (pos types.StreamPosition, err error)
SelectSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string) (events []types.SendToDeviceEvent, err error) SelectSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error)
UpdateSentSendToDeviceMessages(ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID) (err error) UpdateSentSendToDeviceMessages(ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID) (err error)
DeleteSendToDeviceMessages(ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID) (err error) DeleteSendToDeviceMessages(ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID) (err error)
CountSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string) (count int, err error) CountSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string) (count int, err error)
@ -160,6 +160,6 @@ type Filter interface {
type Receipts interface { type Receipts interface {
UpsertReceipt(ctx context.Context, txn *sql.Tx, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error) UpsertReceipt(ctx context.Context, txn *sql.Tx, roomId, receiptType, userId, eventId string, timestamp gomatrixserverlib.Timestamp) (pos types.StreamPosition, err error)
SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) ([]eduAPI.OutputReceiptEvent, error) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []eduAPI.OutputReceiptEvent, error)
SelectMaxReceiptID(ctx context.Context, txn *sql.Tx) (id int64, err error) SelectMaxReceiptID(ctx context.Context, txn *sql.Tx) (id int64, err error)
} }

View file

@ -77,9 +77,8 @@ func (n *Notifier) OnNewEvent(
// This needs to be done PRIOR to waking up users as they will read this value. // This needs to be done PRIOR to waking up users as they will read this value.
n.streamLock.Lock() n.streamLock.Lock()
defer n.streamLock.Unlock() defer n.streamLock.Unlock()
latestPos := n.currPos.WithUpdates(posUpdate)
n.currPos = latestPos
n.currPos.ApplyUpdates(posUpdate)
n.removeEmptyUserStreams() n.removeEmptyUserStreams()
if ev != nil { if ev != nil {
@ -113,11 +112,11 @@ func (n *Notifier) OnNewEvent(
} }
} }
n.wakeupUsers(usersToNotify, peekingDevicesToNotify, latestPos) n.wakeupUsers(usersToNotify, peekingDevicesToNotify, n.currPos)
} else if roomID != "" { } else if roomID != "" {
n.wakeupUsers(n.joinedUsers(roomID), n.PeekingDevices(roomID), latestPos) n.wakeupUsers(n.joinedUsers(roomID), n.PeekingDevices(roomID), n.currPos)
} else if len(userIDs) > 0 { } else if len(userIDs) > 0 {
n.wakeupUsers(userIDs, nil, latestPos) n.wakeupUsers(userIDs, nil, n.currPos)
} else { } else {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"posUpdate": posUpdate.String, "posUpdate": posUpdate.String,
@ -155,20 +154,33 @@ func (n *Notifier) OnNewSendToDevice(
) { ) {
n.streamLock.Lock() n.streamLock.Lock()
defer n.streamLock.Unlock() defer n.streamLock.Unlock()
latestPos := n.currPos.WithUpdates(posUpdate)
n.currPos = latestPos
n.wakeupUserDevice(userID, deviceIDs, latestPos) n.currPos.ApplyUpdates(posUpdate)
n.wakeupUserDevice(userID, deviceIDs, n.currPos)
} }
// OnNewReceipt updates the current position // OnNewReceipt updates the current position
func (n *Notifier) OnNewReceipt( func (n *Notifier) OnNewTyping(
roomID string,
posUpdate types.StreamingToken, posUpdate types.StreamingToken,
) { ) {
n.streamLock.Lock() n.streamLock.Lock()
defer n.streamLock.Unlock() defer n.streamLock.Unlock()
latestPos := n.currPos.WithUpdates(posUpdate)
n.currPos = latestPos n.currPos.ApplyUpdates(posUpdate)
n.wakeupUsers(n.joinedUsers(roomID), nil, n.currPos)
}
// OnNewReceipt updates the current position
func (n *Notifier) OnNewReceipt(
roomID string,
posUpdate types.StreamingToken,
) {
n.streamLock.Lock()
defer n.streamLock.Unlock()
n.currPos.ApplyUpdates(posUpdate)
n.wakeupUsers(n.joinedUsers(roomID), nil, n.currPos)
} }
func (n *Notifier) OnNewKeyChange( func (n *Notifier) OnNewKeyChange(
@ -176,9 +188,19 @@ func (n *Notifier) OnNewKeyChange(
) { ) {
n.streamLock.Lock() n.streamLock.Lock()
defer n.streamLock.Unlock() defer n.streamLock.Unlock()
latestPos := n.currPos.WithUpdates(posUpdate)
n.currPos = latestPos n.currPos.ApplyUpdates(posUpdate)
n.wakeupUsers([]string{wakeUserID}, nil, latestPos) n.wakeupUsers([]string{wakeUserID}, nil, n.currPos)
}
func (n *Notifier) OnNewInvite(
posUpdate types.StreamingToken, wakeUserID string,
) {
n.streamLock.Lock()
defer n.streamLock.Unlock()
n.currPos.ApplyUpdates(posUpdate)
n.wakeupUsers([]string{wakeUserID}, nil, n.currPos)
} }
// GetListener returns a UserStreamListener that can be used to wait for // GetListener returns a UserStreamListener that can be used to wait for

View file

@ -335,7 +335,7 @@ func waitForEvents(n *Notifier, req syncRequest) (types.StreamingToken, error) {
return types.StreamingToken{}, fmt.Errorf( return types.StreamingToken{}, fmt.Errorf(
"waitForEvents timed out waiting for %s (pos=%v)", req.device.UserID, req.since, "waitForEvents timed out waiting for %s (pos=%v)", req.device.UserID, req.since,
) )
case <-listener.GetNotifyChannel(*req.since): case <-listener.GetNotifyChannel(req.since):
p := listener.GetSyncPosition() p := listener.GetSyncPosition()
return p, nil return p, nil
} }
@ -365,7 +365,7 @@ func newTestSyncRequest(userID, deviceID string, since types.StreamingToken) syn
ID: deviceID, ID: deviceID,
}, },
timeout: 1 * time.Minute, timeout: 1 * time.Minute,
since: &since, since: since,
wantFullState: false, wantFullState: false,
limit: DefaultTimelineLimit, limit: DefaultTimelineLimit,
log: util.GetLogger(context.TODO()), log: util.GetLogger(context.TODO()),

View file

@ -46,7 +46,7 @@ type syncRequest struct {
device userapi.Device device userapi.Device
limit int limit int
timeout time.Duration timeout time.Duration
since *types.StreamingToken // nil means that no since token was supplied since types.StreamingToken // nil means that no since token was supplied
wantFullState bool wantFullState bool
log *log.Entry log *log.Entry
} }
@ -55,17 +55,13 @@ func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Dat
timeout := getTimeout(req.URL.Query().Get("timeout")) timeout := getTimeout(req.URL.Query().Get("timeout"))
fullState := req.URL.Query().Get("full_state") fullState := req.URL.Query().Get("full_state")
wantFullState := fullState != "" && fullState != "false" wantFullState := fullState != "" && fullState != "false"
var since *types.StreamingToken since, sinceStr := types.StreamingToken{}, req.URL.Query().Get("since")
sinceStr := req.URL.Query().Get("since")
if sinceStr != "" { if sinceStr != "" {
tok, err := types.NewStreamTokenFromString(sinceStr) var err error
since, err = types.NewStreamTokenFromString(sinceStr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
since = &tok
}
if since == nil {
since = &types.StreamingToken{}
} }
timelineLimit := DefaultTimelineLimit timelineLimit := DefaultTimelineLimit
// TODO: read from stored filters too // TODO: read from stored filters too

View file

@ -185,13 +185,12 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
// respond with, so we skip the return an go back to waiting for content to // respond with, so we skip the return an go back to waiting for content to
// be sent down or the request timing out. // be sent down or the request timing out.
var hasTimedOut bool var hasTimedOut bool
sincePos := *syncReq.since sincePos := syncReq.since
for { for {
select { select {
// Wait for notifier to wake us up // Wait for notifier to wake us up
case <-userStreamListener.GetNotifyChannel(sincePos): case <-userStreamListener.GetNotifyChannel(sincePos):
currPos = userStreamListener.GetSyncPosition() currPos = userStreamListener.GetSyncPosition()
sincePos = currPos
// Or for timeout to expire // Or for timeout to expire
case <-timer.C: case <-timer.C:
// We just need to ensure we get out of the select after reaching the // We just need to ensure we get out of the select after reaching the
@ -279,7 +278,7 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Strea
res := types.NewResponse() res := types.NewResponse()
// See if we have any new tasks to do for the send-to-device messaging. // See if we have any new tasks to do for the send-to-device messaging.
events, updates, deletions, err := rp.db.SendToDeviceUpdatesForSync(req.ctx, req.device.UserID, req.device.ID, *req.since) lastPos, events, updates, deletions, err := rp.db.SendToDeviceUpdatesForSync(req.ctx, req.device.UserID, req.device.ID, req.since)
if err != nil { if err != nil {
return nil, fmt.Errorf("rp.db.SendToDeviceUpdatesForSync: %w", err) return nil, fmt.Errorf("rp.db.SendToDeviceUpdatesForSync: %w", err)
} }
@ -291,7 +290,7 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Strea
return res, fmt.Errorf("rp.db.CompleteSync: %w", err) return res, fmt.Errorf("rp.db.CompleteSync: %w", err)
} }
} else { } else {
res, err = rp.db.IncrementalSync(req.ctx, res, req.device, *req.since, latestPos, req.limit, req.wantFullState) res, err = rp.db.IncrementalSync(req.ctx, res, req.device, req.since, latestPos, req.limit, req.wantFullState)
if err != nil { if err != nil {
return res, fmt.Errorf("rp.db.IncrementalSync: %w", err) return res, fmt.Errorf("rp.db.IncrementalSync: %w", err)
} }
@ -302,7 +301,7 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Strea
if err != nil { if err != nil {
return res, fmt.Errorf("rp.appendAccountData: %w", err) return res, fmt.Errorf("rp.appendAccountData: %w", err)
} }
res, err = rp.appendDeviceLists(res, req.device.UserID, *req.since, latestPos) res, err = rp.appendDeviceLists(res, req.device.UserID, req.since, latestPos)
if err != nil { if err != nil {
return res, fmt.Errorf("rp.appendDeviceLists: %w", err) return res, fmt.Errorf("rp.appendDeviceLists: %w", err)
} }
@ -316,7 +315,7 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Strea
// Then add the updates into the sync response. // Then add the updates into the sync response.
if len(updates) > 0 || len(deletions) > 0 { if len(updates) > 0 || len(deletions) > 0 {
// Handle the updates and deletions in the database. // Handle the updates and deletions in the database.
err = rp.db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, *req.since) err = rp.db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, req.since)
if err != nil { if err != nil {
return res, fmt.Errorf("rp.db.CleanSendToDeviceUpdates: %w", err) return res, fmt.Errorf("rp.db.CleanSendToDeviceUpdates: %w", err)
} }
@ -326,15 +325,9 @@ func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.Strea
for _, event := range events { for _, event := range events {
res.ToDevice.Events = append(res.ToDevice.Events, event.SendToDeviceEvent) res.ToDevice.Events = append(res.ToDevice.Events, event.SendToDeviceEvent)
} }
// Get the next_batch from the sync response and increase the
// EDU counter.
if pos, perr := types.NewStreamTokenFromString(res.NextBatch); perr == nil {
pos.SendToDevicePosition++
res.NextBatch = pos.String()
}
} }
res.NextBatch.SendToDevicePosition = lastPos
return res, err return res, err
} }
@ -464,7 +457,7 @@ func (rp *RequestPool) appendAccountData(
// or timeout=0, or full_state=true, in any of the cases the request should // or timeout=0, or full_state=true, in any of the cases the request should
// return immediately. // return immediately.
func (rp *RequestPool) shouldReturnImmediately(syncReq *syncRequest) bool { func (rp *RequestPool) shouldReturnImmediately(syncReq *syncRequest) bool {
if syncReq.since == nil || syncReq.timeout == 0 || syncReq.wantFullState { if syncReq.since.IsEmpty() || syncReq.timeout == 0 || syncReq.wantFullState {
return true return true
} }
waiting, werr := rp.db.SendToDeviceUpdatesWaiting(context.TODO(), syncReq.device.UserID, syncReq.device.ID) waiting, werr := rp.db.SendToDeviceUpdatesWaiting(context.TODO(), syncReq.device.UserID, syncReq.device.ID)

View file

@ -113,14 +113,27 @@ type StreamingToken struct {
TypingPosition StreamPosition TypingPosition StreamPosition
ReceiptPosition StreamPosition ReceiptPosition StreamPosition
SendToDevicePosition StreamPosition SendToDevicePosition StreamPosition
InvitePosition StreamPosition
DeviceListPosition LogPosition DeviceListPosition LogPosition
} }
// This will be used as a fallback by json.Marshal.
func (s StreamingToken) MarshalText() ([]byte, error) {
return []byte(s.String()), nil
}
// This will be used as a fallback by json.Unmarshal.
func (s *StreamingToken) UnmarshalText(text []byte) (err error) {
*s, err = NewStreamTokenFromString(string(text))
return err
}
func (t StreamingToken) String() string { func (t StreamingToken) String() string {
posStr := fmt.Sprintf( posStr := fmt.Sprintf(
"s%d_%d_%d_%d", "s%d_%d_%d_%d_%d",
t.PDUPosition, t.TypingPosition, t.PDUPosition, t.TypingPosition,
t.ReceiptPosition, t.SendToDevicePosition, t.ReceiptPosition, t.SendToDevicePosition,
t.InvitePosition,
) )
if dl := t.DeviceListPosition; !dl.IsEmpty() { if dl := t.DeviceListPosition; !dl.IsEmpty() {
posStr += fmt.Sprintf(".dl-%d-%d", dl.Partition, dl.Offset) posStr += fmt.Sprintf(".dl-%d-%d", dl.Partition, dl.Offset)
@ -139,6 +152,8 @@ func (t *StreamingToken) IsAfter(other StreamingToken) bool {
return true return true
case t.SendToDevicePosition > other.SendToDevicePosition: case t.SendToDevicePosition > other.SendToDevicePosition:
return true return true
case t.InvitePosition > other.InvitePosition:
return true
case t.DeviceListPosition.IsAfter(&other.DeviceListPosition): case t.DeviceListPosition.IsAfter(&other.DeviceListPosition):
return true return true
} }
@ -146,35 +161,59 @@ func (t *StreamingToken) IsAfter(other StreamingToken) bool {
} }
func (t *StreamingToken) IsEmpty() bool { func (t *StreamingToken) IsEmpty() bool {
return t == nil || t.PDUPosition+t.TypingPosition+t.ReceiptPosition+t.SendToDevicePosition == 0 && t.DeviceListPosition.IsEmpty() return t == nil || t.PDUPosition+t.TypingPosition+t.ReceiptPosition+t.SendToDevicePosition+t.InvitePosition == 0 && t.DeviceListPosition.IsEmpty()
} }
// WithUpdates returns a copy of the StreamingToken with updates applied from another StreamingToken. // WithUpdates returns a copy of the StreamingToken with updates applied from another StreamingToken.
// If the latter StreamingToken contains a field that is not 0, it is considered an update, // If the latter StreamingToken contains a field that is not 0, it is considered an update,
// and its value will replace the corresponding value in the StreamingToken on which WithUpdates is called. // and its value will replace the corresponding value in the StreamingToken on which WithUpdates is called.
// If the other token has a log, they will replace any existing log on this token. // If the other token has a log, they will replace any existing log on this token.
func (t *StreamingToken) WithUpdates(other StreamingToken) (ret StreamingToken) { func (t *StreamingToken) WithUpdates(other StreamingToken) StreamingToken {
ret = *t ret := *t
switch { ret.ApplyUpdates(other)
case other.PDUPosition > 0:
ret.PDUPosition = other.PDUPosition
case other.TypingPosition > 0:
ret.TypingPosition = other.TypingPosition
case other.ReceiptPosition > 0:
ret.ReceiptPosition = other.ReceiptPosition
case other.SendToDevicePosition > 0:
ret.SendToDevicePosition = other.SendToDevicePosition
case other.DeviceListPosition.Offset > 0:
ret.DeviceListPosition = other.DeviceListPosition
}
return ret return ret
} }
// ApplyUpdates applies any changes from the supplied StreamingToken. If the supplied
// streaming token contains any positions that are not 0, they are considered updates
// and will overwrite the value in the token.
func (t *StreamingToken) ApplyUpdates(other StreamingToken) {
if other.PDUPosition > 0 {
t.PDUPosition = other.PDUPosition
}
if other.TypingPosition > 0 {
t.TypingPosition = other.TypingPosition
}
if other.ReceiptPosition > 0 {
t.ReceiptPosition = other.ReceiptPosition
}
if other.SendToDevicePosition > 0 {
t.SendToDevicePosition = other.SendToDevicePosition
}
if other.InvitePosition > 0 {
t.InvitePosition = other.InvitePosition
}
if other.DeviceListPosition.Offset > 0 {
t.DeviceListPosition = other.DeviceListPosition
}
}
type TopologyToken struct { type TopologyToken struct {
Depth StreamPosition Depth StreamPosition
PDUPosition StreamPosition PDUPosition StreamPosition
} }
// This will be used as a fallback by json.Marshal.
func (t TopologyToken) MarshalText() ([]byte, error) {
return []byte(t.String()), nil
}
// This will be used as a fallback by json.Unmarshal.
func (t *TopologyToken) UnmarshalText(text []byte) (err error) {
*t, err = NewTopologyTokenFromString(string(text))
return err
}
func (t *TopologyToken) StreamToken() StreamingToken { func (t *TopologyToken) StreamToken() StreamingToken {
return StreamingToken{ return StreamingToken{
PDUPosition: t.PDUPosition, PDUPosition: t.PDUPosition,
@ -247,7 +286,7 @@ func NewStreamTokenFromString(tok string) (token StreamingToken, err error) {
} }
categories := strings.Split(tok[1:], ".") categories := strings.Split(tok[1:], ".")
parts := strings.Split(categories[0], "_") parts := strings.Split(categories[0], "_")
var positions [4]StreamPosition var positions [5]StreamPosition
for i, p := range parts { for i, p := range parts {
if i > len(positions) { if i > len(positions) {
break break
@ -264,6 +303,7 @@ func NewStreamTokenFromString(tok string) (token StreamingToken, err error) {
TypingPosition: positions[1], TypingPosition: positions[1],
ReceiptPosition: positions[2], ReceiptPosition: positions[2],
SendToDevicePosition: positions[3], SendToDevicePosition: positions[3],
InvitePosition: positions[4],
} }
// dl-0-1234 // dl-0-1234
// $log_name-$partition-$offset // $log_name-$partition-$offset
@ -302,7 +342,7 @@ type PrevEventRef struct {
// Response represents a /sync API response. See https://matrix.org/docs/spec/client_server/r0.2.0.html#get-matrix-client-r0-sync // Response represents a /sync API response. See https://matrix.org/docs/spec/client_server/r0.2.0.html#get-matrix-client-r0-sync
type Response struct { type Response struct {
NextBatch string `json:"next_batch"` NextBatch StreamingToken `json:"next_batch"`
AccountData struct { AccountData struct {
Events []gomatrixserverlib.ClientEvent `json:"events"` Events []gomatrixserverlib.ClientEvent `json:"events"`
} `json:"account_data,omitempty"` } `json:"account_data,omitempty"`
@ -366,7 +406,7 @@ type JoinResponse struct {
Timeline struct { Timeline struct {
Events []gomatrixserverlib.ClientEvent `json:"events"` Events []gomatrixserverlib.ClientEvent `json:"events"`
Limited bool `json:"limited"` Limited bool `json:"limited"`
PrevBatch string `json:"prev_batch"` PrevBatch *TopologyToken `json:"prev_batch,omitempty"`
} `json:"timeline"` } `json:"timeline"`
Ephemeral struct { Ephemeral struct {
Events []gomatrixserverlib.ClientEvent `json:"events"` Events []gomatrixserverlib.ClientEvent `json:"events"`
@ -424,7 +464,7 @@ type LeaveResponse struct {
Timeline struct { Timeline struct {
Events []gomatrixserverlib.ClientEvent `json:"events"` Events []gomatrixserverlib.ClientEvent `json:"events"`
Limited bool `json:"limited"` Limited bool `json:"limited"`
PrevBatch string `json:"prev_batch"` PrevBatch *TopologyToken `json:"prev_batch,omitempty"`
} `json:"timeline"` } `json:"timeline"`
} }

View file

@ -10,10 +10,10 @@ import (
func TestNewSyncTokenWithLogs(t *testing.T) { func TestNewSyncTokenWithLogs(t *testing.T) {
tests := map[string]*StreamingToken{ tests := map[string]*StreamingToken{
"s4_0_0_0": { "s4_0_0_0_0": {
PDUPosition: 4, PDUPosition: 4,
}, },
"s4_0_0_0.dl-0-123": { "s4_0_0_0_0.dl-0-123": {
PDUPosition: 4, PDUPosition: 4,
DeviceListPosition: LogPosition{ DeviceListPosition: LogPosition{
Partition: 0, Partition: 0,
@ -42,10 +42,10 @@ func TestNewSyncTokenWithLogs(t *testing.T) {
func TestSyncTokens(t *testing.T) { func TestSyncTokens(t *testing.T) {
shouldPass := map[string]string{ shouldPass := map[string]string{
"s4_0_0_0": StreamingToken{4, 0, 0, 0, LogPosition{}}.String(), "s4_0_0_0_0": StreamingToken{4, 0, 0, 0, 0, LogPosition{}}.String(),
"s3_1_0_0.dl-1-2": StreamingToken{3, 1, 0, 0, LogPosition{1, 2}}.String(), "s3_1_0_0_0.dl-1-2": StreamingToken{3, 1, 0, 0, 0, LogPosition{1, 2}}.String(),
"s3_1_2_3": StreamingToken{3, 1, 2, 3, LogPosition{}}.String(), "s3_1_2_3_5": StreamingToken{3, 1, 2, 3, 5, LogPosition{}}.String(),
"t3_1": TopologyToken{3, 1}.String(), "t3_1": TopologyToken{3, 1}.String(),
} }
for a, b := range shouldPass { for a, b := range shouldPass {

View file

@ -141,18 +141,14 @@ New users appear in /keys/changes
Local delete device changes appear in v2 /sync Local delete device changes appear in v2 /sync
Local new device changes appear in v2 /sync Local new device changes appear in v2 /sync
Local update device changes appear in v2 /sync Local update device changes appear in v2 /sync
Users receive device_list updates for their own devices
Get left notifs for other users in sync and /keys/changes when user leaves Get left notifs for other users in sync and /keys/changes when user leaves
Local device key changes get to remote servers Local device key changes get to remote servers
Local device key changes get to remote servers with correct prev_id Local device key changes get to remote servers with correct prev_id
Server correctly handles incoming m.device_list_update Server correctly handles incoming m.device_list_update
Device deletion propagates over federation
If remote user leaves room, changes device and rejoins we see update in sync If remote user leaves room, changes device and rejoins we see update in sync
If remote user leaves room, changes device and rejoins we see update in /keys/changes If remote user leaves room, changes device and rejoins we see update in /keys/changes
If remote user leaves room we no longer receive device updates If remote user leaves room we no longer receive device updates
If a device list update goes missing, the server resyncs on the next one If a device list update goes missing, the server resyncs on the next one
Get left notifs in sync and /keys/changes when other user leaves
Can query remote device keys using POST after notification
Server correctly resyncs when client query keys and there is no remote cache Server correctly resyncs when client query keys and there is no remote cache
Server correctly resyncs when server leaves and rejoins a room Server correctly resyncs when server leaves and rejoins a room
Device list doesn't change if remote server is down Device list doesn't change if remote server is down