Add Queryer and Inputer and factor out more RSAPI stuff (#1382)

* Add Queryer and use embedded structs

* Add Inputer and factor out more RS API stuff

This neatly splits up the RS API based on the functionality it provides,
whilst providing a useful place for code sharing via the `helpers` package.
This commit is contained in:
Kegsay 2020-09-02 17:13:15 +01:00 committed by GitHub
parent f06637435b
commit 9d9e854fe0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 292 additions and 185 deletions

View file

@ -0,0 +1,91 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package input contains the code processes new room events
package input
import (
"context"
"encoding/json"
"sync"
"github.com/Shopify/sarama"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
)
type Inputer struct {
DB storage.Database
Producer sarama.SyncProducer
ServerName gomatrixserverlib.ServerName
OutputRoomEventTopic string
mutexes sync.Map // room ID -> *sync.Mutex, protects calls to processRoomEvent
}
// WriteOutputEvents implements OutputRoomEventWriter
func (r *Inputer) WriteOutputEvents(roomID string, updates []api.OutputEvent) error {
messages := make([]*sarama.ProducerMessage, len(updates))
for i := range updates {
value, err := json.Marshal(updates[i])
if err != nil {
return err
}
logger := log.WithFields(log.Fields{
"room_id": roomID,
"type": updates[i].Type,
})
if updates[i].NewRoomEvent != nil {
logger = logger.WithFields(log.Fields{
"event_type": updates[i].NewRoomEvent.Event.Type(),
"event_id": updates[i].NewRoomEvent.Event.EventID(),
"adds_state": len(updates[i].NewRoomEvent.AddsStateEventIDs),
"removes_state": len(updates[i].NewRoomEvent.RemovesStateEventIDs),
"send_as_server": updates[i].NewRoomEvent.SendAsServer,
"sender": updates[i].NewRoomEvent.Event.Sender(),
})
}
logger.Infof("Producing to topic '%s'", r.OutputRoomEventTopic)
messages[i] = &sarama.ProducerMessage{
Topic: r.OutputRoomEventTopic,
Key: sarama.StringEncoder(roomID),
Value: sarama.ByteEncoder(value),
}
}
return r.Producer.SendMessages(messages)
}
// InputRoomEvents implements api.RoomserverInternalAPI
func (r *Inputer) InputRoomEvents(
ctx context.Context,
request *api.InputRoomEventsRequest,
response *api.InputRoomEventsResponse,
) (err error) {
for i, e := range request.InputRoomEvents {
roomID := "global"
if r.DB.SupportsConcurrentRoomInputs() {
roomID = e.Event.RoomID()
}
mutex, _ := r.mutexes.LoadOrStore(roomID, &sync.Mutex{})
mutex.(*sync.Mutex).Lock()
if response.EventID, err = r.processRoomEvent(ctx, request.InputRoomEvents[i]); err != nil {
mutex.(*sync.Mutex).Unlock()
return err
}
mutex.(*sync.Mutex).Unlock()
}
return nil
}

View file

@ -0,0 +1,185 @@
// Copyright 2017 Vector Creations Ltd
// Copyright 2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package input
import (
"context"
"fmt"
"github.com/matrix-org/dendrite/internal/eventutil"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/internal/helpers"
"github.com/matrix-org/dendrite/roomserver/state"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
)
// processRoomEvent can only be called once at a time
//
// TODO(#375): This should be rewritten to allow concurrent calls. The
// difficulty is in ensuring that we correctly annotate events with the correct
// state deltas when sending to kafka streams
// TODO: Break up function - we should probably do transaction ID checks before calling this.
// nolint:gocyclo
func (r *Inputer) processRoomEvent(
ctx context.Context,
input api.InputRoomEvent,
) (eventID string, err error) {
// Parse and validate the event JSON
headered := input.Event
event := headered.Unwrap()
// Check that the event passes authentication checks and work out
// the numeric IDs for the auth events.
authEventNIDs, err := helpers.CheckAuthEvents(ctx, r.DB, headered, input.AuthEventIDs)
if err != nil {
logrus.WithError(err).WithField("event_id", event.EventID()).WithField("auth_event_ids", input.AuthEventIDs).Error("processRoomEvent.checkAuthEvents failed for event")
return
}
// If we don't have a transaction ID then get one.
if input.TransactionID != nil {
tdID := input.TransactionID
eventID, err = r.DB.GetTransactionEventID(
ctx, tdID.TransactionID, tdID.SessionID, event.Sender(),
)
// On error OR event with the transaction already processed/processesing
if err != nil || eventID != "" {
return
}
}
// Store the event.
_, stateAtEvent, redactionEvent, redactedEventID, err := r.DB.StoreEvent(ctx, event, input.TransactionID, authEventNIDs)
if err != nil {
return "", fmt.Errorf("r.DB.StoreEvent: %w", err)
}
// if storing this event results in it being redacted then do so.
if redactedEventID == event.EventID() {
r, rerr := eventutil.RedactEvent(redactionEvent, &event)
if rerr != nil {
return "", fmt.Errorf("eventutil.RedactEvent: %w", rerr)
}
event = *r
}
// For outliers we can stop after we've stored the event itself as it
// doesn't have any associated state to store and we don't need to
// notify anyone about it.
if input.Kind == api.KindOutlier {
logrus.WithFields(logrus.Fields{
"event_id": event.EventID(),
"type": event.Type(),
"room": event.RoomID(),
}).Info("Stored outlier")
return event.EventID(), nil
}
roomInfo, err := r.DB.RoomInfo(ctx, event.RoomID())
if err != nil {
return "", fmt.Errorf("r.DB.RoomInfo: %w", err)
}
if roomInfo == nil {
return "", fmt.Errorf("r.DB.RoomInfo missing for room %s", event.RoomID())
}
if stateAtEvent.BeforeStateSnapshotNID == 0 {
// We haven't calculated a state for this event yet.
// Lets calculate one.
err = r.calculateAndSetState(ctx, input, *roomInfo, &stateAtEvent, event)
if err != nil {
return "", fmt.Errorf("r.calculateAndSetState: %w", err)
}
}
if err = r.updateLatestEvents(
ctx, // context
roomInfo, // room info for the room being updated
stateAtEvent, // state at event (below)
event, // event
input.SendAsServer, // send as server
input.TransactionID, // transaction ID
); err != nil {
return "", fmt.Errorf("r.updateLatestEvents: %w", err)
}
// processing this event resulted in an event (which may not be the one we're processing)
// being redacted. We are guaranteed to have both sides (the redaction/redacted event),
// so notify downstream components to redact this event - they should have it if they've
// been tracking our output log.
if redactedEventID != "" {
err = r.WriteOutputEvents(event.RoomID(), []api.OutputEvent{
{
Type: api.OutputTypeRedactedEvent,
RedactedEvent: &api.OutputRedactedEvent{
RedactedEventID: redactedEventID,
RedactedBecause: redactionEvent.Headered(headered.RoomVersion),
},
},
})
if err != nil {
return "", fmt.Errorf("r.WriteOutputEvents: %w", err)
}
}
// Update the extremities of the event graph for the room
return event.EventID(), nil
}
func (r *Inputer) calculateAndSetState(
ctx context.Context,
input api.InputRoomEvent,
roomInfo types.RoomInfo,
stateAtEvent *types.StateAtEvent,
event gomatrixserverlib.Event,
) error {
var err error
roomState := state.NewStateResolution(r.DB, roomInfo)
if input.HasState {
// Check here if we think we're in the room already.
stateAtEvent.Overwrite = true
var joinEventNIDs []types.EventNID
// Request join memberships only for local users only.
if joinEventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true); err == nil {
// If we have no local users that are joined to the room then any state about
// the room that we have is quite possibly out of date. Therefore in that case
// we should overwrite it rather than merge it.
stateAtEvent.Overwrite = len(joinEventNIDs) == 0
}
// We've been told what the state at the event is so we don't need to calculate it.
// Check that those state events are in the database and store the state.
var entries []types.StateEntry
if entries, err = r.DB.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil {
return err
}
if stateAtEvent.BeforeStateSnapshotNID, err = r.DB.AddState(ctx, roomInfo.RoomNID, nil, entries); err != nil {
return err
}
} else {
stateAtEvent.Overwrite = false
// We haven't been told what the state at the event is so we need to calculate it from the prev_events
if stateAtEvent.BeforeStateSnapshotNID, err = roomState.CalculateAndStoreStateBeforeEvent(ctx, event); err != nil {
return err
}
}
return r.DB.SetState(ctx, stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID)
}

View file

@ -0,0 +1,390 @@
// Copyright 2017 Vector Creations Ltd
// Copyright 2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package input
import (
"bytes"
"context"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/state"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
// updateLatestEvents updates the list of latest events for this room in the database and writes the
// event to the output log.
// The latest events are the events that aren't referenced by another event in the database:
//
// Time goes down the page. 1 is the m.room.create event (root).
//
// 1 After storing 1 the latest events are {1}
// | After storing 2 the latest events are {2}
// 2 After storing 3 the latest events are {3}
// / \ After storing 4 the latest events are {3,4}
// 3 4 After storing 5 the latest events are {5,4}
// | | After storing 6 the latest events are {5,6}
// 5 6 <--- latest After storing 7 the latest events are {6,7}
// |
// 7 <----- latest
//
// Can only be called once at a time
func (r *Inputer) updateLatestEvents(
ctx context.Context,
roomInfo *types.RoomInfo,
stateAtEvent types.StateAtEvent,
event gomatrixserverlib.Event,
sendAsServer string,
transactionID *api.TransactionID,
) (err error) {
updater, err := r.DB.GetLatestEventsForUpdate(ctx, *roomInfo)
if err != nil {
return fmt.Errorf("r.DB.GetLatestEventsForUpdate: %w", err)
}
succeeded := false
defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err)
u := latestEventsUpdater{
ctx: ctx,
api: r,
updater: updater,
roomInfo: roomInfo,
stateAtEvent: stateAtEvent,
event: event,
sendAsServer: sendAsServer,
transactionID: transactionID,
}
if err = u.doUpdateLatestEvents(); err != nil {
return fmt.Errorf("u.doUpdateLatestEvents: %w", err)
}
succeeded = true
return
}
// latestEventsUpdater tracks the state used to update the latest events in the
// room. It mostly just ferries state between the various function calls.
// The state could be passed using function arguments, but it becomes impractical
// when there are so many variables to pass around.
type latestEventsUpdater struct {
ctx context.Context
api *Inputer
updater *shared.LatestEventsUpdater
roomInfo *types.RoomInfo
stateAtEvent types.StateAtEvent
event gomatrixserverlib.Event
transactionID *api.TransactionID
// Which server to send this event as.
sendAsServer string
// The eventID of the event that was processed before this one.
lastEventIDSent string
// The latest events in the room after processing this event.
latest []types.StateAtEventAndReference
// The state entries removed from and added to the current state of the
// room as a result of processing this event. They are sorted lists.
removed []types.StateEntry
added []types.StateEntry
// The state entries that are removed and added to recover the state before
// the event being processed. They are sorted lists.
stateBeforeEventRemoves []types.StateEntry
stateBeforeEventAdds []types.StateEntry
// The snapshots of current state before and after processing this event
oldStateNID types.StateSnapshotNID
newStateNID types.StateSnapshotNID
}
func (u *latestEventsUpdater) doUpdateLatestEvents() error {
prevEvents := u.event.PrevEvents()
u.lastEventIDSent = u.updater.LastEventIDSent()
u.oldStateNID = u.updater.CurrentStateSnapshotNID()
// If we are doing a regular event update then we will get the
// previous latest events to use as a part of the calculation. If
// we are overwriting the latest events because we have a complete
// state snapshot from somewhere else, e.g. a federated room join,
// then start with an empty set - none of the forward extremities
// that we knew about before matter anymore.
oldLatest := []types.StateAtEventAndReference{}
if !u.stateAtEvent.Overwrite {
oldLatest = u.updater.LatestEvents()
}
// If the event has already been written to the output log then we
// don't need to do anything, as we've handled it already.
hasBeenSent, err := u.updater.HasEventBeenSent(u.stateAtEvent.EventNID)
if err != nil {
return fmt.Errorf("u.updater.HasEventBeenSent: %w", err)
} else if hasBeenSent {
return nil
}
// Update the roomserver_previous_events table with references. This
// is effectively tracking the structure of the DAG.
if err = u.updater.StorePreviousEvents(u.stateAtEvent.EventNID, prevEvents); err != nil {
return fmt.Errorf("u.updater.StorePreviousEvents: %w", err)
}
// Get the event reference for our new event. This will be used when
// determining if the event is referenced by an existing event.
eventReference := u.event.EventReference()
// Check if our new event is already referenced by an existing event
// in the room. If it is then it isn't a latest event.
alreadyReferenced, err := u.updater.IsReferenced(eventReference)
if err != nil {
return fmt.Errorf("u.updater.IsReferenced: %w", err)
}
// Work out what the latest events are.
u.latest = calculateLatest(
oldLatest,
alreadyReferenced,
prevEvents,
types.StateAtEventAndReference{
EventReference: eventReference,
StateAtEvent: u.stateAtEvent,
},
)
// Now that we know what the latest events are, it's time to get the
// latest state.
if err = u.latestState(); err != nil {
return fmt.Errorf("u.latestState: %w", err)
}
// If we need to generate any output events then here's where we do it.
// TODO: Move this!
updates, err := u.api.updateMemberships(u.ctx, u.updater, u.removed, u.added)
if err != nil {
return fmt.Errorf("u.api.updateMemberships: %w", err)
}
update, err := u.makeOutputNewRoomEvent()
if err != nil {
return fmt.Errorf("u.makeOutputNewRoomEvent: %w", err)
}
updates = append(updates, *update)
// Send the event to the output logs.
// We do this inside the database transaction to ensure that we only mark an event as sent if we sent it.
// (n.b. this means that it's possible that the same event will be sent twice if the transaction fails but
// the write to the output log succeeds)
// TODO: This assumes that writing the event to the output log is synchronous. It should be possible to
// send the event asynchronously but we would need to ensure that 1) the events are written to the log in
// the correct order, 2) that pending writes are resent across restarts. In order to avoid writing all the
// necessary bookkeeping we'll keep the event sending synchronous for now.
if err = u.api.WriteOutputEvents(u.event.RoomID(), updates); err != nil {
return fmt.Errorf("u.api.WriteOutputEvents: %w", err)
}
if err = u.updater.SetLatestEvents(u.roomInfo.RoomNID, u.latest, u.stateAtEvent.EventNID, u.newStateNID); err != nil {
return fmt.Errorf("u.updater.SetLatestEvents: %w", err)
}
if err = u.updater.MarkEventAsSent(u.stateAtEvent.EventNID); err != nil {
return fmt.Errorf("u.updater.MarkEventAsSent: %w", err)
}
return nil
}
func (u *latestEventsUpdater) latestState() error {
var err error
roomState := state.NewStateResolution(u.api.DB, *u.roomInfo)
// Get a list of the current latest events.
latestStateAtEvents := make([]types.StateAtEvent, len(u.latest))
for i := range u.latest {
latestStateAtEvents[i] = u.latest[i].StateAtEvent
}
// Takes the NIDs of the latest events and creates a state snapshot
// of the state after the events. The snapshot state will be resolved
// using the correct state resolution algorithm for the room.
u.newStateNID, err = roomState.CalculateAndStoreStateAfterEvents(
u.ctx, latestStateAtEvents,
)
if err != nil {
return fmt.Errorf("roomState.CalculateAndStoreStateAfterEvents: %w", err)
}
// If we are overwriting the state then we should make sure that we
// don't send anything out over federation again, it will very likely
// be a repeat.
if u.stateAtEvent.Overwrite {
u.sendAsServer = ""
}
// Now that we have a new state snapshot based on the latest events,
// we can compare that new snapshot to the previous one and see what
// has changed. This gives us one list of removed state events and
// another list of added ones. Replacing a value for a state-key tuple
// will result one removed (the old event) and one added (the new event).
u.removed, u.added, err = roomState.DifferenceBetweeenStateSnapshots(
u.ctx, u.oldStateNID, u.newStateNID,
)
if err != nil {
return fmt.Errorf("roomState.DifferenceBetweenStateSnapshots: %w", err)
}
// Also work out the state before the event removes and the event
// adds.
u.stateBeforeEventRemoves, u.stateBeforeEventAdds, err = roomState.DifferenceBetweeenStateSnapshots(
u.ctx, u.newStateNID, u.stateAtEvent.BeforeStateSnapshotNID,
)
if err != nil {
return fmt.Errorf("roomState.DifferenceBetweeenStateSnapshots: %w", err)
}
return nil
}
func calculateLatest(
oldLatest []types.StateAtEventAndReference,
alreadyReferenced bool,
prevEvents []gomatrixserverlib.EventReference,
newEvent types.StateAtEventAndReference,
) []types.StateAtEventAndReference {
var alreadyInLatest bool
var newLatest []types.StateAtEventAndReference
for _, l := range oldLatest {
keep := true
for _, prevEvent := range prevEvents {
if l.EventID == prevEvent.EventID && bytes.Equal(l.EventSHA256, prevEvent.EventSHA256) {
// This event can be removed from the latest events cause we've found an event that references it.
// (If an event is referenced by another event then it can't be one of the latest events in the room
// because we have an event that comes after it)
keep = false
break
}
}
if l.EventNID == newEvent.EventNID {
alreadyInLatest = true
}
if keep {
// Keep the event in the latest events.
newLatest = append(newLatest, l)
}
}
if !alreadyReferenced && !alreadyInLatest {
// This event is not referenced by any of the events in the room
// and the event is not already in the latest events.
// Add it to the latest events
newLatest = append(newLatest, newEvent)
}
return newLatest
}
func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error) {
latestEventIDs := make([]string, len(u.latest))
for i := range u.latest {
latestEventIDs[i] = u.latest[i].EventID
}
ore := api.OutputNewRoomEvent{
Event: u.event.Headered(u.roomInfo.RoomVersion),
LastSentEventID: u.lastEventIDSent,
LatestEventIDs: latestEventIDs,
TransactionID: u.transactionID,
}
eventIDMap, err := u.stateEventMap()
if err != nil {
return nil, err
}
for _, entry := range u.added {
ore.AddsStateEventIDs = append(ore.AddsStateEventIDs, eventIDMap[entry.EventNID])
}
for _, entry := range u.removed {
ore.RemovesStateEventIDs = append(ore.RemovesStateEventIDs, eventIDMap[entry.EventNID])
}
for _, entry := range u.stateBeforeEventRemoves {
ore.StateBeforeRemovesEventIDs = append(ore.StateBeforeRemovesEventIDs, eventIDMap[entry.EventNID])
}
for _, entry := range u.stateBeforeEventAdds {
ore.StateBeforeAddsEventIDs = append(ore.StateBeforeAddsEventIDs, eventIDMap[entry.EventNID])
}
ore.SendAsServer = u.sendAsServer
// include extra state events if they were added as nearly every downstream component will care about it
// and we'd rather not have them all hit QueryEventsByID at the same time!
if len(ore.AddsStateEventIDs) > 0 {
ore.AddStateEvents, err = u.extraEventsForIDs(u.roomInfo.RoomVersion, ore.AddsStateEventIDs)
if err != nil {
return nil, fmt.Errorf("failed to load add_state_events from db: %w", err)
}
}
return &api.OutputEvent{
Type: api.OutputTypeNewRoomEvent,
NewRoomEvent: &ore,
}, nil
}
// extraEventsForIDs returns the full events for the event IDs given, but does not include the current event being
// updated.
func (u *latestEventsUpdater) extraEventsForIDs(roomVersion gomatrixserverlib.RoomVersion, eventIDs []string) ([]gomatrixserverlib.HeaderedEvent, error) {
var extraEventIDs []string
for _, e := range eventIDs {
if e == u.event.EventID() {
continue
}
extraEventIDs = append(extraEventIDs, e)
}
if len(extraEventIDs) == 0 {
return nil, nil
}
extraEvents, err := u.api.DB.EventsFromIDs(u.ctx, extraEventIDs)
if err != nil {
return nil, err
}
var h []gomatrixserverlib.HeaderedEvent
for _, e := range extraEvents {
h = append(h, e.Headered(roomVersion))
}
return h, nil
}
// retrieve an event nid -> event ID map for all events that need updating
func (u *latestEventsUpdater) stateEventMap() (map[types.EventNID]string, error) {
var stateEventNIDs []types.EventNID
var allStateEntries []types.StateEntry
allStateEntries = append(allStateEntries, u.added...)
allStateEntries = append(allStateEntries, u.removed...)
allStateEntries = append(allStateEntries, u.stateBeforeEventRemoves...)
allStateEntries = append(allStateEntries, u.stateBeforeEventAdds...)
for _, entry := range allStateEntries {
stateEventNIDs = append(stateEventNIDs, entry.EventNID)
}
stateEventNIDs = stateEventNIDs[:util.SortAndUnique(eventNIDSorter(stateEventNIDs))]
return u.api.DB.EventIDs(u.ctx, stateEventNIDs)
}
type eventNIDSorter []types.EventNID
func (s eventNIDSorter) Len() int { return len(s) }
func (s eventNIDSorter) Less(i, j int) bool { return s[i] < s[j] }
func (s eventNIDSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }

View file

@ -0,0 +1,267 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package input
import (
"context"
"fmt"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/internal/helpers"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
)
// updateMembership updates the current membership and the invites for each
// user affected by a change in the current state of the room.
// Returns a list of output events to write to the kafka log to inform the
// consumers about the invites added or retired by the change in current state.
func (r *Inputer) updateMemberships(
ctx context.Context,
updater *shared.LatestEventsUpdater,
removed, added []types.StateEntry,
) ([]api.OutputEvent, error) {
changes := membershipChanges(removed, added)
var eventNIDs []types.EventNID
for _, change := range changes {
if change.addedEventNID != 0 {
eventNIDs = append(eventNIDs, change.addedEventNID)
}
if change.removedEventNID != 0 {
eventNIDs = append(eventNIDs, change.removedEventNID)
}
}
// Load the event JSON so we can look up the "membership" key.
// TODO: Maybe add a membership key to the events table so we can load that
// key without having to load the entire event JSON?
events, err := r.DB.Events(ctx, eventNIDs)
if err != nil {
return nil, err
}
var updates []api.OutputEvent
for _, change := range changes {
var ae *gomatrixserverlib.Event
var re *gomatrixserverlib.Event
targetUserNID := change.EventStateKeyNID
if change.removedEventNID != 0 {
ev, _ := helpers.EventMap(events).Lookup(change.removedEventNID)
if ev != nil {
re = &ev.Event
}
}
if change.addedEventNID != 0 {
ev, _ := helpers.EventMap(events).Lookup(change.addedEventNID)
if ev != nil {
ae = &ev.Event
}
}
if updates, err = r.updateMembership(updater, targetUserNID, re, ae, updates); err != nil {
return nil, err
}
}
return updates, nil
}
func (r *Inputer) updateMembership(
updater *shared.LatestEventsUpdater,
targetUserNID types.EventStateKeyNID,
remove, add *gomatrixserverlib.Event,
updates []api.OutputEvent,
) ([]api.OutputEvent, error) {
var err error
// Default the membership to Leave if no event was added or removed.
oldMembership := gomatrixserverlib.Leave
newMembership := gomatrixserverlib.Leave
if remove != nil {
oldMembership, err = remove.Membership()
if err != nil {
return nil, err
}
}
if add != nil {
newMembership, err = add.Membership()
if err != nil {
return nil, err
}
}
if oldMembership == newMembership && newMembership != gomatrixserverlib.Join {
// If the membership is the same then nothing changed and we can return
// immediately, unless it's a Join update (e.g. profile update).
return updates, nil
}
if add == nil {
// This can happen when we have rejoined a room and suddenly we have a
// divergence between the former state and the new one. We don't want to
// act on removals and apparently there are no adds, so stop here.
return updates, nil
}
mu, err := updater.MembershipUpdater(targetUserNID, r.isLocalTarget(add))
if err != nil {
return nil, err
}
switch newMembership {
case gomatrixserverlib.Invite:
return helpers.UpdateToInviteMembership(mu, add, updates, updater.RoomVersion())
case gomatrixserverlib.Join:
return updateToJoinMembership(mu, add, updates)
case gomatrixserverlib.Leave, gomatrixserverlib.Ban:
return updateToLeaveMembership(mu, add, newMembership, updates)
default:
panic(fmt.Errorf(
"input: membership %q is not one of the allowed values", newMembership,
))
}
}
func (r *Inputer) isLocalTarget(event *gomatrixserverlib.Event) bool {
isTargetLocalUser := false
if statekey := event.StateKey(); statekey != nil {
_, domain, _ := gomatrixserverlib.SplitID('@', *statekey)
isTargetLocalUser = domain == r.ServerName
}
return isTargetLocalUser
}
func updateToJoinMembership(
mu *shared.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent,
) ([]api.OutputEvent, error) {
// If the user is already marked as being joined, we call SetToJoin to update
// the event ID then we can return immediately. Retired is ignored as there
// is no invite event to retire.
if mu.IsJoin() {
_, err := mu.SetToJoin(add.Sender(), add.EventID(), true)
if err != nil {
return nil, err
}
return updates, nil
}
// When we mark a user as being joined we will invalidate any invites that
// are active for that user. We notify the consumers that the invites have
// been retired using a special event, even though they could infer this
// by studying the state changes in the room event stream.
retired, err := mu.SetToJoin(add.Sender(), add.EventID(), false)
if err != nil {
return nil, err
}
for _, eventID := range retired {
orie := api.OutputRetireInviteEvent{
EventID: eventID,
Membership: gomatrixserverlib.Join,
RetiredByEventID: add.EventID(),
TargetUserID: *add.StateKey(),
}
updates = append(updates, api.OutputEvent{
Type: api.OutputTypeRetireInviteEvent,
RetireInviteEvent: &orie,
})
}
return updates, nil
}
func updateToLeaveMembership(
mu *shared.MembershipUpdater, add *gomatrixserverlib.Event,
newMembership string, updates []api.OutputEvent,
) ([]api.OutputEvent, error) {
// If the user is already neither joined, nor invited to the room then we
// can return immediately.
if mu.IsLeave() {
return updates, nil
}
// When we mark a user as having left we will invalidate any invites that
// are active for that user. We notify the consumers that the invites have
// been retired using a special event, even though they could infer this
// by studying the state changes in the room event stream.
retired, err := mu.SetToLeave(add.Sender(), add.EventID())
if err != nil {
return nil, err
}
for _, eventID := range retired {
orie := api.OutputRetireInviteEvent{
EventID: eventID,
Membership: newMembership,
RetiredByEventID: add.EventID(),
TargetUserID: *add.StateKey(),
}
updates = append(updates, api.OutputEvent{
Type: api.OutputTypeRetireInviteEvent,
RetireInviteEvent: &orie,
})
}
return updates, nil
}
// membershipChanges pairs up the membership state changes.
func membershipChanges(removed, added []types.StateEntry) []stateChange {
changes := pairUpChanges(removed, added)
var result []stateChange
for _, c := range changes {
if c.EventTypeNID == types.MRoomMemberNID {
result = append(result, c)
}
}
return result
}
type stateChange struct {
types.StateKeyTuple
removedEventNID types.EventNID
addedEventNID types.EventNID
}
// pairUpChanges pairs up the state events added and removed for each type,
// state key tuple.
func pairUpChanges(removed, added []types.StateEntry) []stateChange {
tuples := make(map[types.StateKeyTuple]stateChange)
changes := []stateChange{}
// First, go through the newly added state entries.
for _, add := range added {
if change, ok := tuples[add.StateKeyTuple]; ok {
// If we already have an entry, update it.
change.addedEventNID = add.EventNID
tuples[add.StateKeyTuple] = change
} else {
// Otherwise, create a new entry.
tuples[add.StateKeyTuple] = stateChange{add.StateKeyTuple, 0, add.EventNID}
}
}
// Now go through the removed state entries.
for _, remove := range removed {
if change, ok := tuples[remove.StateKeyTuple]; ok {
// If we already have an entry, update it.
change.removedEventNID = remove.EventNID
tuples[remove.StateKeyTuple] = change
} else {
// Otherwise, create a new entry.
tuples[remove.StateKeyTuple] = stateChange{remove.StateKeyTuple, remove.EventNID, 0}
}
}
// Now return the changes as an array.
for _, change := range tuples {
changes = append(changes, change)
}
return changes
}