Merge branch 'master' into 1323-archived-rooms-sync-left-rooms

This commit is contained in:
Neil Alexander 2021-01-13 17:37:07 +00:00 committed by GitHub
commit 1be95012d4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 344 additions and 315 deletions

2
go.mod
View file

@ -22,7 +22,7 @@ require (
github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4 github.com/matrix-org/go-http-js-libp2p v0.0.0-20200518170932-783164aeeda4
github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3 github.com/matrix-org/go-sqlite3-js v0.0.0-20200522092705-bc8506ccbcf3
github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd
github.com/matrix-org/gomatrixserverlib v0.0.0-20201209172200-eb6a8903f9fb github.com/matrix-org/gomatrixserverlib v0.0.0-20210113173004-b1c67ac867cc
github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91 github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
github.com/mattn/go-sqlite3 v1.14.2 github.com/mattn/go-sqlite3 v1.14.2

4
go.sum
View file

@ -567,8 +567,8 @@ github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26 h1:Hr3zjRsq2bh
github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0= github.com/matrix-org/gomatrix v0.0.0-20190528120928-7df988a63f26/go.mod h1:3fxX6gUjWyI/2Bt7J1OLhpCzOfO/bB3AiX0cJtEKud0=
github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd h1:xVrqJK3xHREMNjwjljkAUaadalWc0rRbmVuQatzmgwg= github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd h1:xVrqJK3xHREMNjwjljkAUaadalWc0rRbmVuQatzmgwg=
github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s= github.com/matrix-org/gomatrix v0.0.0-20200827122206-7dd5e2a05bcd/go.mod h1:/gBX06Kw0exX1HrwmoBibFA98yBk/jxKpGVeyQbff+s=
github.com/matrix-org/gomatrixserverlib v0.0.0-20201209172200-eb6a8903f9fb h1:UlhiSebJupQ+qAM93cdVGg4nAJ6bnxwAA5/EBygtYoo= github.com/matrix-org/gomatrixserverlib v0.0.0-20210113173004-b1c67ac867cc h1:n2Hnbg8RZ4102Qmxie1riLkIyrqeqShJUILg1miSmDI=
github.com/matrix-org/gomatrixserverlib v0.0.0-20201209172200-eb6a8903f9fb/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU= github.com/matrix-org/gomatrixserverlib v0.0.0-20210113173004-b1c67ac867cc/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU=
github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91 h1:HJ6U3S3ljJqNffYMcIeAncp5qT/i+ZMiJ2JC2F0aXP4= github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91 h1:HJ6U3S3ljJqNffYMcIeAncp5qT/i+ZMiJ2JC2F0aXP4=
github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91/go.mod h1:sjyPyRxKM5uw1nD2cJ6O2OxI6GOqyVBfNXqKjBZTBZE= github.com/matrix-org/naffka v0.0.0-20200901083833-bcdd62999a91/go.mod h1:sjyPyRxKM5uw1nD2cJ6O2OxI6GOqyVBfNXqKjBZTBZE=
github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 h1:ntrLa/8xVzeSs8vHFHK25k0C+NV74sYMJnNSg5NoSRo= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7 h1:ntrLa/8xVzeSs8vHFHK25k0C+NV74sYMJnNSg5NoSRo=

View file

@ -20,6 +20,7 @@ import (
"bytes" "bytes"
"context" "context"
"fmt" "fmt"
"time"
"github.com/matrix-org/dendrite/internal/eventutil" "github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
@ -28,9 +29,29 @@ import (
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
func init() {
prometheus.MustRegister(processRoomEventDuration)
}
var processRoomEventDuration = prometheus.NewHistogramVec(
prometheus.HistogramOpts{
Namespace: "dendrite",
Subsystem: "roomserver",
Name: "processroomevent_duration_millis",
Help: "How long it takes the roomserver to process an event",
Buckets: []float64{ // milliseconds
5, 10, 25, 50, 75, 100, 250, 500,
1000, 2000, 3000, 4000, 5000, 6000,
7000, 8000, 9000, 10000, 15000, 20000,
},
},
[]string{"room_id"},
)
// processRoomEvent can only be called once at a time // processRoomEvent can only be called once at a time
// //
// TODO(#375): This should be rewritten to allow concurrent calls. The // TODO(#375): This should be rewritten to allow concurrent calls. The
@ -42,6 +63,15 @@ func (r *Inputer) processRoomEvent(
ctx context.Context, ctx context.Context,
input *api.InputRoomEvent, input *api.InputRoomEvent,
) (eventID string, err error) { ) (eventID string, err error) {
// Measure how long it takes to process this event.
started := time.Now()
defer func() {
timetaken := time.Since(started)
processRoomEventDuration.With(prometheus.Labels{
"room_id": input.Event.RoomID(),
}).Observe(float64(timetaken.Milliseconds()))
}()
// Parse and validate the event JSON // Parse and validate the event JSON
headered := input.Event headered := input.Event
event := headered.Unwrap() event := headered.Unwrap()

View file

@ -273,6 +273,14 @@ func (r *messagesReq) retrieveEvents() (
return []gomatrixserverlib.ClientEvent{}, *r.from, *r.to, nil return []gomatrixserverlib.ClientEvent{}, *r.from, *r.to, nil
} }
// Get the position of the first and the last event in the room's topology.
// This position is currently determined by the event's depth, so we could
// also use it instead of retrieving from the database. However, if we ever
// change the way topological positions are defined (as depth isn't the most
// reliable way to define it), it would be easier and less troublesome to
// only have to change it in one place, i.e. the database.
start, end, err = r.getStartEnd(events)
// Sort the events to ensure we send them in the right order. // Sort the events to ensure we send them in the right order.
if r.backwardOrdering { if r.backwardOrdering {
// This reverses the array from old->new to new->old // This reverses the array from old->new to new->old
@ -292,14 +300,6 @@ func (r *messagesReq) retrieveEvents() (
// Convert all of the events into client events. // Convert all of the events into client events.
clientEvents = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatAll) clientEvents = gomatrixserverlib.HeaderedToClientEvents(events, gomatrixserverlib.FormatAll)
// Get the position of the first and the last event in the room's topology.
// This position is currently determined by the event's depth, so we could
// also use it instead of retrieving from the database. However, if we ever
// change the way topological positions are defined (as depth isn't the most
// reliable way to define it), it would be easier and less troublesome to
// only have to change it in one place, i.e. the database.
start, end, err = r.getStartEnd(events)
return clientEvents, start, end, err return clientEvents, start, end, err
} }
@ -363,7 +363,7 @@ func (r *messagesReq) filterHistoryVisible(events []*gomatrixserverlib.HeaderedE
return events // apply no filtering as it defaults to Shared. return events // apply no filtering as it defaults to Shared.
} }
hisVis, _ := hisVisEvent.HistoryVisibility() hisVis, _ := hisVisEvent.HistoryVisibility()
if hisVis == "shared" { if hisVis == "shared" || hisVis == "world_readable" {
return events // apply no filtering return events // apply no filtering
} }
if membershipEvent == nil { if membershipEvent == nil {
@ -388,26 +388,16 @@ func (r *messagesReq) filterHistoryVisible(events []*gomatrixserverlib.HeaderedE
} }
func (r *messagesReq) getStartEnd(events []*gomatrixserverlib.HeaderedEvent) (start, end types.TopologyToken, err error) { func (r *messagesReq) getStartEnd(events []*gomatrixserverlib.HeaderedEvent) (start, end types.TopologyToken, err error) {
start, err = r.db.EventPositionInTopology( if r.backwardOrdering {
r.ctx, events[0].EventID(), start = *r.from
) if events[len(events)-1].Type() == gomatrixserverlib.MRoomCreate {
if err != nil { // NOTSPEC: We've hit the beginning of the room so there's really nowhere
err = fmt.Errorf("EventPositionInTopology: for start event %s: %w", events[0].EventID(), err) // else to go. This seems to fix Riot iOS from looping on /messages endlessly.
return end = types.TopologyToken{}
} } else {
if r.backwardOrdering && events[len(events)-1].Type() == gomatrixserverlib.MRoomCreate { end, err = r.db.EventPositionInTopology(
// We've hit the beginning of the room so there's really nowhere else r.ctx, events[0].EventID(),
// to go. This seems to fix Riot iOS from looping on /messages endlessly. )
end = types.TopologyToken{}
} else {
end, err = r.db.EventPositionInTopology(
r.ctx, events[len(events)-1].EventID(),
)
if err != nil {
err = fmt.Errorf("EventPositionInTopology: for end event %s: %w", events[len(events)-1].EventID(), err)
return
}
if r.backwardOrdering {
// A stream/topological position is a cursor located between two events. // A stream/topological position is a cursor located between two events.
// While they are identified in the code by the event on their right (if // While they are identified in the code by the event on their right (if
// we consider a left to right chronological order), tokens need to refer // we consider a left to right chronological order), tokens need to refer
@ -415,6 +405,15 @@ func (r *messagesReq) getStartEnd(events []*gomatrixserverlib.HeaderedEvent) (st
// end position we send in the response if we're going backward. // end position we send in the response if we're going backward.
end.Decrement() end.Decrement()
} }
} else {
start = *r.from
end, err = r.db.EventPositionInTopology(
r.ctx, events[len(events)-1].EventID(),
)
}
if err != nil {
err = fmt.Errorf("EventPositionInTopology: for end event %s: %w", events[len(events)-1].EventID(), err)
return
} }
return return
} }

View file

@ -33,6 +33,7 @@ type Database interface {
MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForReceipts(ctx context.Context) (types.StreamPosition, error)
MaxStreamPositionForInvites(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForInvites(ctx context.Context) (types.StreamPosition, error)
MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error) MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error)
MaxStreamPositionForSendToDeviceMessages(ctx context.Context) (types.StreamPosition, error)
CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter) ([]*gomatrixserverlib.HeaderedEvent, error) CurrentState(ctx context.Context, roomID string, stateFilterPart *gomatrixserverlib.StateFilter) ([]*gomatrixserverlib.HeaderedEvent, error)
GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error) GetStateDeltasForFullStateSync(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error)
@ -117,26 +118,14 @@ type Database interface {
// matches the streamevent.transactionID device then the transaction ID gets // matches the streamevent.transactionID device then the transaction ID gets
// added to the unsigned section of the output event. // added to the unsigned section of the output event.
StreamEventsToEvents(device *userapi.Device, in []types.StreamEvent) []*gomatrixserverlib.HeaderedEvent StreamEventsToEvents(device *userapi.Device, in []types.StreamEvent) []*gomatrixserverlib.HeaderedEvent
// SendToDeviceUpdatesForSync returns a list of send-to-device updates. It returns three lists: // SendToDeviceUpdatesForSync returns a list of send-to-device updates. It returns the
// - "events": a list of send-to-device events that should be included in the sync // relevant events within the given ranges for the supplied user ID and device ID.
// - "changes": a list of send-to-device events that should be updated in the database by SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, from, to types.StreamPosition) (pos types.StreamPosition, events []types.SendToDeviceEvent, err error)
// CleanSendToDeviceUpdates
// - "deletions": a list of send-to-device events which have been confirmed as sent and
// can be deleted altogether by CleanSendToDeviceUpdates
// The token supplied should be the current requested sync token, e.g. from the "since"
// parameter.
SendToDeviceUpdatesForSync(ctx context.Context, userID, deviceID string, token types.StreamingToken) (pos types.StreamPosition, events []types.SendToDeviceEvent, changes []types.SendToDeviceNID, deletions []types.SendToDeviceNID, err error)
// StoreNewSendForDeviceMessage stores a new send-to-device event for a user's device. // StoreNewSendForDeviceMessage stores a new send-to-device event for a user's device.
StoreNewSendForDeviceMessage(ctx context.Context, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent) (types.StreamPosition, error) StoreNewSendForDeviceMessage(ctx context.Context, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent) (types.StreamPosition, error)
// CleanSendToDeviceUpdates will update or remove any send-to-device updates based on the // CleanSendToDeviceUpdates removes all send-to-device messages BEFORE the specified
// result to a previous call to SendDeviceUpdatesForSync. This is separate as it allows // from position, preventing the send-to-device table from growing indefinitely.
// SendToDeviceUpdatesForSync to be called multiple times if needed (e.g. before and after CleanSendToDeviceUpdates(ctx context.Context, userID, deviceID string, before types.StreamPosition) (err error)
// starting to wait for an incremental sync with timeout).
// The token supplied should be the current requested sync token, e.g. from the "since"
// parameter.
CleanSendToDeviceUpdates(ctx context.Context, toUpdate, toDelete []types.SendToDeviceNID, token types.StreamingToken) (err error)
// SendToDeviceUpdatesWaiting returns true if there are send-to-device updates waiting to be sent.
SendToDeviceUpdatesWaiting(ctx context.Context, userID, deviceID string) (bool, error)
// GetFilter looks up the filter associated with a given local user and filter ID. // GetFilter looks up the filter associated with a given local user and filter ID.
// Returns a filter structure. Otherwise returns an error if no such filter exists // Returns a filter structure. Otherwise returns an error if no such filter exists
// or if there was an error talking to the database. // or if there was an error talking to the database.

View file

@ -24,6 +24,7 @@ import (
func LoadFromGoose() { func LoadFromGoose() {
goose.AddMigration(UpFixSequences, DownFixSequences) goose.AddMigration(UpFixSequences, DownFixSequences)
goose.AddMigration(UpRemoveSendToDeviceSentColumn, DownRemoveSendToDeviceSentColumn)
} }
func LoadFixSequences(m *sqlutil.Migrations) { func LoadFixSequences(m *sqlutil.Migrations) {

View file

@ -0,0 +1,48 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package deltas
import (
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
)
func LoadRemoveSendToDeviceSentColumn(m *sqlutil.Migrations) {
m.AddMigration(UpRemoveSendToDeviceSentColumn, DownRemoveSendToDeviceSentColumn)
}
func UpRemoveSendToDeviceSentColumn(tx *sql.Tx) error {
_, err := tx.Exec(`
ALTER TABLE syncapi_send_to_device
DROP COLUMN IF EXISTS sent_by_token;
`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
func DownRemoveSendToDeviceSentColumn(tx *sql.Tx) error {
_, err := tx.Exec(`
ALTER TABLE syncapi_send_to_device
ADD COLUMN IF NOT EXISTS sent_by_token TEXT;
`)
if err != nil {
return fmt.Errorf("failed to execute downgrade: %w", err)
}
return nil
}

View file

@ -19,7 +19,6 @@ import (
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
@ -38,11 +37,7 @@ CREATE TABLE IF NOT EXISTS syncapi_send_to_device (
-- The device ID to send the message to. -- The device ID to send the message to.
device_id TEXT NOT NULL, device_id TEXT NOT NULL,
-- The event content JSON. -- The event content JSON.
content TEXT NOT NULL, content TEXT NOT NULL
-- The token that was supplied to the /sync at the time that this
-- message was included in a sync response, or NULL if we haven't
-- included it in a /sync response yet.
sent_by_token TEXT
); );
` `
@ -52,34 +47,26 @@ const insertSendToDeviceMessageSQL = `
RETURNING id RETURNING id
` `
const countSendToDeviceMessagesSQL = `
SELECT COUNT(*)
FROM syncapi_send_to_device
WHERE user_id = $1 AND device_id = $2
`
const selectSendToDeviceMessagesSQL = ` const selectSendToDeviceMessagesSQL = `
SELECT id, user_id, device_id, content, sent_by_token SELECT id, user_id, device_id, content
FROM syncapi_send_to_device FROM syncapi_send_to_device
WHERE user_id = $1 AND device_id = $2 WHERE user_id = $1 AND device_id = $2 AND id > $3 AND id <= $4
ORDER BY id DESC ORDER BY id DESC
` `
const updateSentSendToDeviceMessagesSQL = ` const deleteSendToDeviceMessagesSQL = `
UPDATE syncapi_send_to_device SET sent_by_token = $1 DELETE FROM syncapi_send_to_device
WHERE id = ANY($2) WHERE user_id = $1 AND device_id = $2 AND id < $3
` `
const deleteSendToDeviceMessagesSQL = ` const selectMaxSendToDeviceIDSQL = "" +
DELETE FROM syncapi_send_to_device WHERE id = ANY($1) "SELECT MAX(id) FROM syncapi_send_to_device"
`
type sendToDeviceStatements struct { type sendToDeviceStatements struct {
insertSendToDeviceMessageStmt *sql.Stmt insertSendToDeviceMessageStmt *sql.Stmt
countSendToDeviceMessagesStmt *sql.Stmt selectSendToDeviceMessagesStmt *sql.Stmt
selectSendToDeviceMessagesStmt *sql.Stmt deleteSendToDeviceMessagesStmt *sql.Stmt
updateSentSendToDeviceMessagesStmt *sql.Stmt selectMaxSendToDeviceIDStmt *sql.Stmt
deleteSendToDeviceMessagesStmt *sql.Stmt
} }
func NewPostgresSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) { func NewPostgresSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
@ -91,16 +78,13 @@ func NewPostgresSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil { if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil {
return nil, err return nil, err
} }
if s.countSendToDeviceMessagesStmt, err = db.Prepare(countSendToDeviceMessagesSQL); err != nil {
return nil, err
}
if s.selectSendToDeviceMessagesStmt, err = db.Prepare(selectSendToDeviceMessagesSQL); err != nil { if s.selectSendToDeviceMessagesStmt, err = db.Prepare(selectSendToDeviceMessagesSQL); err != nil {
return nil, err return nil, err
} }
if s.updateSentSendToDeviceMessagesStmt, err = db.Prepare(updateSentSendToDeviceMessagesSQL); err != nil { if s.deleteSendToDeviceMessagesStmt, err = db.Prepare(deleteSendToDeviceMessagesSQL); err != nil {
return nil, err return nil, err
} }
if s.deleteSendToDeviceMessagesStmt, err = db.Prepare(deleteSendToDeviceMessagesSQL); err != nil { if s.selectMaxSendToDeviceIDStmt, err = db.Prepare(selectMaxSendToDeviceIDSQL); err != nil {
return nil, err return nil, err
} }
return s, nil return s, nil
@ -113,64 +97,55 @@ func (s *sendToDeviceStatements) InsertSendToDeviceMessage(
return return
} }
func (s *sendToDeviceStatements) CountSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, userID, deviceID string,
) (count int, err error) {
row := sqlutil.TxStmt(txn, s.countSendToDeviceMessagesStmt).QueryRowContext(ctx, userID, deviceID)
if err = row.Scan(&count); err != nil {
return
}
return count, nil
}
func (s *sendToDeviceStatements) SelectSendToDeviceMessages( func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, userID, deviceID string, ctx context.Context, txn *sql.Tx, userID, deviceID string, from, to types.StreamPosition,
) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error) { ) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error) {
rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID) rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID, from, to)
if err != nil { if err != nil {
return return
} }
defer internal.CloseAndLogIfError(ctx, rows, "SelectSendToDeviceMessages: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "SelectSendToDeviceMessages: rows.close() failed")
for rows.Next() { for rows.Next() {
var id types.SendToDeviceNID var id types.StreamPosition
var userID, deviceID, content string var userID, deviceID, content string
var sentByToken *string if err = rows.Scan(&id, &userID, &deviceID, &content); err != nil {
if err = rows.Scan(&id, &userID, &deviceID, &content, &sentByToken); err != nil {
return return
} }
if id > lastPos {
lastPos = id
}
event := types.SendToDeviceEvent{ event := types.SendToDeviceEvent{
ID: id, ID: id,
UserID: userID, UserID: userID,
DeviceID: deviceID, DeviceID: deviceID,
} }
if err = json.Unmarshal([]byte(content), &event.SendToDeviceEvent); err != nil { if err = json.Unmarshal([]byte(content), &event.SendToDeviceEvent); err != nil {
return continue
}
if sentByToken != nil {
if token, err := types.NewStreamTokenFromString(*sentByToken); err == nil {
event.SentByToken = &token
}
} }
events = append(events, event) events = append(events, event)
if types.StreamPosition(id) > lastPos {
lastPos = types.StreamPosition(id)
}
} }
if lastPos == 0 {
lastPos = to
}
return lastPos, events, rows.Err() return lastPos, events, rows.Err()
} }
func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages( func (s *sendToDeviceStatements) DeleteSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID, ctx context.Context, txn *sql.Tx, userID, deviceID string, pos types.StreamPosition,
) (err error) { ) (err error) {
_, err = sqlutil.TxStmt(txn, s.updateSentSendToDeviceMessagesStmt).ExecContext(ctx, token, pq.Array(nids)) _, err = sqlutil.TxStmt(txn, s.deleteSendToDeviceMessagesStmt).ExecContext(ctx, userID, deviceID, pos)
return return
} }
func (s *sendToDeviceStatements) DeleteSendToDeviceMessages( func (s *sendToDeviceStatements) SelectMaxSendToDeviceMessageID(
ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID, ctx context.Context, txn *sql.Tx,
) (err error) { ) (id int64, err error) {
_, err = sqlutil.TxStmt(txn, s.deleteSendToDeviceMessagesStmt).ExecContext(ctx, pq.Array(nids)) var nullableID sql.NullInt64
stmt := sqlutil.TxStmt(txn, s.selectMaxSendToDeviceIDStmt)
err = stmt.QueryRowContext(ctx).Scan(&nullableID)
if nullableID.Valid {
id = nullableID.Int64
}
return return
} }

View file

@ -89,6 +89,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
} }
m := sqlutil.NewMigrations() m := sqlutil.NewMigrations()
deltas.LoadFixSequences(m) deltas.LoadFixSequences(m)
deltas.LoadRemoveSendToDeviceSentColumn(m)
if err = m.RunDeltas(d.db, dbProperties); err != nil { if err = m.RunDeltas(d.db, dbProperties); err != nil {
return nil, err return nil, err
} }

View file

@ -29,6 +29,7 @@ import (
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@ -85,6 +86,14 @@ func (d *Database) MaxStreamPositionForInvites(ctx context.Context) (types.Strea
return types.StreamPosition(id), nil return types.StreamPosition(id), nil
} }
func (d *Database) MaxStreamPositionForSendToDeviceMessages(ctx context.Context) (types.StreamPosition, error) {
id, err := d.SendToDevice.SelectMaxSendToDeviceMessageID(ctx, nil)
if err != nil {
return 0, fmt.Errorf("d.SendToDevice.SelectMaxSendToDeviceMessageID: %w", err)
}
return types.StreamPosition(id), nil
}
func (d *Database) MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error) { func (d *Database) MaxStreamPositionForAccountData(ctx context.Context) (types.StreamPosition, error) {
id, err := d.AccountData.SelectMaxAccountDataID(ctx, nil) id, err := d.AccountData.SelectMaxAccountDataID(ctx, nil)
if err != nil { if err != nil {
@ -168,30 +177,6 @@ func (d *Database) GetEventsInStreamingRange(
return events, err return events, err
} }
/*
func (d *Database) AddTypingUser(
userID, roomID string, expireTime *time.Time,
) types.StreamPosition {
return types.StreamPosition(d.EDUCache.AddTypingUser(userID, roomID, expireTime))
}
func (d *Database) RemoveTypingUser(
userID, roomID string,
) types.StreamPosition {
return types.StreamPosition(d.EDUCache.RemoveUser(userID, roomID))
}
func (d *Database) SetTypingTimeoutCallback(fn cache.TimeoutCallbackFn) {
d.EDUCache.SetTimeoutCallback(fn)
}
*/
/*
func (d *Database) AddSendToDevice() types.StreamPosition {
return types.StreamPosition(d.EDUCache.AddSendToDeviceMessage())
}
*/
func (d *Database) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) { func (d *Database) AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error) {
return d.CurrentRoomState.SelectJoinedUsers(ctx) return d.CurrentRoomState.SelectJoinedUsers(ctx)
} }
@ -891,16 +876,6 @@ func (d *Database) currentStateStreamEventsForRoom(
return s, nil return s, nil
} }
func (d *Database) SendToDeviceUpdatesWaiting(
ctx context.Context, userID, deviceID string,
) (bool, error) {
count, err := d.SendToDevice.CountSendToDeviceMessages(ctx, nil, userID, deviceID)
if err != nil {
return false, err
}
return count > 0, nil
}
func (d *Database) StoreNewSendForDeviceMessage( func (d *Database) StoreNewSendForDeviceMessage(
ctx context.Context, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent, ctx context.Context, userID, deviceID string, event gomatrixserverlib.SendToDeviceEvent,
) (newPos types.StreamPosition, err error) { ) (newPos types.StreamPosition, err error) {
@ -919,77 +894,37 @@ func (d *Database) StoreNewSendForDeviceMessage(
if err != nil { if err != nil {
return 0, err return 0, err
} }
return 0, nil return newPos, nil
} }
func (d *Database) SendToDeviceUpdatesForSync( func (d *Database) SendToDeviceUpdatesForSync(
ctx context.Context, ctx context.Context,
userID, deviceID string, userID, deviceID string,
token types.StreamingToken, from, to types.StreamPosition,
) (types.StreamPosition, []types.SendToDeviceEvent, []types.SendToDeviceNID, []types.SendToDeviceNID, error) { ) (types.StreamPosition, []types.SendToDeviceEvent, error) {
// First of all, get our send-to-device updates for this user. // First of all, get our send-to-device updates for this user.
lastPos, events, err := d.SendToDevice.SelectSendToDeviceMessages(ctx, nil, userID, deviceID) lastPos, events, err := d.SendToDevice.SelectSendToDeviceMessages(ctx, nil, userID, deviceID, from, to)
if err != nil { if err != nil {
return 0, nil, nil, nil, fmt.Errorf("d.SendToDevice.SelectSendToDeviceMessages: %w", err) return from, nil, fmt.Errorf("d.SendToDevice.SelectSendToDeviceMessages: %w", err)
} }
// If there's nothing to do then stop here. // If there's nothing to do then stop here.
if len(events) == 0 { if len(events) == 0 {
return 0, nil, nil, nil, nil return to, nil, nil
} }
return lastPos, events, nil
// Work out whether we need to update any of the database entries.
toReturn := []types.SendToDeviceEvent{}
toUpdate := []types.SendToDeviceNID{}
toDelete := []types.SendToDeviceNID{}
for _, event := range events {
if event.SentByToken == nil {
// If the event has no sent-by token yet then we haven't attempted to send
// it. Record the current requested sync token in the database.
toUpdate = append(toUpdate, event.ID)
toReturn = append(toReturn, event)
event.SentByToken = &token
} else if token.IsAfter(*event.SentByToken) {
// The event had a sync token, therefore we've sent it before. The current
// sync token is now after the stored one so we can assume that the client
// successfully completed the previous sync (it would re-request it otherwise)
// so we can remove the entry from the database.
toDelete = append(toDelete, event.ID)
} else {
// It looks like the sync is being re-requested, maybe it timed out or
// failed. Re-send any that should have been acknowledged by now.
toReturn = append(toReturn, event)
}
}
return lastPos, toReturn, toUpdate, toDelete, nil
} }
func (d *Database) CleanSendToDeviceUpdates( func (d *Database) CleanSendToDeviceUpdates(
ctx context.Context, ctx context.Context,
toUpdate, toDelete []types.SendToDeviceNID, userID, deviceID string, before types.StreamPosition,
token types.StreamingToken,
) (err error) { ) (err error) {
if len(toUpdate) == 0 && len(toDelete) == 0 { if err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return nil return d.SendToDevice.DeleteSendToDeviceMessages(ctx, txn, userID, deviceID, before)
}); err != nil {
logrus.WithError(err).Errorf("Failed to clean up old send-to-device messages for user %q device %q", userID, deviceID)
return err
} }
// If we need to write to the database then we'll ask the SendToDeviceWriter to return nil
// do that for us. It'll guarantee that we don't lock the table for writes in
// more than one place.
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
// Delete any send-to-device messages marked for deletion.
if e := d.SendToDevice.DeleteSendToDeviceMessages(ctx, txn, toDelete); e != nil {
return fmt.Errorf("d.SendToDevice.DeleteSendToDeviceMessages: %w", e)
}
// Now update any outstanding send-to-device messages with the new sync token.
if e := d.SendToDevice.UpdateSentSendToDeviceMessages(ctx, txn, token.String(), toUpdate); e != nil {
return fmt.Errorf("d.SendToDevice.UpdateSentSendToDeviceMessages: %w", err)
}
return nil
})
return
} }
// getMembershipFromEvent returns the value of content.membership iff the event is a state event // getMembershipFromEvent returns the value of content.membership iff the event is a state event

View file

@ -24,6 +24,7 @@ import (
func LoadFromGoose() { func LoadFromGoose() {
goose.AddMigration(UpFixSequences, DownFixSequences) goose.AddMigration(UpFixSequences, DownFixSequences)
goose.AddMigration(UpRemoveSendToDeviceSentColumn, DownRemoveSendToDeviceSentColumn)
} }
func LoadFixSequences(m *sqlutil.Migrations) { func LoadFixSequences(m *sqlutil.Migrations) {

View file

@ -0,0 +1,67 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package deltas
import (
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
)
func LoadRemoveSendToDeviceSentColumn(m *sqlutil.Migrations) {
m.AddMigration(UpRemoveSendToDeviceSentColumn, DownRemoveSendToDeviceSentColumn)
}
func UpRemoveSendToDeviceSentColumn(tx *sql.Tx) error {
_, err := tx.Exec(`
CREATE TEMPORARY TABLE syncapi_send_to_device_backup(id, user_id, device_id, content);
INSERT INTO syncapi_send_to_device_backup SELECT id, user_id, device_id, content FROM syncapi_send_to_device;
DROP TABLE syncapi_send_to_device;
CREATE TABLE syncapi_send_to_device(
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id TEXT NOT NULL,
device_id TEXT NOT NULL,
content TEXT NOT NULL
);
INSERT INTO syncapi_send_to_device SELECT id, user_id, device_id, content FROM syncapi_send_to_device_backup;
DROP TABLE syncapi_send_to_device_backup;
`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
func DownRemoveSendToDeviceSentColumn(tx *sql.Tx) error {
_, err := tx.Exec(`
CREATE TEMPORARY TABLE syncapi_send_to_device_backup(id, user_id, device_id, content);
INSERT INTO syncapi_send_to_device_backup SELECT id, user_id, device_id, content FROM syncapi_send_to_device;
DROP TABLE syncapi_send_to_device;
CREATE TABLE syncapi_send_to_device(
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id TEXT NOT NULL,
device_id TEXT NOT NULL,
content TEXT NOT NULL,
sent_by_token TEXT
);
INSERT INTO syncapi_send_to_device SELECT id, user_id, device_id, content FROM syncapi_send_to_device_backup;
DROP TABLE syncapi_send_to_device_backup;
`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}

View file

@ -18,12 +18,12 @@ import (
"context" "context"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"strings"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables" "github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/dendrite/syncapi/types"
"github.com/sirupsen/logrus"
) )
const sendToDeviceSchema = ` const sendToDeviceSchema = `
@ -36,11 +36,7 @@ CREATE TABLE IF NOT EXISTS syncapi_send_to_device (
-- The device ID to send the message to. -- The device ID to send the message to.
device_id TEXT NOT NULL, device_id TEXT NOT NULL,
-- The event content JSON. -- The event content JSON.
content TEXT NOT NULL, content TEXT NOT NULL
-- The token that was supplied to the /sync at the time that this
-- message was included in a sync response, or NULL if we haven't
-- included it in a /sync response yet.
sent_by_token TEXT
); );
` `
@ -49,33 +45,27 @@ const insertSendToDeviceMessageSQL = `
VALUES ($1, $2, $3) VALUES ($1, $2, $3)
` `
const countSendToDeviceMessagesSQL = `
SELECT COUNT(*)
FROM syncapi_send_to_device
WHERE user_id = $1 AND device_id = $2
`
const selectSendToDeviceMessagesSQL = ` const selectSendToDeviceMessagesSQL = `
SELECT id, user_id, device_id, content, sent_by_token SELECT id, user_id, device_id, content
FROM syncapi_send_to_device FROM syncapi_send_to_device
WHERE user_id = $1 AND device_id = $2 WHERE user_id = $1 AND device_id = $2 AND id > $3 AND id <= $4
ORDER BY id DESC ORDER BY id DESC
` `
const updateSentSendToDeviceMessagesSQL = ` const deleteSendToDeviceMessagesSQL = `
UPDATE syncapi_send_to_device SET sent_by_token = $1 DELETE FROM syncapi_send_to_device
WHERE id IN ($2) WHERE user_id = $1 AND device_id = $2 AND id < $3
` `
const deleteSendToDeviceMessagesSQL = ` const selectMaxSendToDeviceIDSQL = "" +
DELETE FROM syncapi_send_to_device WHERE id IN ($1) "SELECT MAX(id) FROM syncapi_send_to_device"
`
type sendToDeviceStatements struct { type sendToDeviceStatements struct {
db *sql.DB db *sql.DB
insertSendToDeviceMessageStmt *sql.Stmt insertSendToDeviceMessageStmt *sql.Stmt
selectSendToDeviceMessagesStmt *sql.Stmt selectSendToDeviceMessagesStmt *sql.Stmt
countSendToDeviceMessagesStmt *sql.Stmt deleteSendToDeviceMessagesStmt *sql.Stmt
selectMaxSendToDeviceIDStmt *sql.Stmt
} }
func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) { func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
@ -86,15 +76,18 @@ func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
if s.countSendToDeviceMessagesStmt, err = db.Prepare(countSendToDeviceMessagesSQL); err != nil {
return nil, err
}
if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil { if s.insertSendToDeviceMessageStmt, err = db.Prepare(insertSendToDeviceMessageSQL); err != nil {
return nil, err return nil, err
} }
if s.selectSendToDeviceMessagesStmt, err = db.Prepare(selectSendToDeviceMessagesSQL); err != nil { if s.selectSendToDeviceMessagesStmt, err = db.Prepare(selectSendToDeviceMessagesSQL); err != nil {
return nil, err return nil, err
} }
if s.deleteSendToDeviceMessagesStmt, err = db.Prepare(deleteSendToDeviceMessagesSQL); err != nil {
return nil, err
}
if s.selectMaxSendToDeviceIDStmt, err = db.Prepare(selectMaxSendToDeviceIDSQL); err != nil {
return nil, err
}
return s, nil return s, nil
} }
@ -111,75 +104,57 @@ func (s *sendToDeviceStatements) InsertSendToDeviceMessage(
return return
} }
func (s *sendToDeviceStatements) CountSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, userID, deviceID string,
) (count int, err error) {
row := sqlutil.TxStmt(txn, s.countSendToDeviceMessagesStmt).QueryRowContext(ctx, userID, deviceID)
if err = row.Scan(&count); err != nil {
return
}
return count, nil
}
func (s *sendToDeviceStatements) SelectSendToDeviceMessages( func (s *sendToDeviceStatements) SelectSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, userID, deviceID string, ctx context.Context, txn *sql.Tx, userID, deviceID string, from, to types.StreamPosition,
) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error) { ) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error) {
rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID) rows, err := sqlutil.TxStmt(txn, s.selectSendToDeviceMessagesStmt).QueryContext(ctx, userID, deviceID, from, to)
if err != nil { if err != nil {
return return
} }
defer internal.CloseAndLogIfError(ctx, rows, "SelectSendToDeviceMessages: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "SelectSendToDeviceMessages: rows.close() failed")
for rows.Next() { for rows.Next() {
var id types.SendToDeviceNID var id types.StreamPosition
var userID, deviceID, content string var userID, deviceID, content string
var sentByToken *string if err = rows.Scan(&id, &userID, &deviceID, &content); err != nil {
if err = rows.Scan(&id, &userID, &deviceID, &content, &sentByToken); err != nil { logrus.WithError(err).Errorf("Failed to retrieve send-to-device message")
return return
} }
if id > lastPos {
lastPos = id
}
event := types.SendToDeviceEvent{ event := types.SendToDeviceEvent{
ID: id, ID: id,
UserID: userID, UserID: userID,
DeviceID: deviceID, DeviceID: deviceID,
} }
if err = json.Unmarshal([]byte(content), &event.SendToDeviceEvent); err != nil { if err = json.Unmarshal([]byte(content), &event.SendToDeviceEvent); err != nil {
return logrus.WithError(err).Errorf("Failed to unmarshal send-to-device message")
} continue
if sentByToken != nil {
if token, err := types.NewStreamTokenFromString(*sentByToken); err == nil {
event.SentByToken = &token
}
} }
events = append(events, event) events = append(events, event)
if types.StreamPosition(id) > lastPos {
lastPos = types.StreamPosition(id)
}
} }
if lastPos == 0 {
lastPos = to
}
return lastPos, events, rows.Err() return lastPos, events, rows.Err()
} }
func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages( func (s *sendToDeviceStatements) DeleteSendToDeviceMessages(
ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID, ctx context.Context, txn *sql.Tx, userID, deviceID string, pos types.StreamPosition,
) (err error) { ) (err error) {
query := strings.Replace(updateSentSendToDeviceMessagesSQL, "($2)", sqlutil.QueryVariadic(1+len(nids)), 1) _, err = sqlutil.TxStmt(txn, s.deleteSendToDeviceMessagesStmt).ExecContext(ctx, userID, deviceID, pos)
params := make([]interface{}, 1+len(nids))
params[0] = token
for k, v := range nids {
params[k+1] = v
}
_, err = txn.ExecContext(ctx, query, params...)
return return
} }
func (s *sendToDeviceStatements) DeleteSendToDeviceMessages( func (s *sendToDeviceStatements) SelectMaxSendToDeviceMessageID(
ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID, ctx context.Context, txn *sql.Tx,
) (err error) { ) (id int64, err error) {
query := strings.Replace(deleteSendToDeviceMessagesSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1) var nullableID sql.NullInt64
params := make([]interface{}, 1+len(nids)) stmt := sqlutil.TxStmt(txn, s.selectMaxSendToDeviceIDStmt)
for k, v := range nids { err = stmt.QueryRowContext(ctx).Scan(&nullableID)
params[k] = v if nullableID.Valid {
id = nullableID.Int64
} }
_, err = txn.ExecContext(ctx, query, params...)
return return
} }

View file

@ -102,6 +102,7 @@ func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (er
} }
m := sqlutil.NewMigrations() m := sqlutil.NewMigrations()
deltas.LoadFixSequences(m) deltas.LoadFixSequences(m)
deltas.LoadRemoveSendToDeviceSentColumn(m)
if err = m.RunDeltas(d.db, dbProperties); err != nil { if err = m.RunDeltas(d.db, dbProperties); err != nil {
return err return err
} }

View file

@ -147,10 +147,9 @@ type BackwardsExtremities interface {
// sync response, as the client is seemingly trying to repeat the same /sync. // sync response, as the client is seemingly trying to repeat the same /sync.
type SendToDevice interface { type SendToDevice interface {
InsertSendToDeviceMessage(ctx context.Context, txn *sql.Tx, userID, deviceID, content string) (pos types.StreamPosition, err error) InsertSendToDeviceMessage(ctx context.Context, txn *sql.Tx, userID, deviceID, content string) (pos types.StreamPosition, err error)
SelectSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error) SelectSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string, from, to types.StreamPosition) (lastPos types.StreamPosition, events []types.SendToDeviceEvent, err error)
UpdateSentSendToDeviceMessages(ctx context.Context, txn *sql.Tx, token string, nids []types.SendToDeviceNID) (err error) DeleteSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string, from types.StreamPosition) (err error)
DeleteSendToDeviceMessages(ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID) (err error) SelectMaxSendToDeviceMessageID(ctx context.Context, txn *sql.Tx) (id int64, err error)
CountSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string) (count int, err error)
} }
type Filter interface { type Filter interface {

View file

@ -82,11 +82,6 @@ func (p *AccountDataStreamProvider) IncrementalSync(
return from return from
} }
if len(dataTypes) == 0 {
// TODO: this fixes the sytest but is it the right thing to do?
dataTypes[""] = []string{"m.push_rules"}
}
// Iterate over the rooms // Iterate over the rooms
for roomID, dataTypes := range dataTypes { for roomID, dataTypes := range dataTypes {
// Request the missing data from the database // Request the missing data from the database
@ -114,7 +109,10 @@ func (p *AccountDataStreamProvider) IncrementalSync(
} }
} else { } else {
if roomData, ok := dataRes.RoomAccountData[roomID][dataType]; ok { if roomData, ok := dataRes.RoomAccountData[roomID][dataType]; ok {
joinData := req.Response.Rooms.Join[roomID] joinData := *types.NewJoinResponse()
if existing, ok := req.Response.Rooms.Join[roomID]; ok {
joinData = existing
}
joinData.AccountData.Events = append( joinData.AccountData.Events = append(
joinData.AccountData.Events, joinData.AccountData.Events,
gomatrixserverlib.ClientEvent{ gomatrixserverlib.ClientEvent{

View file

@ -59,7 +59,10 @@ func (p *ReceiptStreamProvider) IncrementalSync(
} }
for roomID, receipts := range receiptsByRoom { for roomID, receipts := range receiptsByRoom {
jr := req.Response.Rooms.Join[roomID] jr := *types.NewJoinResponse()
if existing, ok := req.Response.Rooms.Join[roomID]; ok {
jr = existing
}
var ok bool var ok bool
ev := gomatrixserverlib.ClientEvent{ ev := gomatrixserverlib.ClientEvent{

View file

@ -10,6 +10,16 @@ type SendToDeviceStreamProvider struct {
StreamProvider StreamProvider
} }
func (p *SendToDeviceStreamProvider) Setup() {
p.StreamProvider.Setup()
id, err := p.DB.MaxStreamPositionForSendToDeviceMessages(context.Background())
if err != nil {
panic(err)
}
p.latest = id
}
func (p *SendToDeviceStreamProvider) CompleteSync( func (p *SendToDeviceStreamProvider) CompleteSync(
ctx context.Context, ctx context.Context,
req *types.SyncRequest, req *types.SyncRequest,
@ -23,24 +33,19 @@ func (p *SendToDeviceStreamProvider) IncrementalSync(
from, to types.StreamPosition, from, to types.StreamPosition,
) types.StreamPosition { ) types.StreamPosition {
// See if we have any new tasks to do for the send-to-device messaging. // See if we have any new tasks to do for the send-to-device messaging.
lastPos, events, updates, deletions, err := p.DB.SendToDeviceUpdatesForSync(req.Context, req.Device.UserID, req.Device.ID, req.Since) lastPos, events, err := p.DB.SendToDeviceUpdatesForSync(req.Context, req.Device.UserID, req.Device.ID, from, to)
if err != nil { if err != nil {
req.Log.WithError(err).Error("p.DB.SendToDeviceUpdatesForSync failed") req.Log.WithError(err).Error("p.DB.SendToDeviceUpdatesForSync failed")
return from return from
} }
// Before we return the sync response, make sure that we take action on if len(events) > 0 {
// any send-to-device database updates or deletions that we need to do. // Clean up old send-to-device messages from before this stream position.
// Then add the updates into the sync response. if err := p.DB.CleanSendToDeviceUpdates(req.Context, req.Device.UserID, req.Device.ID, from); err != nil {
if len(updates) > 0 || len(deletions) > 0 {
// Handle the updates and deletions in the database.
err = p.DB.CleanSendToDeviceUpdates(context.Background(), updates, deletions, req.Since)
if err != nil {
req.Log.WithError(err).Error("p.DB.CleanSendToDeviceUpdates failed") req.Log.WithError(err).Error("p.DB.CleanSendToDeviceUpdates failed")
return from return from
} }
}
if len(events) > 0 {
// Add the updates into the sync response. // Add the updates into the sync response.
for _, event := range events { for _, event := range events {
req.Response.ToDevice.Events = append(req.Response.ToDevice.Events, event.SendToDeviceEvent) req.Response.ToDevice.Events = append(req.Response.ToDevice.Events, event.SendToDeviceEvent)

View file

@ -32,7 +32,10 @@ func (p *TypingStreamProvider) IncrementalSync(
continue continue
} }
jr := req.Response.Rooms.Join[roomID] jr := *types.NewJoinResponse()
if existing, ok := req.Response.Rooms.Join[roomID]; ok {
jr = existing
}
if users, updated := p.EDUCache.GetTypingUsersIfUpdatedAfter( if users, updated := p.EDUCache.GetTypingUsersIfUpdatedAfter(
roomID, int64(from), roomID, int64(from),

View file

@ -360,11 +360,11 @@ type PrevEventRef struct {
type Response struct { type Response struct {
NextBatch StreamingToken `json:"next_batch"` NextBatch StreamingToken `json:"next_batch"`
AccountData struct { AccountData struct {
Events []gomatrixserverlib.ClientEvent `json:"events"` Events []gomatrixserverlib.ClientEvent `json:"events,omitempty"`
} `json:"account_data,omitempty"` } `json:"account_data"`
Presence struct { Presence struct {
Events []gomatrixserverlib.ClientEvent `json:"events"` Events []gomatrixserverlib.ClientEvent `json:"events,omitempty"`
} `json:"presence,omitempty"` } `json:"presence"`
Rooms struct { Rooms struct {
Join map[string]JoinResponse `json:"join"` Join map[string]JoinResponse `json:"join"`
Peek map[string]JoinResponse `json:"peek"` Peek map[string]JoinResponse `json:"peek"`
@ -372,13 +372,13 @@ type Response struct {
Leave map[string]LeaveResponse `json:"leave"` Leave map[string]LeaveResponse `json:"leave"`
} `json:"rooms"` } `json:"rooms"`
ToDevice struct { ToDevice struct {
Events []gomatrixserverlib.SendToDeviceEvent `json:"events"` Events []gomatrixserverlib.SendToDeviceEvent `json:"events,omitempty"`
} `json:"to_device"` } `json:"to_device"`
DeviceLists struct { DeviceLists struct {
Changed []string `json:"changed,omitempty"` Changed []string `json:"changed,omitempty"`
Left []string `json:"left,omitempty"` Left []string `json:"left,omitempty"`
} `json:"device_lists,omitempty"` } `json:"device_lists"`
DeviceListsOTKCount map[string]int `json:"device_one_time_keys_count"` DeviceListsOTKCount map[string]int `json:"device_one_time_keys_count,omitempty"`
} }
// NewResponse creates an empty response with initialised maps. // NewResponse creates an empty response with initialised maps.
@ -386,19 +386,19 @@ func NewResponse() *Response {
res := Response{} res := Response{}
// Pre-initialise the maps. Synapse will return {} even if there are no rooms under a specific section, // Pre-initialise the maps. Synapse will return {} even if there are no rooms under a specific section,
// so let's do the same thing. Bonus: this means we can't get dreaded 'assignment to entry in nil map' errors. // so let's do the same thing. Bonus: this means we can't get dreaded 'assignment to entry in nil map' errors.
res.Rooms.Join = make(map[string]JoinResponse) res.Rooms.Join = map[string]JoinResponse{}
res.Rooms.Peek = make(map[string]JoinResponse) res.Rooms.Peek = map[string]JoinResponse{}
res.Rooms.Invite = make(map[string]InviteResponse) res.Rooms.Invite = map[string]InviteResponse{}
res.Rooms.Leave = make(map[string]LeaveResponse) res.Rooms.Leave = map[string]LeaveResponse{}
// Also pre-intialise empty slices or else we'll insert 'null' instead of '[]' for the value. // Also pre-intialise empty slices or else we'll insert 'null' instead of '[]' for the value.
// TODO: We really shouldn't have to do all this to coerce encoding/json to Do The Right Thing. We should // TODO: We really shouldn't have to do all this to coerce encoding/json to Do The Right Thing. We should
// really be using our own Marshal/Unmarshal implementations otherwise this may prove to be a CPU bottleneck. // really be using our own Marshal/Unmarshal implementations otherwise this may prove to be a CPU bottleneck.
// This also applies to NewJoinResponse, NewInviteResponse and NewLeaveResponse. // This also applies to NewJoinResponse, NewInviteResponse and NewLeaveResponse.
res.AccountData.Events = make([]gomatrixserverlib.ClientEvent, 0) res.AccountData.Events = []gomatrixserverlib.ClientEvent{}
res.Presence.Events = make([]gomatrixserverlib.ClientEvent, 0) res.Presence.Events = []gomatrixserverlib.ClientEvent{}
res.ToDevice.Events = make([]gomatrixserverlib.SendToDeviceEvent, 0) res.ToDevice.Events = []gomatrixserverlib.SendToDeviceEvent{}
res.DeviceListsOTKCount = make(map[string]int) res.DeviceListsOTKCount = map[string]int{}
return &res return &res
} }
@ -435,10 +435,10 @@ type JoinResponse struct {
// NewJoinResponse creates an empty response with initialised arrays. // NewJoinResponse creates an empty response with initialised arrays.
func NewJoinResponse() *JoinResponse { func NewJoinResponse() *JoinResponse {
res := JoinResponse{} res := JoinResponse{}
res.State.Events = make([]gomatrixserverlib.ClientEvent, 0) res.State.Events = []gomatrixserverlib.ClientEvent{}
res.Timeline.Events = make([]gomatrixserverlib.ClientEvent, 0) res.Timeline.Events = []gomatrixserverlib.ClientEvent{}
res.Ephemeral.Events = make([]gomatrixserverlib.ClientEvent, 0) res.Ephemeral.Events = []gomatrixserverlib.ClientEvent{}
res.AccountData.Events = make([]gomatrixserverlib.ClientEvent, 0) res.AccountData.Events = []gomatrixserverlib.ClientEvent{}
return &res return &res
} }
@ -487,19 +487,16 @@ type LeaveResponse struct {
// NewLeaveResponse creates an empty response with initialised arrays. // NewLeaveResponse creates an empty response with initialised arrays.
func NewLeaveResponse() *LeaveResponse { func NewLeaveResponse() *LeaveResponse {
res := LeaveResponse{} res := LeaveResponse{}
res.State.Events = make([]gomatrixserverlib.ClientEvent, 0) res.State.Events = []gomatrixserverlib.ClientEvent{}
res.Timeline.Events = make([]gomatrixserverlib.ClientEvent, 0) res.Timeline.Events = []gomatrixserverlib.ClientEvent{}
return &res return &res
} }
type SendToDeviceNID int
type SendToDeviceEvent struct { type SendToDeviceEvent struct {
gomatrixserverlib.SendToDeviceEvent gomatrixserverlib.SendToDeviceEvent
ID SendToDeviceNID ID StreamPosition
UserID string UserID string
DeviceID string DeviceID string
SentByToken *StreamingToken
} }
type PeekingDevice struct { type PeekingDevice struct {

View file

@ -502,3 +502,5 @@ Can forget room you've been kicked from
/joined_members return joined members /joined_members return joined members
A next_batch token can be used in the v1 messages API A next_batch token can be used in the v1 messages API
Users receive device_list updates for their own devices Users receive device_list updates for their own devices
m.room.history_visibility == "world_readable" allows/forbids appropriately for Guest users
m.room.history_visibility == "world_readable" allows/forbids appropriately for Real users