From 51de6612a66e09c76f32d2c857263754c039db23 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Wed, 5 Jan 2022 13:45:27 +0000 Subject: [PATCH] Add `jetstream.WithJetStreamMessage` to make ack/nak-ing less messy, use process context in consumers --- appservice/consumers/roomserver.go | 42 +-- federationapi/consumers/eduserver.go | 279 ++++++++++---------- federationapi/consumers/keychange.go | 10 +- federationapi/consumers/roomserver.go | 110 ++++---- setup/jetstream/helpers.go | 11 + syncapi/consumers/clientapi.go | 59 +++-- syncapi/consumers/eduserver_receipts.go | 49 ++-- syncapi/consumers/eduserver_sendtodevice.go | 79 +++--- syncapi/consumers/eduserver_typing.go | 58 ++-- syncapi/consumers/keychange.go | 6 +- syncapi/consumers/roomserver.go | 94 +++---- 11 files changed, 412 insertions(+), 385 deletions(-) create mode 100644 setup/jetstream/helpers.go diff --git a/appservice/consumers/roomserver.go b/appservice/consumers/roomserver.go index ddcc478c..139b5724 100644 --- a/appservice/consumers/roomserver.go +++ b/appservice/consumers/roomserver.go @@ -32,6 +32,7 @@ import ( // OutputRoomEventConsumer consumes events that originated in the room server. type OutputRoomEventConsumer struct { + ctx context.Context jetstream nats.JetStreamContext topic string asDB storage.Database @@ -51,6 +52,7 @@ func NewOutputRoomEventConsumer( workerStates []types.ApplicationServiceWorkerState, ) *OutputRoomEventConsumer { return &OutputRoomEventConsumer{ + ctx: process.Context(), jetstream: js, topic: cfg.Global.JetStream.TopicFor(jetstream.OutputRoomEvent), asDB: appserviceDB, @@ -69,30 +71,30 @@ func (s *OutputRoomEventConsumer) Start() error { // onMessage is called when the appservice component receives a new event from // the room server output log. func (s *OutputRoomEventConsumer) onMessage(msg *nats.Msg) { - // Parse out the event JSON - var output api.OutputEvent - if err := json.Unmarshal(msg.Data, &output); err != nil { - // If the message was invalid, log it and move on to the next message in the stream - log.WithError(err).Errorf("roomserver output log: message parse failure") - _ = msg.Ack() - return - } + jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool { + // Parse out the event JSON + var output api.OutputEvent + if err := json.Unmarshal(msg.Data, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + log.WithError(err).Errorf("roomserver output log: message parse failure") + return true + } - if output.Type != api.OutputTypeNewRoomEvent { - _ = msg.Ack() - return - } + if output.Type != api.OutputTypeNewRoomEvent { + return true + } - events := []*gomatrixserverlib.HeaderedEvent{output.NewRoomEvent.Event} - events = append(events, output.NewRoomEvent.AddStateEvents...) + events := []*gomatrixserverlib.HeaderedEvent{output.NewRoomEvent.Event} + events = append(events, output.NewRoomEvent.AddStateEvents...) - // Send event to any relevant application services - if err := s.filterRoomserverEvents(context.TODO(), events); err != nil { - log.WithError(err).Errorf("roomserver output log: filter error") - return - } + // Send event to any relevant application services + if err := s.filterRoomserverEvents(context.TODO(), events); err != nil { + log.WithError(err).Errorf("roomserver output log: filter error") + return true + } - _ = msg.Ack() + return true + }) } // filterRoomserverEvents takes in events and decides whether any of them need diff --git a/federationapi/consumers/eduserver.go b/federationapi/consumers/eduserver.go index 3cef5c1d..9e52acef 100644 --- a/federationapi/consumers/eduserver.go +++ b/federationapi/consumers/eduserver.go @@ -32,6 +32,7 @@ import ( // OutputEDUConsumer consumes events that originate in EDU server. type OutputEDUConsumer struct { + ctx context.Context jetstream nats.JetStreamContext db storage.Database queues *queue.OutgoingQueues @@ -50,6 +51,7 @@ func NewOutputEDUConsumer( store storage.Database, ) *OutputEDUConsumer { return &OutputEDUConsumer{ + ctx: process.Context(), jetstream: js, queues: queues, db: store, @@ -78,174 +80,173 @@ func (t *OutputEDUConsumer) Start() error { // send-to-device events topic from the EDU server. func (t *OutputEDUConsumer) onSendToDeviceEvent(msg *nats.Msg) { // Extract the send-to-device event from msg. - var ote api.OutputSendToDeviceEvent - if err := json.Unmarshal(msg.Data, &ote); err != nil { - log.WithError(err).Errorf("eduserver output log: message parse failed (expected send-to-device)") - _ = msg.Ack() - return - } + jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool { + var ote api.OutputSendToDeviceEvent + if err := json.Unmarshal(msg.Data, &ote); err != nil { + log.WithError(err).Errorf("eduserver output log: message parse failed (expected send-to-device)") + return true + } - // only send send-to-device events which originated from us - _, originServerName, err := gomatrixserverlib.SplitID('@', ote.Sender) - if err != nil { - log.WithError(err).WithField("user_id", ote.Sender).Error("Failed to extract domain from send-to-device sender") - _ = msg.Ack() - return - } - if originServerName != t.ServerName { - log.WithField("other_server", originServerName).Info("Suppressing send-to-device: originated elsewhere") - _ = msg.Ack() - return - } + // only send send-to-device events which originated from us + _, originServerName, err := gomatrixserverlib.SplitID('@', ote.Sender) + if err != nil { + log.WithError(err).WithField("user_id", ote.Sender).Error("Failed to extract domain from send-to-device sender") + return true + } + if originServerName != t.ServerName { + log.WithField("other_server", originServerName).Info("Suppressing send-to-device: originated elsewhere") + return true + } - _, destServerName, err := gomatrixserverlib.SplitID('@', ote.UserID) - if err != nil { - log.WithError(err).WithField("user_id", ote.UserID).Error("Failed to extract domain from send-to-device destination") - _ = msg.Ack() - return - } + _, destServerName, err := gomatrixserverlib.SplitID('@', ote.UserID) + if err != nil { + log.WithError(err).WithField("user_id", ote.UserID).Error("Failed to extract domain from send-to-device destination") + return true + } - // Pack the EDU and marshal it - edu := &gomatrixserverlib.EDU{ - Type: gomatrixserverlib.MDirectToDevice, - Origin: string(t.ServerName), - } - tdm := gomatrixserverlib.ToDeviceMessage{ - Sender: ote.Sender, - Type: ote.Type, - MessageID: util.RandomString(32), - Messages: map[string]map[string]json.RawMessage{ - ote.UserID: { - ote.DeviceID: ote.Content, + // Pack the EDU and marshal it + edu := &gomatrixserverlib.EDU{ + Type: gomatrixserverlib.MDirectToDevice, + Origin: string(t.ServerName), + } + tdm := gomatrixserverlib.ToDeviceMessage{ + Sender: ote.Sender, + Type: ote.Type, + MessageID: util.RandomString(32), + Messages: map[string]map[string]json.RawMessage{ + ote.UserID: { + ote.DeviceID: ote.Content, + }, }, - }, - } - if edu.Content, err = json.Marshal(tdm); err != nil { - log.WithError(err).Error("failed to marshal EDU JSON") - _ = msg.Ack() - return - } + } + if edu.Content, err = json.Marshal(tdm); err != nil { + log.WithError(err).Error("failed to marshal EDU JSON") + return true + } - log.Infof("Sending send-to-device message into %q destination queue", destServerName) - if err := t.queues.SendEDU(edu, t.ServerName, []gomatrixserverlib.ServerName{destServerName}); err != nil { - log.WithError(err).Error("failed to send EDU") - } + log.Infof("Sending send-to-device message into %q destination queue", destServerName) + if err := t.queues.SendEDU(edu, t.ServerName, []gomatrixserverlib.ServerName{destServerName}); err != nil { + log.WithError(err).Error("failed to send EDU") + return false + } - _ = msg.Ack() + return true + }) } // onTypingEvent is called in response to a message received on the typing // events topic from the EDU server. func (t *OutputEDUConsumer) onTypingEvent(msg *nats.Msg) { - // Extract the typing event from msg. - var ote api.OutputTypingEvent - if err := json.Unmarshal(msg.Data, &ote); err != nil { - // Skip this msg but continue processing messages. - log.WithError(err).Errorf("eduserver output log: message parse failed (expected typing)") - _ = msg.Ack() - return - } + jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool { + // Extract the typing event from msg. + var ote api.OutputTypingEvent + if err := json.Unmarshal(msg.Data, &ote); err != nil { + // Skip this msg but continue processing messages. + log.WithError(err).Errorf("eduserver output log: message parse failed (expected typing)") + _ = msg.Ack() + return true + } - // only send typing events which originated from us - _, typingServerName, err := gomatrixserverlib.SplitID('@', ote.Event.UserID) - if err != nil { - log.WithError(err).WithField("user_id", ote.Event.UserID).Error("Failed to extract domain from typing sender") - _ = msg.Ack() - return - } - if typingServerName != t.ServerName { - return - } + // only send typing events which originated from us + _, typingServerName, err := gomatrixserverlib.SplitID('@', ote.Event.UserID) + if err != nil { + log.WithError(err).WithField("user_id", ote.Event.UserID).Error("Failed to extract domain from typing sender") + _ = msg.Ack() + return true + } + if typingServerName != t.ServerName { + return true + } - joined, err := t.db.GetJoinedHosts(context.TODO(), ote.Event.RoomID) - if err != nil { - log.WithError(err).WithField("room_id", ote.Event.RoomID).Error("failed to get joined hosts for room") - return - } + joined, err := t.db.GetJoinedHosts(t.ctx, ote.Event.RoomID) + if err != nil { + log.WithError(err).WithField("room_id", ote.Event.RoomID).Error("failed to get joined hosts for room") + return false + } - names := make([]gomatrixserverlib.ServerName, len(joined)) - for i := range joined { - names[i] = joined[i].ServerName - } + names := make([]gomatrixserverlib.ServerName, len(joined)) + for i := range joined { + names[i] = joined[i].ServerName + } - edu := &gomatrixserverlib.EDU{Type: ote.Event.Type} - if edu.Content, err = json.Marshal(map[string]interface{}{ - "room_id": ote.Event.RoomID, - "user_id": ote.Event.UserID, - "typing": ote.Event.Typing, - }); err != nil { - log.WithError(err).Error("failed to marshal EDU JSON") - _ = msg.Ack() - return - } + edu := &gomatrixserverlib.EDU{Type: ote.Event.Type} + if edu.Content, err = json.Marshal(map[string]interface{}{ + "room_id": ote.Event.RoomID, + "user_id": ote.Event.UserID, + "typing": ote.Event.Typing, + }); err != nil { + log.WithError(err).Error("failed to marshal EDU JSON") + return true + } - if err := t.queues.SendEDU(edu, t.ServerName, names); err != nil { - log.WithError(err).Error("failed to send EDU") - } + if err := t.queues.SendEDU(edu, t.ServerName, names); err != nil { + log.WithError(err).Error("failed to send EDU") + return false + } - _ = msg.Ack() + return true + }) } // onReceiptEvent is called in response to a message received on the receipt // events topic from the EDU server. func (t *OutputEDUConsumer) onReceiptEvent(msg *nats.Msg) { - // Extract the typing event from msg. - var receipt api.OutputReceiptEvent - if err := json.Unmarshal(msg.Data, &receipt); err != nil { - // Skip this msg but continue processing messages. - log.WithError(err).Errorf("eduserver output log: message parse failed (expected receipt)") - _ = msg.Ack() - return - } + jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool { + // Extract the typing event from msg. + var receipt api.OutputReceiptEvent + if err := json.Unmarshal(msg.Data, &receipt); err != nil { + // Skip this msg but continue processing messages. + log.WithError(err).Errorf("eduserver output log: message parse failed (expected receipt)") + return true + } - // only send receipt events which originated from us - _, receiptServerName, err := gomatrixserverlib.SplitID('@', receipt.UserID) - if err != nil { - log.WithError(err).WithField("user_id", receipt.UserID).Error("failed to extract domain from receipt sender") - _ = msg.Ack() - return - } - if receiptServerName != t.ServerName { - _ = msg.Ack() - return // don't log, very spammy as it logs for each remote receipt - } + // only send receipt events which originated from us + _, receiptServerName, err := gomatrixserverlib.SplitID('@', receipt.UserID) + if err != nil { + log.WithError(err).WithField("user_id", receipt.UserID).Error("failed to extract domain from receipt sender") + return true + } + if receiptServerName != t.ServerName { + return true + } - joined, err := t.db.GetJoinedHosts(context.TODO(), receipt.RoomID) - if err != nil { - log.WithError(err).WithField("room_id", receipt.RoomID).Error("failed to get joined hosts for room") - return - } + joined, err := t.db.GetJoinedHosts(t.ctx, receipt.RoomID) + if err != nil { + log.WithError(err).WithField("room_id", receipt.RoomID).Error("failed to get joined hosts for room") + return false + } - names := make([]gomatrixserverlib.ServerName, len(joined)) - for i := range joined { - names[i] = joined[i].ServerName - } + names := make([]gomatrixserverlib.ServerName, len(joined)) + for i := range joined { + names[i] = joined[i].ServerName + } - content := map[string]api.FederationReceiptMRead{} - content[receipt.RoomID] = api.FederationReceiptMRead{ - User: map[string]api.FederationReceiptData{ - receipt.UserID: { - Data: api.ReceiptTS{ - TS: receipt.Timestamp, + content := map[string]api.FederationReceiptMRead{} + content[receipt.RoomID] = api.FederationReceiptMRead{ + User: map[string]api.FederationReceiptData{ + receipt.UserID: { + Data: api.ReceiptTS{ + TS: receipt.Timestamp, + }, + EventIDs: []string{receipt.EventID}, }, - EventIDs: []string{receipt.EventID}, }, - }, - } + } - edu := &gomatrixserverlib.EDU{ - Type: gomatrixserverlib.MReceipt, - Origin: string(t.ServerName), - } - if edu.Content, err = json.Marshal(content); err != nil { - log.WithError(err).Error("failed to marshal EDU JSON") - _ = msg.Ack() - return - } + edu := &gomatrixserverlib.EDU{ + Type: gomatrixserverlib.MReceipt, + Origin: string(t.ServerName), + } + if edu.Content, err = json.Marshal(content); err != nil { + log.WithError(err).Error("failed to marshal EDU JSON") + return true + } - if err := t.queues.SendEDU(edu, t.ServerName, names); err != nil { - log.WithError(err).Error("failed to send EDU") - } + if err := t.queues.SendEDU(edu, t.ServerName, names); err != nil { + log.WithError(err).Error("failed to send EDU") + return false + } - _ = msg.Ack() + return true + }) } diff --git a/federationapi/consumers/keychange.go b/federationapi/consumers/keychange.go index 882c8ada..adba5800 100644 --- a/federationapi/consumers/keychange.go +++ b/federationapi/consumers/keychange.go @@ -35,6 +35,7 @@ import ( // KeyChangeConsumer consumes events that originate in key server. type KeyChangeConsumer struct { + ctx context.Context consumer *internal.ContinualConsumer db storage.Database queues *queue.OutgoingQueues @@ -52,6 +53,7 @@ func NewKeyChangeConsumer( rsAPI roomserverAPI.RoomserverInternalAPI, ) *KeyChangeConsumer { c := &KeyChangeConsumer{ + ctx: process.Context(), consumer: &internal.ContinualConsumer{ Process: process, ComponentName: "federationsender/keychange", @@ -117,7 +119,7 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) error { } var queryRes roomserverAPI.QueryRoomsForUserResponse - err = t.rsAPI.QueryRoomsForUser(context.Background(), &roomserverAPI.QueryRoomsForUserRequest{ + err = t.rsAPI.QueryRoomsForUser(t.ctx, &roomserverAPI.QueryRoomsForUserRequest{ UserID: m.UserID, WantMembership: "join", }, &queryRes) @@ -126,7 +128,7 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) error { return nil } // send this key change to all servers who share rooms with this user. - destinations, err := t.db.GetJoinedHostsForRooms(context.Background(), queryRes.RoomIDs) + destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs) if err != nil { logger.WithError(err).Error("failed to calculate joined hosts for rooms user is in") return nil @@ -169,7 +171,7 @@ func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) error { logger := logrus.WithField("user_id", output.UserID) var queryRes roomserverAPI.QueryRoomsForUserResponse - err = t.rsAPI.QueryRoomsForUser(context.Background(), &roomserverAPI.QueryRoomsForUserRequest{ + err = t.rsAPI.QueryRoomsForUser(t.ctx, &roomserverAPI.QueryRoomsForUserRequest{ UserID: output.UserID, WantMembership: "join", }, &queryRes) @@ -178,7 +180,7 @@ func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) error { return nil } // send this key change to all servers who share rooms with this user. - destinations, err := t.db.GetJoinedHostsForRooms(context.Background(), queryRes.RoomIDs) + destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs) if err != nil { logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined hosts for rooms user is in") return nil diff --git a/federationapi/consumers/roomserver.go b/federationapi/consumers/roomserver.go index 0bd61fb4..12410bb7 100644 --- a/federationapi/consumers/roomserver.go +++ b/federationapi/consumers/roomserver.go @@ -33,6 +33,7 @@ import ( // OutputRoomEventConsumer consumes events that originated in the room server. type OutputRoomEventConsumer struct { + ctx context.Context cfg *config.FederationAPI rsAPI api.RoomserverInternalAPI jetstream nats.JetStreamContext @@ -51,6 +52,7 @@ func NewOutputRoomEventConsumer( rsAPI api.RoomserverInternalAPI, ) *OutputRoomEventConsumer { return &OutputRoomEventConsumer{ + ctx: process.Context(), cfg: cfg, jetstream: js, db: store, @@ -71,65 +73,61 @@ func (s *OutputRoomEventConsumer) Start() error { // because updates it will likely fail with a types.EventIDMismatchError when it // realises that it cannot update the room state using the deltas. func (s *OutputRoomEventConsumer) onMessage(msg *nats.Msg) { - // Parse out the event JSON - var output api.OutputEvent - if err := json.Unmarshal(msg.Data, &output); err != nil { - // If the message was invalid, log it and move on to the next message in the stream - log.WithError(err).Errorf("roomserver output log: message parse failure") - _ = msg.Ack() - return - } - - switch output.Type { - case api.OutputTypeNewRoomEvent: - ev := output.NewRoomEvent.Event - - if output.NewRoomEvent.RewritesState { - if err := s.db.PurgeRoomState(context.TODO(), ev.RoomID()); err != nil { - log.WithError(err).Errorf("roomserver output log: purge room state failure") - return - } + jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool { + // Parse out the event JSON + var output api.OutputEvent + if err := json.Unmarshal(msg.Data, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + log.WithError(err).Errorf("roomserver output log: message parse failure") + return true } - if err := s.processMessage(*output.NewRoomEvent); err != nil { - switch err.(type) { - case *queue.ErrorFederationDisabled: - log.WithField("error", output.Type).Info( - err.Error(), - ) - _ = msg.Ack() - default: - // panic rather than continue with an inconsistent database + switch output.Type { + case api.OutputTypeNewRoomEvent: + ev := output.NewRoomEvent.Event + + if output.NewRoomEvent.RewritesState { + if err := s.db.PurgeRoomState(s.ctx, ev.RoomID()); err != nil { + log.WithError(err).Errorf("roomserver output log: purge room state failure") + return false + } + } + + if err := s.processMessage(*output.NewRoomEvent); err != nil { + switch err.(type) { + case *queue.ErrorFederationDisabled: + log.WithField("error", output.Type).Info( + err.Error(), + ) + default: + // panic rather than continue with an inconsistent database + log.WithFields(log.Fields{ + "event_id": ev.EventID(), + "event": string(ev.JSON()), + "add": output.NewRoomEvent.AddsStateEventIDs, + "del": output.NewRoomEvent.RemovesStateEventIDs, + log.ErrorKey: err, + }).Panicf("roomserver output log: write room event failure") + } + } + + case api.OutputTypeNewInboundPeek: + if err := s.processInboundPeek(*output.NewInboundPeek); err != nil { log.WithFields(log.Fields{ - "event_id": ev.EventID(), - "event": string(ev.JSON()), - "add": output.NewRoomEvent.AddsStateEventIDs, - "del": output.NewRoomEvent.RemovesStateEventIDs, + "event": output.NewInboundPeek, log.ErrorKey: err, - }).Panicf("roomserver output log: write room event failure") + }).Panicf("roomserver output log: remote peek event failure") + return false } - return + + default: + log.WithField("type", output.Type).Debug( + "roomserver output log: ignoring unknown output type", + ) } - _ = msg.Ack() - - case api.OutputTypeNewInboundPeek: - if err := s.processInboundPeek(*output.NewInboundPeek); err != nil { - log.WithFields(log.Fields{ - "event": output.NewInboundPeek, - log.ErrorKey: err, - }).Panicf("roomserver output log: remote peek event failure") - return - } - _ = msg.Ack() - - default: - log.WithField("type", output.Type).Debug( - "roomserver output log: ignoring unknown output type", - ) - _ = msg.Ack() - return - } + return true + }) } // processInboundPeek starts tracking a new federated inbound peek (replacing the existing one if any) @@ -146,7 +144,7 @@ func (s *OutputRoomEventConsumer) processInboundPeek(orp api.OutputNewInboundPee // // This is making the tests flakey. - return s.db.AddInboundPeek(context.TODO(), orp.ServerName, orp.RoomID, orp.PeekID, orp.RenewalInterval) + return s.db.AddInboundPeek(s.ctx, orp.ServerName, orp.RoomID, orp.PeekID, orp.RenewalInterval) } // processMessage updates the list of currently joined hosts in the room @@ -162,7 +160,7 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent) err // TODO(#290): handle EventIDMismatchError and recover the current state by // talking to the roomserver oldJoinedHosts, err := s.db.UpdateRoom( - context.TODO(), + s.ctx, ore.Event.RoomID(), ore.LastSentEventID, ore.Event.EventID(), @@ -255,7 +253,7 @@ func (s *OutputRoomEventConsumer) joinedHostsAtEvent( } // handle peeking hosts - inboundPeeks, err := s.db.GetInboundPeeks(context.TODO(), ore.Event.Event.RoomID()) + inboundPeeks, err := s.db.GetInboundPeeks(s.ctx, ore.Event.Event.RoomID()) if err != nil { return nil, err } @@ -373,7 +371,7 @@ func (s *OutputRoomEventConsumer) lookupStateEvents( // from the roomserver using the query API. eventReq := api.QueryEventsByIDRequest{EventIDs: missing} var eventResp api.QueryEventsByIDResponse - if err := s.rsAPI.QueryEventsByID(context.TODO(), &eventReq, &eventResp); err != nil { + if err := s.rsAPI.QueryEventsByID(s.ctx, &eventReq, &eventResp); err != nil { return nil, err } diff --git a/setup/jetstream/helpers.go b/setup/jetstream/helpers.go new file mode 100644 index 00000000..2d563226 --- /dev/null +++ b/setup/jetstream/helpers.go @@ -0,0 +1,11 @@ +package jetstream + +import "github.com/nats-io/nats.go" + +func WithJetStreamMessage(msg *nats.Msg, f func(msg *nats.Msg) bool) { + if f(msg) { + _ = msg.Ack() + } else { + _ = msg.Nak() + } +} diff --git a/syncapi/consumers/clientapi.go b/syncapi/consumers/clientapi.go index 06756fb5..85710cdd 100644 --- a/syncapi/consumers/clientapi.go +++ b/syncapi/consumers/clientapi.go @@ -32,6 +32,7 @@ import ( // OutputClientDataConsumer consumes events that originated in the client API server. type OutputClientDataConsumer struct { + ctx context.Context jetstream nats.JetStreamContext topic string db storage.Database @@ -49,6 +50,7 @@ func NewOutputClientDataConsumer( stream types.StreamProvider, ) *OutputClientDataConsumer { return &OutputClientDataConsumer{ + ctx: process.Context(), jetstream: js, topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputClientData), db: store, @@ -67,36 +69,37 @@ func (s *OutputClientDataConsumer) Start() error { // It is not safe for this function to be called from multiple goroutines, or else the // sync stream position may race and be incorrectly calculated. func (s *OutputClientDataConsumer) onMessage(msg *nats.Msg) { - // Parse out the event JSON - userID := msg.Header.Get(jetstream.UserID) - var output eventutil.AccountData - if err := json.Unmarshal(msg.Data, &output); err != nil { - // If the message was invalid, log it and move on to the next message in the stream - log.WithError(err).Errorf("client API server output log: message parse failure") - sentry.CaptureException(err) - _ = msg.Ack() - return - } + jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool { + // Parse out the event JSON + userID := msg.Header.Get(jetstream.UserID) + var output eventutil.AccountData + if err := json.Unmarshal(msg.Data, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + log.WithError(err).Errorf("client API server output log: message parse failure") + sentry.CaptureException(err) + return true + } - log.WithFields(log.Fields{ - "type": output.Type, - "room_id": output.RoomID, - }).Info("received data from client API server") - - streamPos, err := s.db.UpsertAccountData( - context.TODO(), userID, output.RoomID, output.Type, - ) - if err != nil { - sentry.CaptureException(err) log.WithFields(log.Fields{ - "type": output.Type, - "room_id": output.RoomID, - log.ErrorKey: err, - }).Panicf("could not save account data") - } + "type": output.Type, + "room_id": output.RoomID, + }).Info("received data from client API server") - s.stream.Advance(streamPos) - s.notifier.OnNewAccountData(userID, types.StreamingToken{AccountDataPosition: streamPos}) + streamPos, err := s.db.UpsertAccountData( + s.ctx, userID, output.RoomID, output.Type, + ) + if err != nil { + sentry.CaptureException(err) + log.WithFields(log.Fields{ + "type": output.Type, + "room_id": output.RoomID, + log.ErrorKey: err, + }).Panicf("could not save account data") + } - _ = msg.Ack() + s.stream.Advance(streamPos) + s.notifier.OnNewAccountData(userID, types.StreamingToken{AccountDataPosition: streamPos}) + + return true + }) } diff --git a/syncapi/consumers/eduserver_receipts.go b/syncapi/consumers/eduserver_receipts.go index 99189a78..582e1d64 100644 --- a/syncapi/consumers/eduserver_receipts.go +++ b/syncapi/consumers/eduserver_receipts.go @@ -32,6 +32,7 @@ import ( // OutputReceiptEventConsumer consumes events that originated in the EDU server. type OutputReceiptEventConsumer struct { + ctx context.Context jetstream nats.JetStreamContext topic string db storage.Database @@ -50,6 +51,7 @@ func NewOutputReceiptEventConsumer( stream types.StreamProvider, ) *OutputReceiptEventConsumer { return &OutputReceiptEventConsumer{ + ctx: process.Context(), jetstream: js, topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputReceiptEvent), db: store, @@ -65,30 +67,31 @@ func (s *OutputReceiptEventConsumer) Start() error { } func (s *OutputReceiptEventConsumer) onMessage(msg *nats.Msg) { - var output api.OutputReceiptEvent - if err := json.Unmarshal(msg.Data, &output); err != nil { - // If the message was invalid, log it and move on to the next message in the stream - log.WithError(err).Errorf("EDU server output log: message parse failure") - sentry.CaptureException(err) - _ = msg.Ack() - return - } + jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool { + var output api.OutputReceiptEvent + if err := json.Unmarshal(msg.Data, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + log.WithError(err).Errorf("EDU server output log: message parse failure") + sentry.CaptureException(err) + return true + } - streamPos, err := s.db.StoreReceipt( - context.TODO(), - output.RoomID, - output.Type, - output.UserID, - output.EventID, - output.Timestamp, - ) - if err != nil { - sentry.CaptureException(err) - return - } + streamPos, err := s.db.StoreReceipt( + s.ctx, + output.RoomID, + output.Type, + output.UserID, + output.EventID, + output.Timestamp, + ) + if err != nil { + sentry.CaptureException(err) + return true + } - s.stream.Advance(streamPos) - s.notifier.OnNewReceipt(output.RoomID, types.StreamingToken{ReceiptPosition: streamPos}) + s.stream.Advance(streamPos) + s.notifier.OnNewReceipt(output.RoomID, types.StreamingToken{ReceiptPosition: streamPos}) - _ = msg.Ack() + return true + }) } diff --git a/syncapi/consumers/eduserver_sendtodevice.go b/syncapi/consumers/eduserver_sendtodevice.go index 9ffd572e..6579c303 100644 --- a/syncapi/consumers/eduserver_sendtodevice.go +++ b/syncapi/consumers/eduserver_sendtodevice.go @@ -34,6 +34,7 @@ import ( // OutputSendToDeviceEventConsumer consumes events that originated in the EDU server. type OutputSendToDeviceEventConsumer struct { + ctx context.Context jetstream nats.JetStreamContext topic string db storage.Database @@ -53,6 +54,7 @@ func NewOutputSendToDeviceEventConsumer( stream types.StreamProvider, ) *OutputSendToDeviceEventConsumer { return &OutputSendToDeviceEventConsumer{ + ctx: process.Context(), jetstream: js, topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputSendToDeviceEvent), db: store, @@ -69,48 +71,47 @@ func (s *OutputSendToDeviceEventConsumer) Start() error { } func (s *OutputSendToDeviceEventConsumer) onMessage(msg *nats.Msg) { - var output api.OutputSendToDeviceEvent - if err := json.Unmarshal(msg.Data, &output); err != nil { - // If the message was invalid, log it and move on to the next message in the stream - log.WithError(err).Errorf("EDU server output log: message parse failure") - sentry.CaptureException(err) - _ = msg.Ack() - return - } + jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool { + var output api.OutputSendToDeviceEvent + if err := json.Unmarshal(msg.Data, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + log.WithError(err).Errorf("EDU server output log: message parse failure") + sentry.CaptureException(err) + return true + } - _, domain, err := gomatrixserverlib.SplitID('@', output.UserID) - if err != nil { - sentry.CaptureException(err) - _ = msg.Ack() - return - } - if domain != s.serverName { - _ = msg.Ack() - return - } + _, domain, err := gomatrixserverlib.SplitID('@', output.UserID) + if err != nil { + sentry.CaptureException(err) + return true + } + if domain != s.serverName { + return true + } - util.GetLogger(context.TODO()).WithFields(log.Fields{ - "sender": output.Sender, - "user_id": output.UserID, - "device_id": output.DeviceID, - "event_type": output.Type, - }).Info("sync API received send-to-device event from EDU server") + util.GetLogger(context.TODO()).WithFields(log.Fields{ + "sender": output.Sender, + "user_id": output.UserID, + "device_id": output.DeviceID, + "event_type": output.Type, + }).Info("sync API received send-to-device event from EDU server") - streamPos, err := s.db.StoreNewSendForDeviceMessage( - context.TODO(), output.UserID, output.DeviceID, output.SendToDeviceEvent, - ) - if err != nil { - sentry.CaptureException(err) - log.WithError(err).Errorf("failed to store send-to-device message") - return - } + streamPos, err := s.db.StoreNewSendForDeviceMessage( + s.ctx, output.UserID, output.DeviceID, output.SendToDeviceEvent, + ) + if err != nil { + sentry.CaptureException(err) + log.WithError(err).Errorf("failed to store send-to-device message") + return false + } - s.stream.Advance(streamPos) - s.notifier.OnNewSendToDevice( - output.UserID, - []string{output.DeviceID}, - types.StreamingToken{SendToDevicePosition: streamPos}, - ) + s.stream.Advance(streamPos) + s.notifier.OnNewSendToDevice( + output.UserID, + []string{output.DeviceID}, + types.StreamingToken{SendToDevicePosition: streamPos}, + ) - _ = msg.Ack() + return true + }) } diff --git a/syncapi/consumers/eduserver_typing.go b/syncapi/consumers/eduserver_typing.go index b69293c6..487befe8 100644 --- a/syncapi/consumers/eduserver_typing.go +++ b/syncapi/consumers/eduserver_typing.go @@ -15,6 +15,7 @@ package consumers import ( + "context" "encoding/json" "github.com/getsentry/sentry-go" @@ -32,6 +33,7 @@ import ( // OutputTypingEventConsumer consumes events that originated in the EDU server. type OutputTypingEventConsumer struct { + ctx context.Context jetstream nats.JetStreamContext topic string eduCache *cache.EDUCache @@ -51,6 +53,7 @@ func NewOutputTypingEventConsumer( stream types.StreamProvider, ) *OutputTypingEventConsumer { return &OutputTypingEventConsumer{ + ctx: process.Context(), jetstream: js, topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputTypingEvent), eduCache: eduCache, @@ -66,35 +69,36 @@ func (s *OutputTypingEventConsumer) Start() error { } func (s *OutputTypingEventConsumer) onMessage(msg *nats.Msg) { - var output api.OutputTypingEvent - if err := json.Unmarshal(msg.Data, &output); err != nil { - // If the message was invalid, log it and move on to the next message in the stream - log.WithError(err).Errorf("EDU server output log: message parse failure") - sentry.CaptureException(err) - _ = msg.Ack() - return - } + jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool { + var output api.OutputTypingEvent + if err := json.Unmarshal(msg.Data, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + log.WithError(err).Errorf("EDU server output log: message parse failure") + sentry.CaptureException(err) + return true + } - log.WithFields(log.Fields{ - "room_id": output.Event.RoomID, - "user_id": output.Event.UserID, - "typing": output.Event.Typing, - }).Debug("received data from EDU server") + log.WithFields(log.Fields{ + "room_id": output.Event.RoomID, + "user_id": output.Event.UserID, + "typing": output.Event.Typing, + }).Debug("received data from EDU server") - var typingPos types.StreamPosition - typingEvent := output.Event - if typingEvent.Typing { - typingPos = types.StreamPosition( - s.eduCache.AddTypingUser(typingEvent.UserID, typingEvent.RoomID, output.ExpireTime), - ) - } else { - typingPos = types.StreamPosition( - s.eduCache.RemoveUser(typingEvent.UserID, typingEvent.RoomID), - ) - } + var typingPos types.StreamPosition + typingEvent := output.Event + if typingEvent.Typing { + typingPos = types.StreamPosition( + s.eduCache.AddTypingUser(typingEvent.UserID, typingEvent.RoomID, output.ExpireTime), + ) + } else { + typingPos = types.StreamPosition( + s.eduCache.RemoveUser(typingEvent.UserID, typingEvent.RoomID), + ) + } - s.stream.Advance(typingPos) - s.notifier.OnNewTyping(output.Event.RoomID, types.StreamingToken{TypingPosition: typingPos}) + s.stream.Advance(typingPos) + s.notifier.OnNewTyping(output.Event.RoomID, types.StreamingToken{TypingPosition: typingPos}) - _ = msg.Ack() + return true + }) } diff --git a/syncapi/consumers/keychange.go b/syncapi/consumers/keychange.go index af7162f3..76b143d8 100644 --- a/syncapi/consumers/keychange.go +++ b/syncapi/consumers/keychange.go @@ -34,6 +34,7 @@ import ( // OutputKeyChangeEventConsumer consumes events that originated in the key server. type OutputKeyChangeEventConsumer struct { + ctx context.Context keyChangeConsumer *internal.ContinualConsumer db storage.Database notifier *notifier.Notifier @@ -68,6 +69,7 @@ func NewOutputKeyChangeEventConsumer( } s := &OutputKeyChangeEventConsumer{ + ctx: process.Context(), keyChangeConsumer: &consumer, db: store, serverName: serverName, @@ -131,7 +133,7 @@ func (s *OutputKeyChangeEventConsumer) onDeviceKeyMessage(m api.DeviceMessage, o output := m.DeviceKeys // work out who we need to notify about the new key var queryRes roomserverAPI.QuerySharedUsersResponse - err := s.rsAPI.QuerySharedUsers(context.Background(), &roomserverAPI.QuerySharedUsersRequest{ + err := s.rsAPI.QuerySharedUsers(s.ctx, &roomserverAPI.QuerySharedUsersRequest{ UserID: output.UserID, }, &queryRes) if err != nil { @@ -158,7 +160,7 @@ func (s *OutputKeyChangeEventConsumer) onCrossSigningMessage(m api.DeviceMessage output := m.CrossSigningKeyUpdate // work out who we need to notify about the new key var queryRes roomserverAPI.QuerySharedUsersResponse - err := s.rsAPI.QuerySharedUsers(context.Background(), &roomserverAPI.QuerySharedUsersRequest{ + err := s.rsAPI.QuerySharedUsers(s.ctx, &roomserverAPI.QuerySharedUsersRequest{ UserID: output.UserID, }, &queryRes) if err != nil { diff --git a/syncapi/consumers/roomserver.go b/syncapi/consumers/roomserver.go index e85a181d..5b008e3d 100644 --- a/syncapi/consumers/roomserver.go +++ b/syncapi/consumers/roomserver.go @@ -34,6 +34,7 @@ import ( // OutputRoomEventConsumer consumes events that originated in the room server. type OutputRoomEventConsumer struct { + ctx context.Context cfg *config.SyncAPI rsAPI api.RoomserverInternalAPI jetstream nats.JetStreamContext @@ -56,6 +57,7 @@ func NewOutputRoomEventConsumer( rsAPI api.RoomserverInternalAPI, ) *OutputRoomEventConsumer { return &OutputRoomEventConsumer{ + ctx: process.Context(), cfg: cfg, jetstream: js, topic: cfg.Matrix.JetStream.TopicFor(jetstream.OutputRoomEvent), @@ -77,55 +79,53 @@ func (s *OutputRoomEventConsumer) Start() error { // It is not safe for this function to be called from multiple goroutines, or else the // sync stream position may race and be incorrectly calculated. func (s *OutputRoomEventConsumer) onMessage(msg *nats.Msg) { - // Parse out the event JSON - var err error - var output api.OutputEvent - if err = json.Unmarshal(msg.Data, &output); err != nil { - // If the message was invalid, log it and move on to the next message in the stream - log.WithError(err).Errorf("roomserver output log: message parse failure") - _ = msg.Ack() - return - } - - switch output.Type { - case api.OutputTypeNewRoomEvent: - // Ignore redaction events. We will add them to the database when they are - // validated (when we receive OutputTypeRedactedEvent) - event := output.NewRoomEvent.Event - if event.Type() == gomatrixserverlib.MRoomRedaction && event.StateKey() == nil { - // in the special case where the event redacts itself, just pass the message through because - // we will never see the other part of the pair - if event.Redacts() != event.EventID() { - _ = msg.Ack() - return - } + jetstream.WithJetStreamMessage(msg, func(msg *nats.Msg) bool { + // Parse out the event JSON + var err error + var output api.OutputEvent + if err = json.Unmarshal(msg.Data, &output); err != nil { + // If the message was invalid, log it and move on to the next message in the stream + log.WithError(err).Errorf("roomserver output log: message parse failure") + return true } - err = s.onNewRoomEvent(context.TODO(), *output.NewRoomEvent) - case api.OutputTypeOldRoomEvent: - err = s.onOldRoomEvent(context.TODO(), *output.OldRoomEvent) - case api.OutputTypeNewInviteEvent: - s.onNewInviteEvent(context.TODO(), *output.NewInviteEvent) - case api.OutputTypeRetireInviteEvent: - s.onRetireInviteEvent(context.TODO(), *output.RetireInviteEvent) - case api.OutputTypeNewPeek: - s.onNewPeek(context.TODO(), *output.NewPeek) - case api.OutputTypeRetirePeek: - s.onRetirePeek(context.TODO(), *output.RetirePeek) - case api.OutputTypeRedactedEvent: - err = s.onRedactEvent(context.TODO(), *output.RedactedEvent) - default: - log.WithField("type", output.Type).Debug( - "roomserver output log: ignoring unknown output type", - ) - _ = msg.Ack() - } - if err != nil { - log.WithError(err).Error("roomserver output log: failed to process event") - _ = msg.Nak() - return - } - _ = msg.Ack() + switch output.Type { + case api.OutputTypeNewRoomEvent: + // Ignore redaction events. We will add them to the database when they are + // validated (when we receive OutputTypeRedactedEvent) + event := output.NewRoomEvent.Event + if event.Type() == gomatrixserverlib.MRoomRedaction && event.StateKey() == nil { + // in the special case where the event redacts itself, just pass the message through because + // we will never see the other part of the pair + if event.Redacts() != event.EventID() { + return true + } + } + err = s.onNewRoomEvent(s.ctx, *output.NewRoomEvent) + case api.OutputTypeOldRoomEvent: + err = s.onOldRoomEvent(s.ctx, *output.OldRoomEvent) + case api.OutputTypeNewInviteEvent: + s.onNewInviteEvent(s.ctx, *output.NewInviteEvent) + case api.OutputTypeRetireInviteEvent: + s.onRetireInviteEvent(s.ctx, *output.RetireInviteEvent) + case api.OutputTypeNewPeek: + s.onNewPeek(s.ctx, *output.NewPeek) + case api.OutputTypeRetirePeek: + s.onRetirePeek(s.ctx, *output.RetirePeek) + case api.OutputTypeRedactedEvent: + err = s.onRedactEvent(s.ctx, *output.RedactedEvent) + default: + log.WithField("type", output.Type).Debug( + "roomserver output log: ignoring unknown output type", + ) + } + if err != nil { + log.WithError(err).Error("roomserver output log: failed to process event") + return false + } + + return true + }) } func (s *OutputRoomEventConsumer) onRedactEvent(