Send device list updates to servers (outbound only) (#1237)

* Add QueryDeviceMessages to serve up device keys and stream IDs

* Consume key change events in fedsender

Don't yet send them to destinations as we haven't worked them out yet

* Send device list updates to all required servers

* Glue it all together
This commit is contained in:
Kegsay 2020-08-04 11:32:14 +01:00 committed by GitHub
parent fb56bbf0b7
commit 0c4e8f6d4f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 328 additions and 50 deletions

View file

@ -0,0 +1,135 @@
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package consumers
import (
"context"
"encoding/json"
"fmt"
"github.com/Shopify/sarama"
stateapi "github.com/matrix-org/dendrite/currentstateserver/api"
"github.com/matrix-org/dendrite/federationsender/queue"
"github.com/matrix-org/dendrite/federationsender/storage"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/config"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
)
// KeyChangeConsumer consumes events that originate in key server.
type KeyChangeConsumer struct {
consumer *internal.ContinualConsumer
db storage.Database
queues *queue.OutgoingQueues
serverName gomatrixserverlib.ServerName
stateAPI stateapi.CurrentStateInternalAPI
}
// NewKeyChangeConsumer creates a new KeyChangeConsumer. Call Start() to begin consuming from key servers.
func NewKeyChangeConsumer(
cfg *config.Dendrite,
kafkaConsumer sarama.Consumer,
queues *queue.OutgoingQueues,
store storage.Database,
stateAPI stateapi.CurrentStateInternalAPI,
) *KeyChangeConsumer {
c := &KeyChangeConsumer{
consumer: &internal.ContinualConsumer{
Topic: string(cfg.Kafka.Topics.OutputKeyChangeEvent),
Consumer: kafkaConsumer,
PartitionStore: store,
},
queues: queues,
db: store,
serverName: cfg.Matrix.ServerName,
stateAPI: stateAPI,
}
c.consumer.ProcessMessage = c.onMessage
return c
}
// Start consuming from key servers
func (t *KeyChangeConsumer) Start() error {
if err := t.consumer.Start(); err != nil {
return fmt.Errorf("t.consumer.Start: %w", err)
}
return nil
}
// onMessage is called in response to a message received on the
// key change events topic from the key server.
func (t *KeyChangeConsumer) onMessage(msg *sarama.ConsumerMessage) error {
var m api.DeviceMessage
if err := json.Unmarshal(msg.Value, &m); err != nil {
log.WithError(err).Errorf("failed to read device message from key change topic")
return nil
}
logger := log.WithField("user_id", m.UserID)
// only send key change events which originated from us
_, originServerName, err := gomatrixserverlib.SplitID('@', m.UserID)
if err != nil {
logger.WithError(err).Error("Failed to extract domain from key change event")
return nil
}
if originServerName != t.serverName {
return nil
}
var queryRes stateapi.QueryRoomsForUserResponse
err = t.stateAPI.QueryRoomsForUser(context.Background(), &stateapi.QueryRoomsForUserRequest{
UserID: m.UserID,
WantMembership: "join",
}, &queryRes)
if err != nil {
logger.WithError(err).Error("failed to calculate joined rooms for user")
return nil
}
// send this key change to all servers who share rooms with this user.
destinations, err := t.db.GetJoinedHostsForRooms(context.Background(), queryRes.RoomIDs)
if err != nil {
logger.WithError(err).Error("failed to calculate joined hosts for rooms user is in")
return nil
}
// Pack the EDU and marshal it
edu := &gomatrixserverlib.EDU{
Type: gomatrixserverlib.MDeviceListUpdate,
Origin: string(t.serverName),
}
event := gomatrixserverlib.DeviceListUpdateEvent{
UserID: m.UserID,
DeviceID: m.DeviceID,
DeviceDisplayName: m.DisplayName,
StreamID: m.StreamID,
PrevID: prevID(m.StreamID),
Deleted: len(m.KeyJSON) == 0,
Keys: m.KeyJSON,
}
if edu.Content, err = json.Marshal(event); err != nil {
return err
}
log.Infof("Sending device list update message to %q", destinations)
return t.queues.SendEDU(edu, t.serverName, destinations)
}
func prevID(streamID int) []int {
if streamID <= 1 {
return nil
}
return []int{streamID - 1}
}

View file

@ -16,6 +16,7 @@ package federationsender
import (
"github.com/gorilla/mux"
stateapi "github.com/matrix-org/dendrite/currentstateserver/api"
"github.com/matrix-org/dendrite/federationsender/api"
"github.com/matrix-org/dendrite/federationsender/consumers"
"github.com/matrix-org/dendrite/federationsender/internal"
@ -41,6 +42,7 @@ func NewInternalAPI(
base *setup.BaseDendrite,
federation *gomatrixserverlib.FederationClient,
rsAPI roomserverAPI.RoomserverInternalAPI,
stateAPI stateapi.CurrentStateInternalAPI,
keyRing *gomatrixserverlib.KeyRing,
) api.FederationSenderInternalAPI {
federationSenderDB, err := storage.NewDatabase(string(base.Cfg.Database.FederationSender), base.Cfg.DbProperties())
@ -76,6 +78,12 @@ func NewInternalAPI(
if err := tsConsumer.Start(); err != nil {
logrus.WithError(err).Panic("failed to start typing server consumer")
}
keyConsumer := consumers.NewKeyChangeConsumer(
base.Cfg, base.KafkaConsumer, queues, federationSenderDB, stateAPI,
)
if err := keyConsumer.Start(); err != nil {
logrus.WithError(err).Panic("failed to start key server consumer")
}
return internal.NewFederationSenderInternalAPI(federationSenderDB, base.Cfg, rsAPI, federation, keyRing, stats, queues)
}

View file

@ -30,6 +30,8 @@ type Database interface {
GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error)
GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
// GetJoinedHostsForRooms returns the complete set of servers in the rooms given.
GetJoinedHostsForRooms(ctx context.Context, roomIDs []string) ([]gomatrixserverlib.ServerName, error)
StoreJSON(ctx context.Context, js string) (*shared.Receipt, error)

View file

@ -60,12 +60,16 @@ const selectJoinedHostsSQL = "" +
const selectAllJoinedHostsSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_joined_hosts"
const selectJoinedHostsForRoomsSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id = ANY($1)"
type joinedHostsStatements struct {
db *sql.DB
insertJoinedHostsStmt *sql.Stmt
deleteJoinedHostsStmt *sql.Stmt
selectJoinedHostsStmt *sql.Stmt
selectAllJoinedHostsStmt *sql.Stmt
db *sql.DB
insertJoinedHostsStmt *sql.Stmt
deleteJoinedHostsStmt *sql.Stmt
selectJoinedHostsStmt *sql.Stmt
selectAllJoinedHostsStmt *sql.Stmt
selectJoinedHostsForRoomsStmt *sql.Stmt
}
func NewPostgresJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) {
@ -88,6 +92,9 @@ func NewPostgresJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err erro
if s.selectAllJoinedHostsStmt, err = s.db.Prepare(selectAllJoinedHostsSQL); err != nil {
return
}
if s.selectJoinedHostsForRoomsStmt, err = s.db.Prepare(selectJoinedHostsForRoomsSQL); err != nil {
return
}
return
}
@ -144,6 +151,27 @@ func (s *joinedHostsStatements) SelectAllJoinedHosts(
return result, rows.Err()
}
func (s *joinedHostsStatements) SelectJoinedHostsForRooms(
ctx context.Context, roomIDs []string,
) ([]gomatrixserverlib.ServerName, error) {
rows, err := s.selectJoinedHostsForRoomsStmt.QueryContext(ctx)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedHostsForRoomsStmt: rows.close() failed")
var result []gomatrixserverlib.ServerName
for rows.Next() {
var serverName string
if err = rows.Scan(&serverName); err != nil {
return nil, err
}
result = append(result, gomatrixserverlib.ServerName(serverName))
}
return result, rows.Err()
}
func joinedHostsFromStmt(
ctx context.Context, stmt *sql.Stmt, roomID string,
) ([]types.JoinedHost, error) {

View file

@ -123,6 +123,10 @@ func (d *Database) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.S
return d.FederationSenderJoinedHosts.SelectAllJoinedHosts(ctx)
}
func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string) ([]gomatrixserverlib.ServerName, error) {
return d.FederationSenderJoinedHosts.SelectJoinedHostsForRooms(ctx, roomIDs)
}
// StoreJSON adds a JSON blob into the queue JSON table and returns
// a NID. The NID will then be used when inserting the per-destination
// metadata entries.

View file

@ -59,13 +59,17 @@ const selectJoinedHostsSQL = "" +
const selectAllJoinedHostsSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_joined_hosts"
const selectJoinedHostsForRoomsSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id IN ($1)"
type joinedHostsStatements struct {
db *sql.DB
writer *sqlutil.TransactionWriter
insertJoinedHostsStmt *sql.Stmt
deleteJoinedHostsStmt *sql.Stmt
selectJoinedHostsStmt *sql.Stmt
selectAllJoinedHostsStmt *sql.Stmt
db *sql.DB
writer *sqlutil.TransactionWriter
insertJoinedHostsStmt *sql.Stmt
deleteJoinedHostsStmt *sql.Stmt
selectJoinedHostsStmt *sql.Stmt
selectAllJoinedHostsStmt *sql.Stmt
selectJoinedHostsForRoomsStmt *sql.Stmt
}
func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) {
@ -89,6 +93,9 @@ func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error)
if s.selectAllJoinedHostsStmt, err = db.Prepare(selectAllJoinedHostsSQL); err != nil {
return
}
if s.selectJoinedHostsForRoomsStmt, err = db.Prepare(selectJoinedHostsForRoomsSQL); err != nil {
return
}
return
}
@ -153,6 +160,32 @@ func (s *joinedHostsStatements) SelectAllJoinedHosts(
return result, rows.Err()
}
func (s *joinedHostsStatements) SelectJoinedHostsForRooms(
ctx context.Context, roomIDs []string,
) ([]gomatrixserverlib.ServerName, error) {
iRoomIDs := make([]interface{}, len(roomIDs))
for i := range roomIDs {
iRoomIDs[i] = roomIDs[i]
}
rows, err := s.selectJoinedHostsForRoomsStmt.QueryContext(ctx, iRoomIDs...)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedHostsForRoomsStmt: rows.close() failed")
var result []gomatrixserverlib.ServerName
for rows.Next() {
var serverName string
if err = rows.Scan(&serverName); err != nil {
return nil, err
}
result = append(result, gomatrixserverlib.ServerName(serverName))
}
return result, rows.Err()
}
func joinedHostsFromStmt(
ctx context.Context, stmt *sql.Stmt, roomID string,
) ([]types.JoinedHost, error) {

View file

@ -53,6 +53,7 @@ type FederationSenderJoinedHosts interface {
SelectJoinedHostsWithTx(ctx context.Context, txn *sql.Tx, roomID string) ([]types.JoinedHost, error)
SelectJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error)
SelectAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
SelectJoinedHostsForRooms(ctx context.Context, roomIDs []string) ([]gomatrixserverlib.ServerName, error)
}
type FederationSenderRooms interface {