Send-to-device support (#1072)

* Groundwork for send-to-device messaging

* Update sample config

* Add unstable routing for now

* Send to device consumer in sync API

* Start the send-to-device consumer

* fix indentation in dendrite-config.yaml

* Create send-to-device database tables, other tweaks

* Add some logic for send-to-device messages, add them into sync stream

* Handle incoming send-to-device messages, count them with EDU stream pos

* Undo changes to test

* pq.Array

* Fix sync

* Logging

* Fix a couple of transaction things, fix client API

* Add send-to-device test, hopefully fix bugs

* Comments

* Refactor a bit

* Fix schema

* Fix queries

* Debug logging

* Fix storing and retrieving of send-to-device messages

* Try to avoid database locks

* Update sync position

* Use latest sync position

* Jiggle about sync a bit

* Fix tests

* Break out the retrieval from the update/delete behaviour

* Comments

* nolint on getResponseWithPDUsForCompleteSync

* Try to line up sync tokens again

* Implement wildcard

* Add all send-to-device tests to whitelist, what could possibly go wrong?

* Only care about wildcard when targeted locally

* Deduplicate transactions

* Handle tokens properly, return immediately if waiting send-to-device messages

* Fix sync

* Update sytest-whitelist

* Fix copyright notice (need to do more of this)

* Comments, copyrights

* Return errors from Do, fix dendritejs

* Review comments

* Comments

* Constructor for TransactionWriter

* defletions

* Update gomatrixserverlib, sytest-blacklist
This commit is contained in:
Neil Alexander 2020-06-01 17:50:19 +01:00 committed by GitHub
parent 1f43c24f86
commit a5d822004d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
39 changed files with 1302 additions and 60 deletions

View file

@ -0,0 +1,113 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package consumers
import (
"context"
"encoding/json"
"github.com/Shopify/sarama"
"github.com/matrix-org/dendrite/eduserver/api"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/config"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/sync"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
log "github.com/sirupsen/logrus"
)
// OutputSendToDeviceEventConsumer consumes events that originated in the EDU server.
type OutputSendToDeviceEventConsumer struct {
sendToDeviceConsumer *internal.ContinualConsumer
db storage.Database
serverName gomatrixserverlib.ServerName // our server name
notifier *sync.Notifier
}
// NewOutputSendToDeviceEventConsumer creates a new OutputSendToDeviceEventConsumer.
// Call Start() to begin consuming from the EDU server.
func NewOutputSendToDeviceEventConsumer(
cfg *config.Dendrite,
kafkaConsumer sarama.Consumer,
n *sync.Notifier,
store storage.Database,
) *OutputSendToDeviceEventConsumer {
consumer := internal.ContinualConsumer{
Topic: string(cfg.Kafka.Topics.OutputSendToDeviceEvent),
Consumer: kafkaConsumer,
PartitionStore: store,
}
s := &OutputSendToDeviceEventConsumer{
sendToDeviceConsumer: &consumer,
db: store,
serverName: cfg.Matrix.ServerName,
notifier: n,
}
consumer.ProcessMessage = s.onMessage
return s
}
// Start consuming from EDU api
func (s *OutputSendToDeviceEventConsumer) Start() error {
return s.sendToDeviceConsumer.Start()
}
func (s *OutputSendToDeviceEventConsumer) onMessage(msg *sarama.ConsumerMessage) error {
var output api.OutputSendToDeviceEvent
if err := json.Unmarshal(msg.Value, &output); err != nil {
// If the message was invalid, log it and move on to the next message in the stream
log.WithError(err).Errorf("EDU server output log: message parse failure")
return err
}
_, domain, err := gomatrixserverlib.SplitID('@', output.UserID)
if err != nil {
return err
}
if domain != s.serverName {
return nil
}
util.GetLogger(context.TODO()).WithFields(log.Fields{
"sender": output.Sender,
"user_id": output.UserID,
"device_id": output.DeviceID,
"event_type": output.Type,
}).Info("sync API received send-to-device event from EDU server")
streamPos := s.db.AddSendToDevice()
_, err = s.db.StoreNewSendForDeviceMessage(
context.TODO(), streamPos, output.UserID, output.DeviceID, output.SendToDeviceEvent,
)
if err != nil {
log.WithError(err).Errorf("failed to store send-to-device message")
return err
}
s.notifier.OnNewSendToDevice(
output.UserID,
[]string{output.DeviceID},
types.NewStreamToken(0, streamPos),
)
return nil
}

View file

@ -55,10 +55,12 @@ type Database interface {
// sync response for the given user. Events returned will include any client
// transaction IDs associated with the given device. These transaction IDs come
// from when the device sent the event via an API that included a transaction
// ID.
IncrementalSync(ctx context.Context, device authtypes.Device, fromPos, toPos types.StreamingToken, numRecentEventsPerRoom int, wantFullState bool) (*types.Response, error)
// CompleteSync returns a complete /sync API response for the given user.
CompleteSync(ctx context.Context, device authtypes.Device, numRecentEventsPerRoom int) (*types.Response, error)
// ID. A response object must be provided for IncrementaSync to populate - it
// will not create one.
IncrementalSync(ctx context.Context, res *types.Response, device authtypes.Device, fromPos, toPos types.StreamingToken, numRecentEventsPerRoom int, wantFullState bool) (*types.Response, error)
// CompleteSync returns a complete /sync API response for the given user. A response object
// must be provided for CompleteSync to populate - it will not create one.
CompleteSync(ctx context.Context, res *types.Response, device authtypes.Device, numRecentEventsPerRoom int) (*types.Response, error)
// GetAccountDataInRange returns all account data for a given user inserted or
// updated between two given positions
// Returns a map following the format data[roomID] = []dataTypes
@ -104,4 +106,26 @@ type Database interface {
StreamEventsToEvents(device *authtypes.Device, in []types.StreamEvent) []gomatrixserverlib.HeaderedEvent
// SyncStreamPosition returns the latest position in the sync stream. Returns 0 if there are no events yet.
SyncStreamPosition(ctx context.Context) (types.StreamPosition, error)
// AddSendToDevice increases the EDU position in the cache and returns the stream position.
AddSendToDevice() types.StreamPosition
// SendToDeviceUpdatesForSync returns a list of send-to-device updates. It returns three lists:
// - "events": a list of send-to-device events that should be included in the sync
// - "changes": a list of send-to-device events that should be updated in the database by
// CleanSendToDeviceUpdates
// - "deletions": a list of send-to-device events which have been confirmed as sent and
// can be deleted altogether by CleanSendToDeviceUpdates
// The token supplied should be the current requested sync token, e.g. from the "since"
// parameter.
SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, token types.StreamingToken) (events []types.SendToDeviceEvent, changes []types.SendToDeviceNID, deletions []types.SendToDeviceNID, err error)
// StoreNewSendForDeviceMessage stores a new send-to-device event for a user's device.
StoreNewSendForDeviceMessage(ctx context.Context, streamPos types.StreamPosition, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent) (types.StreamPosition, error)
// CleanSendToDeviceUpdates will update or remove any send-to-device updates based on the
// result to a previous call to SendDeviceUpdatesForSync. This is separate as it allows
// SendToDeviceUpdatesForSync to be called multiple times if needed (e.g. before and after
// starting to wait for an incremental sync with timeout).
// The token supplied should be the current requested sync token, e.g. from the "since"
// parameter.
CleanSendToDeviceUpdates(ctx context.Context, toUpdate, toDelete []types.SendToDeviceNID, token types.StreamingToken) (err error)
// SendToDeviceUpdatesWaiting returns true if there are send-to-device updates waiting to be sent.
SendToDeviceUpdatesWaiting(ctx context.Context, userID, deviceID string) (bool, error)
}

View file

@ -0,0 +1,171 @@
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"encoding/json"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
)
const sendToDeviceSchema = `
CREATE SEQUENCE IF NOT EXISTS syncapi_send_to_device_id;
-- Stores send-to-device messages.
CREATE TABLE IF NOT EXISTS syncapi_send_to_device (
-- The ID that uniquely identifies this message.
id BIGINT PRIMARY KEY DEFAULT nextval('syncapi_send_to_device_id'),
-- The user ID to send the message to.
user_id TEXT NOT NULL,
-- The device ID to send the message to.
device_id TEXT NOT NULL,
-- The event content JSON.
content TEXT NOT NULL,
-- The token that was supplied to the /sync at the time that this
-- message was included in a sync response, or NULL if we haven't
-- included it in a /sync response yet.
sent_by_token TEXT
);
`
const insertSendToDeviceMessageSQL = `
INSERT INTO syncapi_send_to_device (user_id, device_id, content)
VALUES ($1, $2, $3)
`
const countSendToDeviceMessagesSQL = `
SELECT COUNT(*)
FROM syncapi_send_to_device
WHERE user_id = $1 AND device_id = $2
`
const selectSendToDeviceMessagesSQL = `
SELECT id, user_id, device_id, content, sent_by_token
FROM syncapi_send_to_device
WHERE user_id = $1 AND device_id = $2
ORDER BY id DESC
`
const updateSentSendToDeviceMessagesSQL = `
UPDATE syncapi_send_to_device SET sent_by_token = $1
WHERE id = ANY($2)
`
const deleteSendToDeviceMessagesSQL = `
DELETE FROM syncapi_send_to_device WHERE id = ANY($1)
`
type sendToDeviceStatements struct {
insertSendToDeviceMessageStmt *sql.Stmt
countSendToDeviceMessagesStmt *sql.Stmt
selectSendToDeviceMessagesStmt *sql.Stmt
updateSentSendToDeviceMessagesStmt *sql.Stmt
deleteSendToDeviceMessagesStmt *sql.Stmt
}
func NewPostgresSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
s := &sendToDeviceStatements{}
_, err := db.Exec(sendToDeviceSchema)
if err != nil {
return nil, err
}
if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil {
return nil, err
}
if s.countSendToDeviceMessagesStmt, err = db.Prepare(countSendToDeviceMessagesSQL); err != nil {
return nil, err
}
if s.selectSendToDeviceMessagesStmt, err = db.Prepare(selectSendToDeviceMessagesSQL); err != nil {
return nil, err
}
if s.updateSentSendToDeviceMessagesStmt, err = db.Prepare(updateSentSendToDeviceMessagesSQL); err != nil {
return nil, err
}
if s.deleteSendToDeviceMessagesStmt, err = db.Prepare(deleteSendToDeviceMessagesSQL); err != nil {
return nil, err
}
return s, nil
}
func (s *sendToDeviceStatements) InsertSendToDeviceMessage(
ctx context.Context, txn *sql.Tx, userID, deviceID, content string,
) (err error) {
_, err = internal.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content)
return
}
func (s *sendToDeviceStatements) CountSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, userID, deviceID string,
) (count int, err error) {
row := internal.TxStmt(txn, s.countSendToDeviceMessagesStmt).QueryRowContext(ctx, userID, deviceID)
if err = row.Scan(&count); err != nil {
return
}
return count, nil
}
func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, userID, deviceID string,
) (events []types.SendToDeviceEvent, err error) {
rows, err := internal.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID)
if err != nil {
return
}
defer internal.CloseAndLogIfError(ctx, rows, "SelectSendToDeviceMessages: rows.close() failed")
for rows.Next() {
var id types.SendToDeviceNID
var userID, deviceID, content string
var sentByToken *string
if err = rows.Scan(&id, &userID, &deviceID, &content, &sentByToken); err != nil {
return
}
event := types.SendToDeviceEvent{
ID: id,
UserID: userID,
DeviceID: deviceID,
}
if err = json.Unmarshal([]byte(content), &event.SendToDeviceEvent); err != nil {
return
}
if sentByToken != nil {
if token, err := types.NewStreamTokenFromString(*sentByToken); err == nil {
event.SentByToken = &token
}
}
events = append(events, event)
}
return events, rows.Err()
}
func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID,
) (err error) {
_, err = txn.Stmt(s.updateSentSendToDeviceMessagesStmt).ExecContext(ctx, token, pq.Array(nids))
return
}
func (s *sendToDeviceStatements) DeleteSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID,
) (err error) {
_, err = txn.Stmt(s.deleteSendToDeviceMessagesStmt).ExecContext(ctx, pq.Array(nids))
return
}

View file

@ -69,6 +69,10 @@ func NewDatabase(dbDataSourceName string, dbProperties internal.DbProperties) (*
if err != nil {
return nil, err
}
sendToDevice, err := NewPostgresSendToDeviceTable(d.db)
if err != nil {
return nil, err
}
d.Database = shared.Database{
DB: d.db,
Invites: invites,
@ -77,6 +81,8 @@ func NewDatabase(dbDataSourceName string, dbProperties internal.DbProperties) (*
Topology: topology,
CurrentRoomState: currState,
BackwardExtremities: backwardExtremities,
SendToDevice: sendToDevice,
SendToDeviceWriter: internal.NewTransactionWriter(),
EDUCache: cache.New(),
}
return &d, nil

View file

@ -1,3 +1,17 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package shared
import (
@ -27,6 +41,8 @@ type Database struct {
Topology tables.Topology
CurrentRoomState tables.CurrentRoomState
BackwardExtremities tables.BackwardsExtremities
SendToDevice tables.SendToDevice
SendToDeviceWriter *internal.TransactionWriter
EDUCache *cache.EDUCache
}
@ -89,6 +105,10 @@ func (d *Database) RemoveTypingUser(
return types.StreamPosition(d.EDUCache.RemoveUser(userID, roomID))
}
func (d *Database) AddSendToDevice() types.StreamPosition {
return types.StreamPosition(d.EDUCache.AddSendToDeviceMessage())
}
func (d *Database) SetTypingTimeoutCallback(fn cache.TimeoutCallbackFn) {
d.EDUCache.SetTimeoutCallback(fn)
}
@ -528,14 +548,14 @@ func (d *Database) addEDUDeltaToResponse(
}
func (d *Database) IncrementalSync(
ctx context.Context,
ctx context.Context, res *types.Response,
device authtypes.Device,
fromPos, toPos types.StreamingToken,
numRecentEventsPerRoom int,
wantFullState bool,
) (*types.Response, error) {
nextBatchPos := fromPos.WithUpdates(toPos)
res := types.NewResponse(nextBatchPos)
res.NextBatch = nextBatchPos.String()
var joinedRoomIDs []string
var err error
@ -568,12 +588,12 @@ func (d *Database) IncrementalSync(
// getResponseWithPDUsForCompleteSync creates a response and adds all PDUs needed
// to it. It returns toPos and joinedRoomIDs for use of adding EDUs.
// nolint:nakedret
func (d *Database) getResponseWithPDUsForCompleteSync(
ctx context.Context,
ctx context.Context, res *types.Response,
userID string,
numRecentEventsPerRoom int,
) (
res *types.Response,
toPos types.StreamingToken,
joinedRoomIDs []string,
err error,
@ -604,7 +624,7 @@ func (d *Database) getResponseWithPDUsForCompleteSync(
To: toPos.PDUPosition(),
}
res = types.NewResponse(toPos)
res.NextBatch = toPos.String()
// Extract room state and recent events for all rooms the user is joined to.
joinedRoomIDs, err = d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join)
@ -662,14 +682,15 @@ func (d *Database) getResponseWithPDUsForCompleteSync(
}
succeeded = true
return res, toPos, joinedRoomIDs, err
return //res, toPos, joinedRoomIDs, err
}
func (d *Database) CompleteSync(
ctx context.Context, device authtypes.Device, numRecentEventsPerRoom int,
ctx context.Context, res *types.Response,
device authtypes.Device, numRecentEventsPerRoom int,
) (*types.Response, error) {
res, toPos, joinedRoomIDs, err := d.getResponseWithPDUsForCompleteSync(
ctx, device.UserID, numRecentEventsPerRoom,
toPos, joinedRoomIDs, err := d.getResponseWithPDUsForCompleteSync(
ctx, res, device.UserID, numRecentEventsPerRoom,
)
if err != nil {
return nil, err
@ -1028,6 +1049,115 @@ func (d *Database) currentStateStreamEventsForRoom(
return s, nil
}
func (d *Database) SendToDeviceUpdatesWaiting(
ctx context.Context, userID, deviceID string,
) (bool, error) {
count, err := d.SendToDevice.CountSendToDeviceMessages(ctx, nil, userID, deviceID)
if err != nil {
return false, err
}
return count > 0, nil
}
func (d *Database) AddSendToDeviceEvent(
ctx context.Context, txn *sql.Tx,
userID, deviceID, content string,
) error {
return d.SendToDevice.InsertSendToDeviceMessage(
ctx, txn, userID, deviceID, content,
)
}
func (d *Database) StoreNewSendForDeviceMessage(
ctx context.Context, streamPos types.StreamPosition, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent,
) (types.StreamPosition, error) {
j, err := json.Marshal(event)
if err != nil {
return streamPos, err
}
// Delegate the database write task to the SendToDeviceWriter. It'll guarantee
// that we don't lock the table for writes in more than one place.
err = d.SendToDeviceWriter.Do(d.DB, func(txn *sql.Tx) error {
return d.AddSendToDeviceEvent(
ctx, txn, userID, deviceID, string(j),
)
})
if err != nil {
return streamPos, err
}
return streamPos, nil
}
func (d *Database) SendToDeviceUpdatesForSync(
ctx context.Context,
userID, deviceID string,
token types.StreamingToken,
) ([]types.SendToDeviceEvent, []types.SendToDeviceNID, []types.SendToDeviceNID, error) {
// First of all, get our send-to-device updates for this user.
events, err := d.SendToDevice.SelectSendToDeviceMessages(ctx, nil, userID, deviceID)
if err != nil {
return nil, nil, nil, fmt.Errorf("d.SendToDevice.SelectSendToDeviceMessages: %w", err)
}
// If there's nothing to do then stop here.
if len(events) == 0 {
return nil, nil, nil, nil
}
// Work out whether we need to update any of the database entries.
toReturn := []types.SendToDeviceEvent{}
toUpdate := []types.SendToDeviceNID{}
toDelete := []types.SendToDeviceNID{}
for _, event := range events {
if event.SentByToken == nil {
// If the event has no sent-by token yet then we haven't attempted to send
// it. Record the current requested sync token in the database.
toUpdate = append(toUpdate, event.ID)
toReturn = append(toReturn, event)
event.SentByToken = &token
} else if token.IsAfter(*event.SentByToken) {
// The event had a sync token, therefore we've sent it before. The current
// sync token is now after the stored one so we can assume that the client
// successfully completed the previous sync (it would re-request it otherwise)
// so we can remove the entry from the database.
toDelete = append(toDelete, event.ID)
} else {
// It looks like the sync is being re-requested, maybe it timed out or
// failed. Re-send any that should have been acknowledged by now.
toReturn = append(toReturn, event)
}
}
return toReturn, toUpdate, toDelete, nil
}
func (d *Database) CleanSendToDeviceUpdates(
ctx context.Context,
toUpdate, toDelete []types.SendToDeviceNID,
token types.StreamingToken,
) (err error) {
if len(toUpdate) == 0 && len(toDelete) == 0 {
return nil
}
// If we need to write to the database then we'll ask the SendToDeviceWriter to
// do that for us. It'll guarantee that we don't lock the table for writes in
// more than one place.
err = d.SendToDeviceWriter.Do(d.DB, func(txn *sql.Tx) error {
// Delete any send-to-device messages marked for deletion.
if e := d.SendToDevice.DeleteSendToDeviceMessages(ctx, txn, toDelete); e != nil {
return fmt.Errorf("d.SendToDevice.DeleteSendToDeviceMessages: %w", e)
}
// Now update any outstanding send-to-device messages with the new sync token.
if e := d.SendToDevice.UpdateSentSendToDeviceMessages(ctx, txn, token.String(), toUpdate); e != nil {
return fmt.Errorf("d.SendToDevice.UpdateSentSendToDeviceMessages: %w", err)
}
return nil
})
return
}
// There may be some overlap where events in stateEvents are already in recentEvents, so filter
// them out so we don't include them twice in the /sync response. They should be in recentEvents
// only, so clients get to the correct state once they have rolled forward.

View file

@ -0,0 +1,172 @@
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"encoding/json"
"strings"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
)
const sendToDeviceSchema = `
-- Stores send-to-device messages.
CREATE TABLE IF NOT EXISTS syncapi_send_to_device (
-- The ID that uniquely identifies this message.
id INTEGER PRIMARY KEY AUTOINCREMENT,
-- The user ID to send the message to.
user_id TEXT NOT NULL,
-- The device ID to send the message to.
device_id TEXT NOT NULL,
-- The event content JSON.
content TEXT NOT NULL,
-- The token that was supplied to the /sync at the time that this
-- message was included in a sync response, or NULL if we haven't
-- included it in a /sync response yet.
sent_by_token TEXT
);
`
const insertSendToDeviceMessageSQL = `
INSERT INTO syncapi_send_to_device (user_id, device_id, content)
VALUES ($1, $2, $3)
`
const countSendToDeviceMessagesSQL = `
SELECT COUNT(*)
FROM syncapi_send_to_device
WHERE user_id = $1 AND device_id = $2
`
const selectSendToDeviceMessagesSQL = `
SELECT id, user_id, device_id, content, sent_by_token
FROM syncapi_send_to_device
WHERE user_id = $1 AND device_id = $2
ORDER BY id DESC
`
const updateSentSendToDeviceMessagesSQL = `
UPDATE syncapi_send_to_device SET sent_by_token = $1
WHERE id IN ($2)
`
const deleteSendToDeviceMessagesSQL = `
DELETE FROM syncapi_send_to_device WHERE id IN ($1)
`
type sendToDeviceStatements struct {
insertSendToDeviceMessageStmt *sql.Stmt
selectSendToDeviceMessagesStmt *sql.Stmt
countSendToDeviceMessagesStmt *sql.Stmt
}
func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
s := &sendToDeviceStatements{}
_, err := db.Exec(sendToDeviceSchema)
if err != nil {
return nil, err
}
if s.countSendToDeviceMessagesStmt, err = db.Prepare(countSendToDeviceMessagesSQL); err != nil {
return nil, err
}
if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil {
return nil, err
}
if s.selectSendToDeviceMessagesStmt, err = db.Prepare(selectSendToDeviceMessagesSQL); err != nil {
return nil, err
}
return s, nil
}
func (s *sendToDeviceStatements) InsertSendToDeviceMessage(
ctx context.Context, txn *sql.Tx, userID, deviceID, content string,
) (err error) {
_, err = internal.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content)
return
}
func (s *sendToDeviceStatements) CountSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, userID, deviceID string,
) (count int, err error) {
row := internal.TxStmt(txn, s.countSendToDeviceMessagesStmt).QueryRowContext(ctx, userID, deviceID)
if err = row.Scan(&count); err != nil {
return
}
return count, nil
}
func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, userID, deviceID string,
) (events []types.SendToDeviceEvent, err error) {
rows, err := internal.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID)
if err != nil {
return
}
defer internal.CloseAndLogIfError(ctx, rows, "SelectSendToDeviceMessages: rows.close() failed")
for rows.Next() {
var id types.SendToDeviceNID
var userID, deviceID, content string
var sentByToken *string
if err = rows.Scan(&id, &userID, &deviceID, &content, &sentByToken); err != nil {
return
}
event := types.SendToDeviceEvent{
ID: id,
UserID: userID,
DeviceID: deviceID,
}
if err = json.Unmarshal([]byte(content), &event.SendToDeviceEvent); err != nil {
return
}
if sentByToken != nil {
if token, err := types.NewStreamTokenFromString(*sentByToken); err == nil {
event.SentByToken = &token
}
}
events = append(events, event)
}
return events, rows.Err()
}
func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID,
) (err error) {
query := strings.Replace(updateSentSendToDeviceMessagesSQL, "($2)", internal.QueryVariadic(1+len(nids)), 1)
params := make([]interface{}, 1+len(nids))
params[0] = token
for k, v := range nids {
params[k+1] = v
}
_, err = txn.ExecContext(ctx, query, params...)
return
}
func (s *sendToDeviceStatements) DeleteSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID,
) (err error) {
query := strings.Replace(deleteSendToDeviceMessagesSQL, "($1)", internal.QueryVariadic(len(nids)), 1)
params := make([]interface{}, 1+len(nids))
for k, v := range nids {
params[k] = v
}
_, err = txn.ExecContext(ctx, query, params...)
return
}

View file

@ -95,6 +95,10 @@ func (d *SyncServerDatasource) prepare() (err error) {
if err != nil {
return err
}
sendToDevice, err := NewSqliteSendToDeviceTable(d.db)
if err != nil {
return err
}
d.Database = shared.Database{
DB: d.db,
Invites: invites,
@ -103,6 +107,8 @@ func (d *SyncServerDatasource) prepare() (err error) {
BackwardExtremities: bwExtrem,
CurrentRoomState: roomState,
Topology: topology,
SendToDevice: sendToDevice,
SendToDeviceWriter: internal.NewTransactionWriter(),
EDUCache: cache.New(),
}
return nil

View file

@ -3,6 +3,7 @@ package storage_test
import (
"context"
"crypto/ed25519"
"encoding/json"
"fmt"
"testing"
"time"
@ -157,7 +158,8 @@ func TestSyncResponse(t *testing.T) {
from := types.NewStreamToken( // pretend we are at the penultimate event
positions[len(positions)-2], types.StreamPosition(0),
)
return db.IncrementalSync(ctx, testUserDeviceA, from, latest, 5, false)
res := types.NewResponse()
return db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false)
},
WantTimeline: events[len(events)-1:],
},
@ -169,8 +171,9 @@ func TestSyncResponse(t *testing.T) {
from := types.NewStreamToken( // pretend we are 10 events behind
positions[len(positions)-11], types.StreamPosition(0),
)
res := types.NewResponse()
// limit is set to 5
return db.IncrementalSync(ctx, testUserDeviceA, from, latest, 5, false)
return db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false)
},
// want the last 5 events, NOT the last 10.
WantTimeline: events[len(events)-5:],
@ -180,8 +183,9 @@ func TestSyncResponse(t *testing.T) {
{
Name: "CompleteSync limited",
DoSync: func() (*types.Response, error) {
res := types.NewResponse()
// limit set to 5
return db.CompleteSync(ctx, testUserDeviceA, 5)
return db.CompleteSync(ctx, res, testUserDeviceA, 5)
},
// want the last 5 events
WantTimeline: events[len(events)-5:],
@ -193,7 +197,8 @@ func TestSyncResponse(t *testing.T) {
{
Name: "CompleteSync",
DoSync: func() (*types.Response, error) {
return db.CompleteSync(ctx, testUserDeviceA, len(events)+1)
res := types.NewResponse()
return db.CompleteSync(ctx, res, testUserDeviceA, len(events)+1)
},
WantTimeline: events,
// We want no state at all as that field in /sync is the delta between the token (beginning of time)
@ -234,7 +239,8 @@ func TestGetEventsInRangeWithPrevBatch(t *testing.T) {
positions[len(positions)-2], types.StreamPosition(0),
)
res, err := db.IncrementalSync(ctx, testUserDeviceA, from, latest, 5, false)
res := types.NewResponse()
res, err = db.IncrementalSync(ctx, res, testUserDeviceA, from, latest, 5, false)
if err != nil {
t.Fatalf("failed to IncrementalSync with latest token")
}
@ -512,6 +518,89 @@ func TestGetEventsInRangeWithEventsInsertedLikeBackfill(t *testing.T) {
}
}
func TestSendToDeviceBehaviour(t *testing.T) {
//t.Parallel()
db := MustCreateDatabase(t)
// At this point there should be no messages. We haven't sent anything
// yet.
events, updates, deletions, err := db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, 0))
if err != nil {
t.Fatal(err)
}
if len(events) != 0 || len(updates) != 0 || len(deletions) != 0 {
t.Fatal("first call should have no updates")
}
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, 0))
if err != nil {
return
}
// Try sending a message.
streamPos, err := db.StoreNewSendForDeviceMessage(ctx, types.StreamPosition(0), "alice", "one", gomatrixserverlib.SendToDeviceEvent{
Sender: "bob",
Type: "m.type",
Content: json.RawMessage("{}"),
})
if err != nil {
t.Fatal(err)
}
// At this point we should get exactly one message. We're sending the sync position
// that we were given from the update and the send-to-device update will be updated
// in the database to reflect that this was the sync position we sent the message at.
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos))
if err != nil {
t.Fatal(err)
}
if len(events) != 1 || len(updates) != 1 || len(deletions) != 0 {
t.Fatal("second call should have one update")
}
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos))
if err != nil {
return
}
// At this point we should still have one message because we haven't progressed the
// sync position yet. This is equivalent to the client failing to /sync and retrying
// with the same position.
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos))
if err != nil {
t.Fatal(err)
}
if len(events) != 1 || len(updates) != 0 || len(deletions) != 0 {
t.Fatal("third call should have one update still")
}
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos))
if err != nil {
return
}
// At this point we should now have no updates, because we've progressed the sync
// position. Therefore the update from before will not be sent again.
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos+1))
if err != nil {
t.Fatal(err)
}
if len(events) != 0 || len(updates) != 0 || len(deletions) != 1 {
t.Fatal("fourth call should have no updates")
}
err = db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, types.NewStreamToken(0, streamPos+1))
if err != nil {
return
}
// At this point we should still have no updates, because no new updates have been
// sent.
events, updates, deletions, err = db.SendToDeviceUpdatesForSync(ctx, "alice", "one", types.NewStreamToken(0, streamPos+2))
if err != nil {
t.Fatal(err)
}
if len(events) != 0 || len(updates) != 0 || len(deletions) != 0 {
t.Fatal("fifth call should have no updates")
}
}
func assertEventsEqual(t *testing.T, msg string, checkRoomID bool, gots []gomatrixserverlib.ClientEvent, wants []gomatrixserverlib.HeaderedEvent) {
if len(gots) != len(wants) {
t.Fatalf("%s response returned %d events, want %d", msg, len(gots), len(wants))

View file

@ -1,3 +1,17 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tables
import (
@ -94,3 +108,28 @@ type BackwardsExtremities interface {
// DeleteBackwardExtremity removes a backwards extremity for a room, if one existed.
DeleteBackwardExtremity(ctx context.Context, txn *sql.Tx, roomID, knownEventID string) (err error)
}
// SendToDevice tracks send-to-device messages which are sent to individual
// clients. Each message gets inserted into this table at the point that we
// receive it from the EDU server.
//
// We're supposed to try and do our best to deliver send-to-device messages
// once, but the only way that we can really guarantee that they have been
// delivered is if the client successfully requests the next sync as given
// in the next_batch. Each time the device syncs, we will request all of the
// updates that either haven't been sent yet, along with all updates that we
// *have* sent but we haven't confirmed to have been received yet. If it's the
// first time we're sending a given update then we update the table to say
// what the "since" parameter was when we tried to send it.
//
// When the client syncs again, if their "since" parameter is *later* than
// the recorded one, we drop the entry from the DB as it's "sent". If the
// sync parameter isn't later then we will keep including the updates in the
// sync response, as the client is seemingly trying to repeat the same /sync.
type SendToDevice interface {
InsertSendToDeviceMessage(ctx context.Context, txn *sql.Tx, userID, deviceID, content string) (err error)
SelectSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string) (events []types.SendToDeviceEvent, err error)
UpdateSentSendToDeviceMessages(ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID) (err error)
DeleteSendToDeviceMessages(ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID) (err error)
CountSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string) (count int, err error)
}

View file

@ -120,6 +120,18 @@ func (n *Notifier) OnNewEvent(
}
}
func (n *Notifier) OnNewSendToDevice(
userID string, deviceIDs []string,
posUpdate types.StreamingToken,
) {
n.streamLock.Lock()
defer n.streamLock.Unlock()
latestPos := n.currPos.WithUpdates(posUpdate)
n.currPos = latestPos
n.wakeupUserDevice(userID, deviceIDs, latestPos)
}
// GetListener returns a UserStreamListener that can be used to wait for
// updates for a user. Must be closed.
// notify for anything before sincePos
@ -189,8 +201,8 @@ func (n *Notifier) wakeupUsers(userIDs []string, newPos types.StreamingToken) {
// wakeupUserDevice will wake up the sync stream for a specific user device. Other
// device streams will be left alone.
// nolint:unused
func (n *Notifier) wakeupUserDevice(userDevices map[string]string, newPos types.StreamingToken) {
for userID, deviceID := range userDevices {
func (n *Notifier) wakeupUserDevice(userID string, deviceIDs []string, newPos types.StreamingToken) {
for _, deviceID := range deviceIDs {
if stream := n.fetchUserDeviceStream(userID, deviceID, false); stream != nil {
stream.Broadcast(newPos) // wake up all goroutines Wait()ing on this stream
}

View file

@ -172,7 +172,7 @@ func TestCorrectStreamWakeup(t *testing.T) {
time.Sleep(1 * time.Second)
wake := "two"
n.wakeupUserDevice(map[string]string{alice: wake}, syncPositionAfter)
n.wakeupUserDevice(alice, []string{wake}, syncPositionAfter)
if result := <-awoken; result != wake {
t.Fatalf("expected to wake %q, got %q", wake, result)

View file

@ -1,4 +1,6 @@
// Copyright 2017 Vector Creations Ltd
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -15,6 +17,7 @@
package sync
import (
"context"
"net/http"
"time"
@ -54,17 +57,18 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype
JSON: jsonerror.Unknown(err.Error()),
}
}
logger := util.GetLogger(req.Context()).WithFields(log.Fields{
"userID": device.UserID,
"deviceID": device.ID,
"since": syncReq.since,
"timeout": syncReq.timeout,
"limit": syncReq.limit,
"user_id": device.UserID,
"device_id": device.ID,
"since": syncReq.since,
"timeout": syncReq.timeout,
"limit": syncReq.limit,
})
currPos := rp.notifier.CurrentPosition()
if shouldReturnImmediately(syncReq) {
if rp.shouldReturnImmediately(syncReq) {
syncData, err = rp.currentSyncForUser(*syncReq, currPos)
if err != nil {
logger.WithError(err).Error("rp.currentSyncForUser failed")
@ -116,7 +120,6 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype
// response. This ensures that we don't waste the hard work
// of calculating the sync only to get timed out before we
// can respond
syncData, err = rp.currentSyncForUser(*syncReq, currPos)
if err != nil {
logger.WithError(err).Error("rp.currentSyncForUser failed")
@ -134,19 +137,59 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype
}
func (rp *RequestPool) currentSyncForUser(req syncRequest, latestPos types.StreamingToken) (res *types.Response, err error) {
// TODO: handle ignored users
if req.since == nil {
res, err = rp.db.CompleteSync(req.ctx, req.device, req.limit)
} else {
res, err = rp.db.IncrementalSync(req.ctx, req.device, *req.since, latestPos, req.limit, req.wantFullState)
res = types.NewResponse()
since := types.NewStreamToken(0, 0)
if req.since != nil {
since = *req.since
}
// See if we have any new tasks to do for the send-to-device messaging.
events, updates, deletions, err := rp.db.SendToDeviceUpdatesForSync(req.ctx, req.device.UserID, req.device.ID, since)
if err != nil {
return nil, err
}
// TODO: handle ignored users
if req.since == nil {
res, err = rp.db.CompleteSync(req.ctx, res, req.device, req.limit)
} else {
res, err = rp.db.IncrementalSync(req.ctx, res, req.device, *req.since, latestPos, req.limit, req.wantFullState)
}
if err != nil {
return
}
accountDataFilter := gomatrixserverlib.DefaultEventFilter() // TODO: use filter provided in req instead
res, err = rp.appendAccountData(res, req.device.UserID, req, latestPos.PDUPosition(), &accountDataFilter)
if err != nil {
return
}
// Before we return the sync response, make sure that we take action on
// any send-to-device database updates or deletions that we need to do.
// Then add the updates into the sync response.
if len(updates) > 0 || len(deletions) > 0 {
// Handle the updates and deletions in the database.
err = rp.db.CleanSendToDeviceUpdates(context.Background(), updates, deletions, since)
if err != nil {
return
}
}
if len(events) > 0 {
// Add the updates into the sync response.
for _, event := range events {
res.ToDevice.Events = append(res.ToDevice.Events, event.SendToDeviceEvent)
}
// Get the next_batch from the sync response and increase the
// EDU counter.
if pos, perr := types.NewStreamTokenFromString(res.NextBatch); perr == nil {
pos.Positions[1]++
res.NextBatch = pos.String()
}
}
return
}
@ -238,6 +281,10 @@ func (rp *RequestPool) appendAccountData(
// shouldReturnImmediately returns whether the /sync request is an initial sync,
// or timeout=0, or full_state=true, in any of the cases the request should
// return immediately.
func shouldReturnImmediately(syncReq *syncRequest) bool {
return syncReq.since == nil || syncReq.timeout == 0 || syncReq.wantFullState
func (rp *RequestPool) shouldReturnImmediately(syncReq *syncRequest) bool {
if syncReq.since == nil || syncReq.timeout == 0 || syncReq.wantFullState {
return true
}
waiting, werr := rp.db.SendToDeviceUpdatesWaiting(context.TODO(), syncReq.device.UserID, syncReq.device.ID)
return werr == nil && waiting
}

View file

@ -78,7 +78,14 @@ func SetupSyncAPIComponent(
base.Cfg, base.KafkaConsumer, notifier, syncDB,
)
if err = typingConsumer.Start(); err != nil {
logrus.WithError(err).Panicf("failed to start typing server consumer")
logrus.WithError(err).Panicf("failed to start typing consumer")
}
sendToDeviceConsumer := consumers.NewOutputSendToDeviceEventConsumer(
base.Cfg, base.KafkaConsumer, notifier, syncDB,
)
if err = sendToDeviceConsumer.Start(); err != nil {
logrus.WithError(err).Panicf("failed to start send-to-device consumer")
}
routing.Setup(base.PublicAPIMux, requestPool, syncDB, deviceDB, federation, rsAPI, cfg)

View file

@ -296,13 +296,14 @@ type Response struct {
Invite map[string]InviteResponse `json:"invite"`
Leave map[string]LeaveResponse `json:"leave"`
} `json:"rooms"`
ToDevice struct {
Events []gomatrixserverlib.SendToDeviceEvent `json:"events"`
} `json:"to_device"`
}
// NewResponse creates an empty response with initialised maps.
func NewResponse(token StreamingToken) *Response {
res := Response{
NextBatch: token.String(),
}
func NewResponse() *Response {
res := Response{}
// Pre-initialise the maps. Synapse will return {} even if there are no rooms under a specific section,
// so let's do the same thing. Bonus: this means we can't get dreaded 'assignment to entry in nil map' errors.
res.Rooms.Join = make(map[string]JoinResponse)
@ -315,6 +316,7 @@ func NewResponse(token StreamingToken) *Response {
// This also applies to NewJoinResponse, NewInviteResponse and NewLeaveResponse.
res.AccountData.Events = make([]gomatrixserverlib.ClientEvent, 0)
res.Presence.Events = make([]gomatrixserverlib.ClientEvent, 0)
res.ToDevice.Events = make([]gomatrixserverlib.SendToDeviceEvent, 0)
return &res
}
@ -326,7 +328,8 @@ func (r *Response) IsEmpty() bool {
len(r.Rooms.Invite) == 0 &&
len(r.Rooms.Leave) == 0 &&
len(r.AccountData.Events) == 0 &&
len(r.Presence.Events) == 0
len(r.Presence.Events) == 0 &&
len(r.ToDevice.Events) == 0
}
// JoinResponse represents a /sync response for a room which is under the 'join' key.
@ -393,3 +396,13 @@ func NewLeaveResponse() *LeaveResponse {
res.Timeline.Events = make([]gomatrixserverlib.ClientEvent, 0)
return &res
}
type SendToDeviceNID int
type SendToDeviceEvent struct {
gomatrixserverlib.SendToDeviceEvent
ID SendToDeviceNID
UserID string
DeviceID string
SentByToken *StreamingToken
}