Add jetstream.WithJetStreamMessage to make ack/nak-ing less messy, use process context in consumers

This commit is contained in:
Neil Alexander 2022-01-05 13:45:27 +00:00
parent 1e92206fbc
commit 51de6612a6
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
11 changed files with 412 additions and 385 deletions

View file

@ -32,6 +32,7 @@ import (
// OutputRoomEventConsumer consumes events that originated in the room server. // OutputRoomEventConsumer consumes events that originated in the room server.
type OutputRoomEventConsumer struct { type OutputRoomEventConsumer struct {
ctx context.Context
jetstream nats.JetStreamContext jetstream nats.JetStreamContext
topic string topic string
asDB storage.Database asDB storage.Database
@ -51,6 +52,7 @@ func NewOutputRoomEventConsumer(
workerStates []types.ApplicationServiceWorkerState, workerStates []types.ApplicationServiceWorkerState,
) *OutputRoomEventConsumer { ) *OutputRoomEventConsumer {
return &OutputRoomEventConsumer{ return &OutputRoomEventConsumer{
ctx: process.Context(),
jetstream: js, jetstream: js,
topic: cfg.Global.JetStream.TopicFor(jetstream.OutputRoomEvent), topic: cfg.Global.JetStream.TopicFor(jetstream.OutputRoomEvent),
asDB: appserviceDB, asDB: appserviceDB,
@ -69,30 +71,30 @@ func (s *OutputRoomEventConsumer) Start() error {
// onMessage is called when the appservice component receives a new event from // onMessage is called when the appservice component receives a new event from
// the room server output log. // the room server output log.
func (s *OutputRoomEventConsumer) onMessage(msg *nats.Msg) { func (s *OutputRoomEventConsumer) onMessage(msg *nats.Msg) {
// Parse out the event JSON jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool {
var output api.OutputEvent // Parse out the event JSON
if err := json.Unmarshal(msg.Data, &output); err != nil { var output api.OutputEvent
// If the message was invalid, log it and move on to the next message in the stream if err := json.Unmarshal(msg.Data, &output); err != nil {
log.WithError(err).Errorf("roomserver output log: message parse failure") // If the message was invalid, log it and move on to the next message in the stream
_ = msg.Ack() log.WithError(err).Errorf("roomserver output log: message parse failure")
return return true
} }
if output.Type != api.OutputTypeNewRoomEvent { if output.Type != api.OutputTypeNewRoomEvent {
_ = msg.Ack() return true
return }
}
events := []*gomatrixserverlib.HeaderedEvent{output.NewRoomEvent.Event} events := []*gomatrixserverlib.HeaderedEvent{output.NewRoomEvent.Event}
events = append(events, output.NewRoomEvent.AddStateEvents...) events = append(events, output.NewRoomEvent.AddStateEvents...)
// Send event to any relevant application services // Send event to any relevant application services
if err := s.filterRoomserverEvents(context.TODO(), events); err != nil { if err := s.filterRoomserverEvents(context.TODO(), events); err != nil {
log.WithError(err).Errorf("roomserver output log: filter error") log.WithError(err).Errorf("roomserver output log: filter error")
return return true
} }
_ = msg.Ack() return true
})
} }
// filterRoomserverEvents takes in events and decides whether any of them need // filterRoomserverEvents takes in events and decides whether any of them need

View file

@ -32,6 +32,7 @@ import (
// OutputEDUConsumer consumes events that originate in EDU server. // OutputEDUConsumer consumes events that originate in EDU server.
type OutputEDUConsumer struct { type OutputEDUConsumer struct {
ctx context.Context
jetstream nats.JetStreamContext jetstream nats.JetStreamContext
db storage.Database db storage.Database
queues *queue.OutgoingQueues queues *queue.OutgoingQueues
@ -50,6 +51,7 @@ func NewOutputEDUConsumer(
store storage.Database, store storage.Database,
) *OutputEDUConsumer { ) *OutputEDUConsumer {
return &OutputEDUConsumer{ return &OutputEDUConsumer{
ctx: process.Context(),
jetstream: js, jetstream: js,
queues: queues, queues: queues,
db: store, db: store,
@ -78,174 +80,173 @@ func (t *OutputEDUConsumer) Start() error {
// send-to-device events topic from the EDU server. // send-to-device events topic from the EDU server.
func (t *OutputEDUConsumer) onSendToDeviceEvent(msg *nats.Msg) { func (t *OutputEDUConsumer) onSendToDeviceEvent(msg *nats.Msg) {
// Extract the send-to-device event from msg. // Extract the send-to-device event from msg.
var ote api.OutputSendToDeviceEvent jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool {
if err := json.Unmarshal(msg.Data, &ote); err != nil { var ote api.OutputSendToDeviceEvent
log.WithError(err).Errorf("eduserver output log: message parse failed (expected send-to-device)") if err := json.Unmarshal(msg.Data, &ote); err != nil {
_ = msg.Ack() log.WithError(err).Errorf("eduserver output log: message parse failed (expected send-to-device)")
return return true
} }
// only send send-to-device events which originated from us // only send send-to-device events which originated from us
_, originServerName, err := gomatrixserverlib.SplitID('@', ote.Sender) _, originServerName, err := gomatrixserverlib.SplitID('@', ote.Sender)
if err != nil { if err != nil {
log.WithError(err).WithField("user_id", ote.Sender).Error("Failed to extract domain from send-to-device sender") log.WithError(err).WithField("user_id", ote.Sender).Error("Failed to extract domain from send-to-device sender")
_ = msg.Ack() return true
return }
} if originServerName != t.ServerName {
if originServerName != t.ServerName { log.WithField("other_server", originServerName).Info("Suppressing send-to-device: originated elsewhere")
log.WithField("other_server", originServerName).Info("Suppressing send-to-device: originated elsewhere") return true
_ = msg.Ack() }
return
}
_, destServerName, err := gomatrixserverlib.SplitID('@', ote.UserID) _, destServerName, err := gomatrixserverlib.SplitID('@', ote.UserID)
if err != nil { if err != nil {
log.WithError(err).WithField("user_id", ote.UserID).Error("Failed to extract domain from send-to-device destination") log.WithError(err).WithField("user_id", ote.UserID).Error("Failed to extract domain from send-to-device destination")
_ = msg.Ack() return true
return }
}
// Pack the EDU and marshal it // Pack the EDU and marshal it
edu := &gomatrixserverlib.EDU{ edu := &gomatrixserverlib.EDU{
Type: gomatrixserverlib.MDirectToDevice, Type: gomatrixserverlib.MDirectToDevice,
Origin: string(t.ServerName), Origin: string(t.ServerName),
} }
tdm := gomatrixserverlib.ToDeviceMessage{ tdm := gomatrixserverlib.ToDeviceMessage{
Sender: ote.Sender, Sender: ote.Sender,
Type: ote.Type, Type: ote.Type,
MessageID: util.RandomString(32), MessageID: util.RandomString(32),
Messages: map[string]map[string]json.RawMessage{ Messages: map[string]map[string]json.RawMessage{
ote.UserID: { ote.UserID: {
ote.DeviceID: ote.Content, ote.DeviceID: ote.Content,
},
}, },
}, }
} if edu.Content, err = json.Marshal(tdm); err != nil {
if edu.Content, err = json.Marshal(tdm); err != nil { log.WithError(err).Error("failed to marshal EDU JSON")
log.WithError(err).Error("failed to marshal EDU JSON") return true
_ = msg.Ack() }
return
}
log.Infof("Sending send-to-device message into %q destination queue", destServerName) log.Infof("Sending send-to-device message into %q destination queue", destServerName)
if err := t.queues.SendEDU(edu, t.ServerName, []gomatrixserverlib.ServerName{destServerName}); err != nil { if err := t.queues.SendEDU(edu, t.ServerName, []gomatrixserverlib.ServerName{destServerName}); err != nil {
log.WithError(err).Error("failed to send EDU") log.WithError(err).Error("failed to send EDU")
} return false
}
_ = msg.Ack() return true
})
} }
// onTypingEvent is called in response to a message received on the typing // onTypingEvent is called in response to a message received on the typing
// events topic from the EDU server. // events topic from the EDU server.
func (t *OutputEDUConsumer) onTypingEvent(msg *nats.Msg) { func (t *OutputEDUConsumer) onTypingEvent(msg *nats.Msg) {
// Extract the typing event from msg. jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool {
var ote api.OutputTypingEvent // Extract the typing event from msg.
if err := json.Unmarshal(msg.Data, &ote); err != nil { var ote api.OutputTypingEvent
// Skip this msg but continue processing messages. if err := json.Unmarshal(msg.Data, &ote); err != nil {
log.WithError(err).Errorf("eduserver output log: message parse failed (expected typing)") // Skip this msg but continue processing messages.
_ = msg.Ack() log.WithError(err).Errorf("eduserver output log: message parse failed (expected typing)")
return _ = msg.Ack()
} return true
}
// only send typing events which originated from us // only send typing events which originated from us
_, typingServerName, err := gomatrixserverlib.SplitID('@', ote.Event.UserID) _, typingServerName, err := gomatrixserverlib.SplitID('@', ote.Event.UserID)
if err != nil { if err != nil {
log.WithError(err).WithField("user_id", ote.Event.UserID).Error("Failed to extract domain from typing sender") log.WithError(err).WithField("user_id", ote.Event.UserID).Error("Failed to extract domain from typing sender")
_ = msg.Ack() _ = msg.Ack()
return return true
} }
if typingServerName != t.ServerName { if typingServerName != t.ServerName {
return return true
} }
joined, err := t.db.GetJoinedHosts(context.TODO(), ote.Event.RoomID) joined, err := t.db.GetJoinedHosts(t.ctx, ote.Event.RoomID)
if err != nil { if err != nil {
log.WithError(err).WithField("room_id", ote.Event.RoomID).Error("failed to get joined hosts for room") log.WithError(err).WithField("room_id", ote.Event.RoomID).Error("failed to get joined hosts for room")
return return false
} }
names := make([]gomatrixserverlib.ServerName, len(joined)) names := make([]gomatrixserverlib.ServerName, len(joined))
for i := range joined { for i := range joined {
names[i] = joined[i].ServerName names[i] = joined[i].ServerName
} }
edu := &gomatrixserverlib.EDU{Type: ote.Event.Type} edu := &gomatrixserverlib.EDU{Type: ote.Event.Type}
if edu.Content, err = json.Marshal(map[string]interface{}{ if edu.Content, err = json.Marshal(map[string]interface{}{
"room_id": ote.Event.RoomID, "room_id": ote.Event.RoomID,
"user_id": ote.Event.UserID, "user_id": ote.Event.UserID,
"typing": ote.Event.Typing, "typing": ote.Event.Typing,
}); err != nil { }); err != nil {
log.WithError(err).Error("failed to marshal EDU JSON") log.WithError(err).Error("failed to marshal EDU JSON")
_ = msg.Ack() return true
return }
}
if err := t.queues.SendEDU(edu, t.ServerName, names); err != nil { if err := t.queues.SendEDU(edu, t.ServerName, names); err != nil {
log.WithError(err).Error("failed to send EDU") log.WithError(err).Error("failed to send EDU")
} return false
}
_ = msg.Ack() return true
})
} }
// onReceiptEvent is called in response to a message received on the receipt // onReceiptEvent is called in response to a message received on the receipt
// events topic from the EDU server. // events topic from the EDU server.
func (t *OutputEDUConsumer) onReceiptEvent(msg *nats.Msg) { func (t *OutputEDUConsumer) onReceiptEvent(msg *nats.Msg) {
// Extract the typing event from msg. jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool {
var receipt api.OutputReceiptEvent // Extract the typing event from msg.
if err := json.Unmarshal(msg.Data, &receipt); err != nil { var receipt api.OutputReceiptEvent
// Skip this msg but continue processing messages. if err := json.Unmarshal(msg.Data, &receipt); err != nil {
log.WithError(err).Errorf("eduserver output log: message parse failed (expected receipt)") // Skip this msg but continue processing messages.
_ = msg.Ack() log.WithError(err).Errorf("eduserver output log: message parse failed (expected receipt)")
return return true
} }
// only send receipt events which originated from us // only send receipt events which originated from us
_, receiptServerName, err := gomatrixserverlib.SplitID('@', receipt.UserID) _, receiptServerName, err := gomatrixserverlib.SplitID('@', receipt.UserID)
if err != nil { if err != nil {
log.WithError(err).WithField("user_id", receipt.UserID).Error("failed to extract domain from receipt sender") log.WithError(err).WithField("user_id", receipt.UserID).Error("failed to extract domain from receipt sender")
_ = msg.Ack() return true
return }
} if receiptServerName != t.ServerName {
if receiptServerName != t.ServerName { return true
_ = msg.Ack() }
return // don't log, very spammy as it logs for each remote receipt
}
joined, err := t.db.GetJoinedHosts(context.TODO(), receipt.RoomID) joined, err := t.db.GetJoinedHosts(t.ctx, receipt.RoomID)
if err != nil { if err != nil {
log.WithError(err).WithField("room_id", receipt.RoomID).Error("failed to get joined hosts for room") log.WithError(err).WithField("room_id", receipt.RoomID).Error("failed to get joined hosts for room")
return return false
} }
names := make([]gomatrixserverlib.ServerName, len(joined)) names := make([]gomatrixserverlib.ServerName, len(joined))
for i := range joined { for i := range joined {
names[i] = joined[i].ServerName names[i] = joined[i].ServerName
} }
content := map[string]api.FederationReceiptMRead{} content := map[string]api.FederationReceiptMRead{}
content[receipt.RoomID] = api.FederationReceiptMRead{ content[receipt.RoomID] = api.FederationReceiptMRead{
User: map[string]api.FederationReceiptData{ User: map[string]api.FederationReceiptData{
receipt.UserID: { receipt.UserID: {
Data: api.ReceiptTS{ Data: api.ReceiptTS{
TS: receipt.Timestamp, TS: receipt.Timestamp,
},
EventIDs: []string{receipt.EventID},
}, },
EventIDs: []string{receipt.EventID},
}, },
}, }
}
edu := &gomatrixserverlib.EDU{ edu := &gomatrixserverlib.EDU{
Type: gomatrixserverlib.MReceipt, Type: gomatrixserverlib.MReceipt,
Origin: string(t.ServerName), Origin: string(t.ServerName),
} }
if edu.Content, err = json.Marshal(content); err != nil { if edu.Content, err = json.Marshal(content); err != nil {
log.WithError(err).Error("failed to marshal EDU JSON") log.WithError(err).Error("failed to marshal EDU JSON")
_ = msg.Ack() return true
return }
}
if err := t.queues.SendEDU(edu, t.ServerName, names); err != nil { if err := t.queues.SendEDU(edu, t.ServerName, names); err != nil {
log.WithError(err).Error("failed to send EDU") log.WithError(err).Error("failed to send EDU")
} return false
}
_ = msg.Ack() return true
})
} }

View file

@ -35,6 +35,7 @@ import (
// KeyChangeConsumer consumes events that originate in key server. // KeyChangeConsumer consumes events that originate in key server.
type KeyChangeConsumer struct { type KeyChangeConsumer struct {
ctx context.Context
consumer *internal.ContinualConsumer consumer *internal.ContinualConsumer
db storage.Database db storage.Database
queues *queue.OutgoingQueues queues *queue.OutgoingQueues
@ -52,6 +53,7 @@ func NewKeyChangeConsumer(
rsAPI roomserverAPI.RoomserverInternalAPI, rsAPI roomserverAPI.RoomserverInternalAPI,
) *KeyChangeConsumer { ) *KeyChangeConsumer {
c := &KeyChangeConsumer{ c := &KeyChangeConsumer{
ctx: process.Context(),
consumer: &internal.ContinualConsumer{ consumer: &internal.ContinualConsumer{
Process: process, Process: process,
ComponentName: "federationsender/keychange", ComponentName: "federationsender/keychange",
@ -117,7 +119,7 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) error {
} }
var queryRes roomserverAPI.QueryRoomsForUserResponse var queryRes roomserverAPI.QueryRoomsForUserResponse
err = t.rsAPI.QueryRoomsForUser(context.Background(), &roomserverAPI.QueryRoomsForUserRequest{ err = t.rsAPI.QueryRoomsForUser(t.ctx, &roomserverAPI.QueryRoomsForUserRequest{
UserID: m.UserID, UserID: m.UserID,
WantMembership: "join", WantMembership: "join",
}, &queryRes) }, &queryRes)
@ -126,7 +128,7 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) error {
return nil return nil
} }
// send this key change to all servers who share rooms with this user. // send this key change to all servers who share rooms with this user.
destinations, err := t.db.GetJoinedHostsForRooms(context.Background(), queryRes.RoomIDs) destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs)
if err != nil { if err != nil {
logger.WithError(err).Error("failed to calculate joined hosts for rooms user is in") logger.WithError(err).Error("failed to calculate joined hosts for rooms user is in")
return nil return nil
@ -169,7 +171,7 @@ func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) error {
logger := logrus.WithField("user_id", output.UserID) logger := logrus.WithField("user_id", output.UserID)
var queryRes roomserverAPI.QueryRoomsForUserResponse var queryRes roomserverAPI.QueryRoomsForUserResponse
err = t.rsAPI.QueryRoomsForUser(context.Background(), &roomserverAPI.QueryRoomsForUserRequest{ err = t.rsAPI.QueryRoomsForUser(t.ctx, &roomserverAPI.QueryRoomsForUserRequest{
UserID: output.UserID, UserID: output.UserID,
WantMembership: "join", WantMembership: "join",
}, &queryRes) }, &queryRes)
@ -178,7 +180,7 @@ func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) error {
return nil return nil
} }
// send this key change to all servers who share rooms with this user. // send this key change to all servers who share rooms with this user.
destinations, err := t.db.GetJoinedHostsForRooms(context.Background(), queryRes.RoomIDs) destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs)
if err != nil { if err != nil {
logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined hosts for rooms user is in") logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined hosts for rooms user is in")
return nil return nil

View file

@ -33,6 +33,7 @@ import (
// OutputRoomEventConsumer consumes events that originated in the room server. // OutputRoomEventConsumer consumes events that originated in the room server.
type OutputRoomEventConsumer struct { type OutputRoomEventConsumer struct {
ctx context.Context
cfg *config.FederationAPI cfg *config.FederationAPI
rsAPI api.RoomserverInternalAPI rsAPI api.RoomserverInternalAPI
jetstream nats.JetStreamContext jetstream nats.JetStreamContext
@ -51,6 +52,7 @@ func NewOutputRoomEventConsumer(
rsAPI api.RoomserverInternalAPI, rsAPI api.RoomserverInternalAPI,
) *OutputRoomEventConsumer { ) *OutputRoomEventConsumer {
return &OutputRoomEventConsumer{ return &OutputRoomEventConsumer{
ctx: process.Context(),
cfg: cfg, cfg: cfg,
jetstream: js, jetstream: js,
db: store, db: store,
@ -71,65 +73,61 @@ func (s *OutputRoomEventConsumer) Start() error {
// because updates it will likely fail with a types.EventIDMismatchError when it // because updates it will likely fail with a types.EventIDMismatchError when it
// realises that it cannot update the room state using the deltas. // realises that it cannot update the room state using the deltas.
func (s *OutputRoomEventConsumer) onMessage(msg *nats.Msg) { func (s *OutputRoomEventConsumer) onMessage(msg *nats.Msg) {
// Parse out the event JSON jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool {
var output api.OutputEvent // Parse out the event JSON
if err := json.Unmarshal(msg.Data, &output); err != nil { var output api.OutputEvent
// If the message was invalid, log it and move on to the next message in the stream if err := json.Unmarshal(msg.Data, &output); err != nil {
log.WithError(err).Errorf("roomserver output log: message parse failure") // If the message was invalid, log it and move on to the next message in the stream
_ = msg.Ack() log.WithError(err).Errorf("roomserver output log: message parse failure")
return return true
}
switch output.Type {
case api.OutputTypeNewRoomEvent:
ev := output.NewRoomEvent.Event
if output.NewRoomEvent.RewritesState {
if err := s.db.PurgeRoomState(context.TODO(), ev.RoomID()); err != nil {
log.WithError(err).Errorf("roomserver output log: purge room state failure")
return
}
} }
if err := s.processMessage(*output.NewRoomEvent); err != nil { switch output.Type {
switch err.(type) { case api.OutputTypeNewRoomEvent:
case *queue.ErrorFederationDisabled: ev := output.NewRoomEvent.Event
log.WithField("error", output.Type).Info(
err.Error(), if output.NewRoomEvent.RewritesState {
) if err := s.db.PurgeRoomState(s.ctx, ev.RoomID()); err != nil {
_ = msg.Ack() log.WithError(err).Errorf("roomserver output log: purge room state failure")
default: return false
// panic rather than continue with an inconsistent database }
}
if err := s.processMessage(*output.NewRoomEvent); err != nil {
switch err.(type) {
case *queue.ErrorFederationDisabled:
log.WithField("error", output.Type).Info(
err.Error(),
)
default:
// panic rather than continue with an inconsistent database
log.WithFields(log.Fields{
"event_id": ev.EventID(),
"event": string(ev.JSON()),
"add": output.NewRoomEvent.AddsStateEventIDs,
"del": output.NewRoomEvent.RemovesStateEventIDs,
log.ErrorKey: err,
}).Panicf("roomserver output log: write room event failure")
}
}
case api.OutputTypeNewInboundPeek:
if err := s.processInboundPeek(*output.NewInboundPeek); err != nil {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"event_id": ev.EventID(), "event": output.NewInboundPeek,
"event": string(ev.JSON()),
"add": output.NewRoomEvent.AddsStateEventIDs,
"del": output.NewRoomEvent.RemovesStateEventIDs,
log.ErrorKey: err, log.ErrorKey: err,
}).Panicf("roomserver output log: write room event failure") }).Panicf("roomserver output log: remote peek event failure")
return false
} }
return
default:
log.WithField("type", output.Type).Debug(
"roomserver output log: ignoring unknown output type",
)
} }
_ = msg.Ack() return true
})
case api.OutputTypeNewInboundPeek:
if err := s.processInboundPeek(*output.NewInboundPeek); err != nil {
log.WithFields(log.Fields{
"event": output.NewInboundPeek,
log.ErrorKey: err,
}).Panicf("roomserver output log: remote peek event failure")
return
}
_ = msg.Ack()
default:
log.WithField("type", output.Type).Debug(
"roomserver output log: ignoring unknown output type",
)
_ = msg.Ack()
return
}
} }
// processInboundPeek starts tracking a new federated inbound peek (replacing the existing one if any) // processInboundPeek starts tracking a new federated inbound peek (replacing the existing one if any)
@ -146,7 +144,7 @@ func (s *OutputRoomEventConsumer) processInboundPeek(orp api.OutputNewInboundPee
// //
// This is making the tests flakey. // This is making the tests flakey.
return s.db.AddInboundPeek(context.TODO(), orp.ServerName, orp.RoomID, orp.PeekID, orp.RenewalInterval) return s.db.AddInboundPeek(s.ctx, orp.ServerName, orp.RoomID, orp.PeekID, orp.RenewalInterval)
} }
// processMessage updates the list of currently joined hosts in the room // processMessage updates the list of currently joined hosts in the room
@ -162,7 +160,7 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent) err
// TODO(#290): handle EventIDMismatchError and recover the current state by // TODO(#290): handle EventIDMismatchError and recover the current state by
// talking to the roomserver // talking to the roomserver
oldJoinedHosts, err := s.db.UpdateRoom( oldJoinedHosts, err := s.db.UpdateRoom(
context.TODO(), s.ctx,
ore.Event.RoomID(), ore.Event.RoomID(),
ore.LastSentEventID, ore.LastSentEventID,
ore.Event.EventID(), ore.Event.EventID(),
@ -255,7 +253,7 @@ func (s *OutputRoomEventConsumer) joinedHostsAtEvent(
} }
// handle peeking hosts // handle peeking hosts
inboundPeeks, err := s.db.GetInboundPeeks(context.TODO(), ore.Event.Event.RoomID()) inboundPeeks, err := s.db.GetInboundPeeks(s.ctx, ore.Event.Event.RoomID())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -373,7 +371,7 @@ func (s *OutputRoomEventConsumer) lookupStateEvents(
// from the roomserver using the query API. // from the roomserver using the query API.
eventReq := api.QueryEventsByIDRequest{EventIDs: missing} eventReq := api.QueryEventsByIDRequest{EventIDs: missing}
var eventResp api.QueryEventsByIDResponse var eventResp api.QueryEventsByIDResponse
if err := s.rsAPI.QueryEventsByID(context.TODO(), &eventReq, &eventResp); err != nil { if err := s.rsAPI.QueryEventsByID(s.ctx, &eventReq, &eventResp); err != nil {
return nil, err return nil, err
} }

View file

@ -0,0 +1,11 @@
package jetstream
import "github.com/nats-io/nats.go"
func WithJetStreamMessage(msg *nats.Msg, f func(msg *nats.Msg) bool) {
if f(msg) {
_ = msg.Ack()
} else {
_ = msg.Nak()
}
}

View file

@ -32,6 +32,7 @@ import (
// OutputClientDataConsumer consumes events that originated in the client API server. // OutputClientDataConsumer consumes events that originated in the client API server.
type OutputClientDataConsumer struct { type OutputClientDataConsumer struct {
ctx context.Context
jetstream nats.JetStreamContext jetstream nats.JetStreamContext
topic string topic string
db storage.Database db storage.Database
@ -49,6 +50,7 @@ func NewOutputClientDataConsumer(
stream types.StreamProvider, stream types.StreamProvider,
) *OutputClientDataConsumer { ) *OutputClientDataConsumer {
return &OutputClientDataConsumer{ return &OutputClientDataConsumer{
ctx: process.Context(),
jetstream: js, jetstream: js,
topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputClientData), topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputClientData),
db: store, db: store,
@ -67,36 +69,37 @@ func (s *OutputClientDataConsumer) Start() error {
// It is not safe for this function to be called from multiple goroutines, or else the // It is not safe for this function to be called from multiple goroutines, or else the
// sync stream position may race and be incorrectly calculated. // sync stream position may race and be incorrectly calculated.
func (s *OutputClientDataConsumer) onMessage(msg *nats.Msg) { func (s *OutputClientDataConsumer) onMessage(msg *nats.Msg) {
// Parse out the event JSON jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool {
userID := msg.Header.Get(jetstream.UserID) // Parse out the event JSON
var output eventutil.AccountData userID := msg.Header.Get(jetstream.UserID)
if err := json.Unmarshal(msg.Data, &output); err != nil { var output eventutil.AccountData
// If the message was invalid, log it and move on to the next message in the stream if err := json.Unmarshal(msg.Data, &output); err != nil {
log.WithError(err).Errorf("client API server output log: message parse failure") // If the message was invalid, log it and move on to the next message in the stream
sentry.CaptureException(err) log.WithError(err).Errorf("client API server output log: message parse failure")
_ = msg.Ack() sentry.CaptureException(err)
return return true
} }
log.WithFields(log.Fields{
"type": output.Type,
"room_id": output.RoomID,
}).Info("received data from client API server")
streamPos, err := s.db.UpsertAccountData(
context.TODO(), userID, output.RoomID, output.Type,
)
if err != nil {
sentry.CaptureException(err)
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"type": output.Type, "type": output.Type,
"room_id": output.RoomID, "room_id": output.RoomID,
log.ErrorKey: err, }).Info("received data from client API server")
}).Panicf("could not save account data")
}
s.stream.Advance(streamPos) streamPos, err := s.db.UpsertAccountData(
s.notifier.OnNewAccountData(userID, types.StreamingToken{AccountDataPosition: streamPos}) s.ctx, userID, output.RoomID, output.Type,
)
if err != nil {
sentry.CaptureException(err)
log.WithFields(log.Fields{
"type": output.Type,
"room_id": output.RoomID,
log.ErrorKey: err,
}).Panicf("could not save account data")
}
_ = msg.Ack() s.stream.Advance(streamPos)
s.notifier.OnNewAccountData(userID, types.StreamingToken{AccountDataPosition: streamPos})
return true
})
} }

View file

@ -32,6 +32,7 @@ import (
// OutputReceiptEventConsumer consumes events that originated in the EDU server. // OutputReceiptEventConsumer consumes events that originated in the EDU server.
type OutputReceiptEventConsumer struct { type OutputReceiptEventConsumer struct {
ctx context.Context
jetstream nats.JetStreamContext jetstream nats.JetStreamContext
topic string topic string
db storage.Database db storage.Database
@ -50,6 +51,7 @@ func NewOutputReceiptEventConsumer(
stream types.StreamProvider, stream types.StreamProvider,
) *OutputReceiptEventConsumer { ) *OutputReceiptEventConsumer {
return &OutputReceiptEventConsumer{ return &OutputReceiptEventConsumer{
ctx: process.Context(),
jetstream: js, jetstream: js,
topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputReceiptEvent), topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputReceiptEvent),
db: store, db: store,
@ -65,30 +67,31 @@ func (s *OutputReceiptEventConsumer) Start() error {
} }
func (s *OutputReceiptEventConsumer) onMessage(msg *nats.Msg) { func (s *OutputReceiptEventConsumer) onMessage(msg *nats.Msg) {
var output api.OutputReceiptEvent jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool {
if err := json.Unmarshal(msg.Data, &output); err != nil { var output api.OutputReceiptEvent
// If the message was invalid, log it and move on to the next message in the stream if err := json.Unmarshal(msg.Data, &output); err != nil {
log.WithError(err).Errorf("EDU server output log: message parse failure") // If the message was invalid, log it and move on to the next message in the stream
sentry.CaptureException(err) log.WithError(err).Errorf("EDU server output log: message parse failure")
_ = msg.Ack() sentry.CaptureException(err)
return return true
} }
streamPos, err := s.db.StoreReceipt( streamPos, err := s.db.StoreReceipt(
context.TODO(), s.ctx,
output.RoomID, output.RoomID,
output.Type, output.Type,
output.UserID, output.UserID,
output.EventID, output.EventID,
output.Timestamp, output.Timestamp,
) )
if err != nil { if err != nil {
sentry.CaptureException(err) sentry.CaptureException(err)
return return true
} }
s.stream.Advance(streamPos) s.stream.Advance(streamPos)
s.notifier.OnNewReceipt(output.RoomID, types.StreamingToken{ReceiptPosition: streamPos}) s.notifier.OnNewReceipt(output.RoomID, types.StreamingToken{ReceiptPosition: streamPos})
_ = msg.Ack() return true
})
} }

View file

@ -34,6 +34,7 @@ import (
// OutputSendToDeviceEventConsumer consumes events that originated in the EDU server. // OutputSendToDeviceEventConsumer consumes events that originated in the EDU server.
type OutputSendToDeviceEventConsumer struct { type OutputSendToDeviceEventConsumer struct {
ctx context.Context
jetstream nats.JetStreamContext jetstream nats.JetStreamContext
topic string topic string
db storage.Database db storage.Database
@ -53,6 +54,7 @@ func NewOutputSendToDeviceEventConsumer(
stream types.StreamProvider, stream types.StreamProvider,
) *OutputSendToDeviceEventConsumer { ) *OutputSendToDeviceEventConsumer {
return &OutputSendToDeviceEventConsumer{ return &OutputSendToDeviceEventConsumer{
ctx: process.Context(),
jetstream: js, jetstream: js,
topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputSendToDeviceEvent), topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputSendToDeviceEvent),
db: store, db: store,
@ -69,48 +71,47 @@ func (s *OutputSendToDeviceEventConsumer) Start() error {
} }
func (s *OutputSendToDeviceEventConsumer) onMessage(msg *nats.Msg) { func (s *OutputSendToDeviceEventConsumer) onMessage(msg *nats.Msg) {
var output api.OutputSendToDeviceEvent jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool {
if err := json.Unmarshal(msg.Data, &output); err != nil { var output api.OutputSendToDeviceEvent
// If the message was invalid, log it and move on to the next message in the stream if err := json.Unmarshal(msg.Data, &output); err != nil {
log.WithError(err).Errorf("EDU server output log: message parse failure") // If the message was invalid, log it and move on to the next message in the stream
sentry.CaptureException(err) log.WithError(err).Errorf("EDU server output log: message parse failure")
_ = msg.Ack() sentry.CaptureException(err)
return return true
} }
_, domain, err := gomatrixserverlib.SplitID('@', output.UserID) _, domain, err := gomatrixserverlib.SplitID('@', output.UserID)
if err != nil { if err != nil {
sentry.CaptureException(err) sentry.CaptureException(err)
_ = msg.Ack() return true
return }
} if domain != s.serverName {
if domain != s.serverName { return true
_ = msg.Ack() }
return
}
util.GetLogger(context.TODO()).WithFields(log.Fields{ util.GetLogger(context.TODO()).WithFields(log.Fields{
"sender": output.Sender, "sender": output.Sender,
"user_id": output.UserID, "user_id": output.UserID,
"device_id": output.DeviceID, "device_id": output.DeviceID,
"event_type": output.Type, "event_type": output.Type,
}).Info("sync API received send-to-device event from EDU server") }).Info("sync API received send-to-device event from EDU server")
streamPos, err := s.db.StoreNewSendForDeviceMessage( streamPos, err := s.db.StoreNewSendForDeviceMessage(
context.TODO(), output.UserID, output.DeviceID, output.SendToDeviceEvent, s.ctx, output.UserID, output.DeviceID, output.SendToDeviceEvent,
) )
if err != nil { if err != nil {
sentry.CaptureException(err) sentry.CaptureException(err)
log.WithError(err).Errorf("failed to store send-to-device message") log.WithError(err).Errorf("failed to store send-to-device message")
return return false
} }
s.stream.Advance(streamPos) s.stream.Advance(streamPos)
s.notifier.OnNewSendToDevice( s.notifier.OnNewSendToDevice(
output.UserID, output.UserID,
[]string{output.DeviceID}, []string{output.DeviceID},
types.StreamingToken{SendToDevicePosition: streamPos}, types.StreamingToken{SendToDevicePosition: streamPos},
) )
_ = msg.Ack() return true
})
} }

View file

@ -15,6 +15,7 @@
package consumers package consumers
import ( import (
"context"
"encoding/json" "encoding/json"
"github.com/getsentry/sentry-go" "github.com/getsentry/sentry-go"
@ -32,6 +33,7 @@ import (
// OutputTypingEventConsumer consumes events that originated in the EDU server. // OutputTypingEventConsumer consumes events that originated in the EDU server.
type OutputTypingEventConsumer struct { type OutputTypingEventConsumer struct {
ctx context.Context
jetstream nats.JetStreamContext jetstream nats.JetStreamContext
topic string topic string
eduCache *cache.EDUCache eduCache *cache.EDUCache
@ -51,6 +53,7 @@ func NewOutputTypingEventConsumer(
stream types.StreamProvider, stream types.StreamProvider,
) *OutputTypingEventConsumer { ) *OutputTypingEventConsumer {
return &OutputTypingEventConsumer{ return &OutputTypingEventConsumer{
ctx: process.Context(),
jetstream: js, jetstream: js,
topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputTypingEvent), topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputTypingEvent),
eduCache: eduCache, eduCache: eduCache,
@ -66,35 +69,36 @@ func (s *OutputTypingEventConsumer) Start() error {
} }
func (s *OutputTypingEventConsumer) onMessage(msg *nats.Msg) { func (s *OutputTypingEventConsumer) onMessage(msg *nats.Msg) {
var output api.OutputTypingEvent jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool {
if err := json.Unmarshal(msg.Data, &output); err != nil { var output api.OutputTypingEvent
// If the message was invalid, log it and move on to the next message in the stream if err := json.Unmarshal(msg.Data, &output); err != nil {
log.WithError(err).Errorf("EDU server output log: message parse failure") // If the message was invalid, log it and move on to the next message in the stream
sentry.CaptureException(err) log.WithError(err).Errorf("EDU server output log: message parse failure")
_ = msg.Ack() sentry.CaptureException(err)
return return true
} }
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"room_id": output.Event.RoomID, "room_id": output.Event.RoomID,
"user_id": output.Event.UserID, "user_id": output.Event.UserID,
"typing": output.Event.Typing, "typing": output.Event.Typing,
}).Debug("received data from EDU server") }).Debug("received data from EDU server")
var typingPos types.StreamPosition var typingPos types.StreamPosition
typingEvent := output.Event typingEvent := output.Event
if typingEvent.Typing { if typingEvent.Typing {
typingPos = types.StreamPosition( typingPos = types.StreamPosition(
s.eduCache.AddTypingUser(typingEvent.UserID, typingEvent.RoomID, output.ExpireTime), s.eduCache.AddTypingUser(typingEvent.UserID, typingEvent.RoomID, output.ExpireTime),
) )
} else { } else {
typingPos = types.StreamPosition( typingPos = types.StreamPosition(
s.eduCache.RemoveUser(typingEvent.UserID, typingEvent.RoomID), s.eduCache.RemoveUser(typingEvent.UserID, typingEvent.RoomID),
) )
} }
s.stream.Advance(typingPos) s.stream.Advance(typingPos)
s.notifier.OnNewTyping(output.Event.RoomID, types.StreamingToken{TypingPosition: typingPos}) s.notifier.OnNewTyping(output.Event.RoomID, types.StreamingToken{TypingPosition: typingPos})
_ = msg.Ack() return true
})
} }

View file

@ -34,6 +34,7 @@ import (
// OutputKeyChangeEventConsumer consumes events that originated in the key server. // OutputKeyChangeEventConsumer consumes events that originated in the key server.
type OutputKeyChangeEventConsumer struct { type OutputKeyChangeEventConsumer struct {
ctx context.Context
keyChangeConsumer *internal.ContinualConsumer keyChangeConsumer *internal.ContinualConsumer
db storage.Database db storage.Database
notifier *notifier.Notifier notifier *notifier.Notifier
@ -68,6 +69,7 @@ func NewOutputKeyChangeEventConsumer(
} }
s := &OutputKeyChangeEventConsumer{ s := &OutputKeyChangeEventConsumer{
ctx: process.Context(),
keyChangeConsumer: &consumer, keyChangeConsumer: &consumer,
db: store, db: store,
serverName: serverName, serverName: serverName,
@ -131,7 +133,7 @@ func (s *OutputKeyChangeEventConsumer) onDeviceKeyMessage(m api.DeviceMessage, o
output := m.DeviceKeys output := m.DeviceKeys
// work out who we need to notify about the new key // work out who we need to notify about the new key
var queryRes roomserverAPI.QuerySharedUsersResponse var queryRes roomserverAPI.QuerySharedUsersResponse
err := s.rsAPI.QuerySharedUsers(context.Background(), &roomserverAPI.QuerySharedUsersRequest{ err := s.rsAPI.QuerySharedUsers(s.ctx, &roomserverAPI.QuerySharedUsersRequest{
UserID: output.UserID, UserID: output.UserID,
}, &queryRes) }, &queryRes)
if err != nil { if err != nil {
@ -158,7 +160,7 @@ func (s *OutputKeyChangeEventConsumer) onCrossSigningMessage(m api.DeviceMessage
output := m.CrossSigningKeyUpdate output := m.CrossSigningKeyUpdate
// work out who we need to notify about the new key // work out who we need to notify about the new key
var queryRes roomserverAPI.QuerySharedUsersResponse var queryRes roomserverAPI.QuerySharedUsersResponse
err := s.rsAPI.QuerySharedUsers(context.Background(), &roomserverAPI.QuerySharedUsersRequest{ err := s.rsAPI.QuerySharedUsers(s.ctx, &roomserverAPI.QuerySharedUsersRequest{
UserID: output.UserID, UserID: output.UserID,
}, &queryRes) }, &queryRes)
if err != nil { if err != nil {

View file

@ -34,6 +34,7 @@ import (
// OutputRoomEventConsumer consumes events that originated in the room server. // OutputRoomEventConsumer consumes events that originated in the room server.
type OutputRoomEventConsumer struct { type OutputRoomEventConsumer struct {
ctx context.Context
cfg *config.SyncAPI cfg *config.SyncAPI
rsAPI api.RoomserverInternalAPI rsAPI api.RoomserverInternalAPI
jetstream nats.JetStreamContext jetstream nats.JetStreamContext
@ -56,6 +57,7 @@ func NewOutputRoomEventConsumer(
rsAPI api.RoomserverInternalAPI, rsAPI api.RoomserverInternalAPI,
) *OutputRoomEventConsumer { ) *OutputRoomEventConsumer {
return &OutputRoomEventConsumer{ return &OutputRoomEventConsumer{
ctx: process.Context(),
cfg: cfg, cfg: cfg,
jetstream: js, jetstream: js,
topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputRoomEvent), topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputRoomEvent),
@ -77,55 +79,53 @@ func (s *OutputRoomEventConsumer) Start() error {
// It is not safe for this function to be called from multiple goroutines, or else the // It is not safe for this function to be called from multiple goroutines, or else the
// sync stream position may race and be incorrectly calculated. // sync stream position may race and be incorrectly calculated.
func (s *OutputRoomEventConsumer) onMessage(msg *nats.Msg) { func (s *OutputRoomEventConsumer) onMessage(msg *nats.Msg) {
// Parse out the event JSON jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool {
var err error // Parse out the event JSON
var output api.OutputEvent var err error
if err = json.Unmarshal(msg.Data, &output); err != nil { var output api.OutputEvent
// If the message was invalid, log it and move on to the next message in the stream if err = json.Unmarshal(msg.Data, &output); err != nil {
log.WithError(err).Errorf("roomserver output log: message parse failure") // If the message was invalid, log it and move on to the next message in the stream
_ = msg.Ack() log.WithError(err).Errorf("roomserver output log: message parse failure")
return return true
}
switch output.Type {
case api.OutputTypeNewRoomEvent:
// Ignore redaction events. We will add them to the database when they are
// validated (when we receive OutputTypeRedactedEvent)
event := output.NewRoomEvent.Event
if event.Type() == gomatrixserverlib.MRoomRedaction && event.StateKey() == nil {
// in the special case where the event redacts itself, just pass the message through because
// we will never see the other part of the pair
if event.Redacts() != event.EventID() {
_ = msg.Ack()
return
}
} }
err = s.onNewRoomEvent(context.TODO(), *output.NewRoomEvent)
case api.OutputTypeOldRoomEvent:
err = s.onOldRoomEvent(context.TODO(), *output.OldRoomEvent)
case api.OutputTypeNewInviteEvent:
s.onNewInviteEvent(context.TODO(), *output.NewInviteEvent)
case api.OutputTypeRetireInviteEvent:
s.onRetireInviteEvent(context.TODO(), *output.RetireInviteEvent)
case api.OutputTypeNewPeek:
s.onNewPeek(context.TODO(), *output.NewPeek)
case api.OutputTypeRetirePeek:
s.onRetirePeek(context.TODO(), *output.RetirePeek)
case api.OutputTypeRedactedEvent:
err = s.onRedactEvent(context.TODO(), *output.RedactedEvent)
default:
log.WithField("type", output.Type).Debug(
"roomserver output log: ignoring unknown output type",
)
_ = msg.Ack()
}
if err != nil {
log.WithError(err).Error("roomserver output log: failed to process event")
_ = msg.Nak()
return
}
_ = msg.Ack() switch output.Type {
case api.OutputTypeNewRoomEvent:
// Ignore redaction events. We will add them to the database when they are
// validated (when we receive OutputTypeRedactedEvent)
event := output.NewRoomEvent.Event
if event.Type() == gomatrixserverlib.MRoomRedaction && event.StateKey() == nil {
// in the special case where the event redacts itself, just pass the message through because
// we will never see the other part of the pair
if event.Redacts() != event.EventID() {
return true
}
}
err = s.onNewRoomEvent(s.ctx, *output.NewRoomEvent)
case api.OutputTypeOldRoomEvent:
err = s.onOldRoomEvent(s.ctx, *output.OldRoomEvent)
case api.OutputTypeNewInviteEvent:
s.onNewInviteEvent(s.ctx, *output.NewInviteEvent)
case api.OutputTypeRetireInviteEvent:
s.onRetireInviteEvent(s.ctx, *output.RetireInviteEvent)
case api.OutputTypeNewPeek:
s.onNewPeek(s.ctx, *output.NewPeek)
case api.OutputTypeRetirePeek:
s.onRetirePeek(s.ctx, *output.RetirePeek)
case api.OutputTypeRedactedEvent:
err = s.onRedactEvent(s.ctx, *output.RedactedEvent)
default:
log.WithField("type", output.Type).Debug(
"roomserver output log: ignoring unknown output type",
)
}
if err != nil {
log.WithError(err).Error("roomserver output log: failed to process event")
return false
}
return true
})
} }
func (s *OutputRoomEventConsumer) onRedactEvent( func (s *OutputRoomEventConsumer) onRedactEvent(