diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 3ad4c202..4e0ff274 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -101,20 +101,27 @@ func (r *Inputer) processRoomEvent( // First of all, check that the auth events of the event are known. // If they aren't then we will ask the federation API for them. - if err := r.checkForMissingAuthEvents(ctx, input.Event, map[string]types.EventNID{}); err != nil { - logrus.WithError(err).Error("XXX: r.checkForMissingAuthEvents") + isRejected := false + authEvents := gomatrixserverlib.NewAuthEvents(nil) + knownAuthEvents := map[string]types.Event{} + if err := r.checkForMissingAuthEvents(ctx, input.Event, &authEvents, knownAuthEvents); err != nil { return "", fmt.Errorf("r.checkForMissingAuthEvents: %w", err) } - // Check that the event passes authentication checks and work out - // the numeric IDs for the auth events. - isRejected := false - authEventNIDs, rejectionErr := helpers.CheckAuthEvents(ctx, r.DB, headered, input.AuthEventIDs) - if rejectionErr != nil { - logrus.WithError(rejectionErr).WithField("event_id", event.EventID()).WithField("auth_event_ids", input.AuthEventIDs).Error("helpers.CheckAuthEvents failed for event, rejecting event") + // Check if the event is allowed by its auth events. If it isn't then + // we consider the event to be "rejected" — it will still be persisted. + var rejectionErr error + if rejectionErr = gomatrixserverlib.Allowed(event, &authEvents); rejectionErr != nil { isRejected = true } + // Accumulate the auth event NIDs. + authEventIDs := event.AuthEventIDs() + authEventNIDs := make([]types.EventNID, 0, len(authEventIDs)) + for _, authEventID := range authEventIDs { + authEventNIDs = append(authEventNIDs, knownAuthEvents[authEventID].EventNID) + } + // Then check if the prev events are known, which we need in order // to calculate the state before the event. if err := r.checkForMissingPrevEvents(ctx, input); err != nil { @@ -246,29 +253,32 @@ func (r *Inputer) processRoomEvent( func (r *Inputer) checkForMissingAuthEvents( ctx context.Context, event *gomatrixserverlib.HeaderedEvent, - cache map[string]types.EventNID, + auth *gomatrixserverlib.AuthEvents, + known map[string]types.Event, ) error { authEventIDs := event.AuthEventIDs() if len(authEventIDs) == 0 { return nil } - knownAuthEventNIDs, err := r.DB.EventNIDs(ctx, authEventIDs) - if err != nil { - return fmt.Errorf("r.DB.EventNIDs: %w", err) - } - for authEventID, authEventNID := range knownAuthEventNIDs { - cache[authEventID] = authEventNID - } + unknown := map[string]struct{}{} - missingAuthEventIDs := make([]string, 0, len(authEventIDs)-len(knownAuthEventNIDs)) - for _, authEventID := range authEventIDs { - if _, ok := knownAuthEventNIDs[authEventID]; !ok { - missingAuthEventIDs = append(missingAuthEventIDs, authEventID) + authEvents, err := r.DB.EventsFromIDs(ctx, authEventIDs) + if err != nil { + return fmt.Errorf("r.DB.EventsFromIDs: %w", err) + } + for _, event := range authEvents { + if event.Event != nil { + known[event.EventID()] = event + if err := auth.AddEvent(event.Event); err != nil { + return fmt.Errorf("auth.AddEvent: %w", err) + } + } else { + unknown[event.EventID()] = struct{}{} } } - if len(missingAuthEventIDs) > 0 { + if len(unknown) > 0 { req := &fedapi.QueryEventAuthFromFederationRequest{ RoomID: event.RoomID(), EventID: event.EventID(), @@ -282,45 +292,40 @@ func (r *Inputer) checkForMissingAuthEvents( res.Events, gomatrixserverlib.TopologicalOrderByAuthEvents, ) { - // Work out which event NIDs we need to look up from the database. If - // the event NID is already in the event map in memory then we can don't - // need to ask the database again. - neededAuthEventNIDs := make([]string, 0, len(event.AuthEventIDs())) - for _, authEventID := range event.AuthEventIDs() { - if _, ok := cache[authEventID]; !ok { - neededAuthEventNIDs = append(neededAuthEventNIDs, authEventID) - } + // If we already know about this event then we don't need to store + // it or do anything further with it. + if _, ok := known[event.EventID()]; ok { + continue } - // If we need to fetch some event NIDs from the database then do that. - // We will also add those to the auth event map in memory, so that we - // can skip future database hits for the same event IDs. - if len(neededAuthEventNIDs) > 0 { - newAuthEventNIDs, err := r.DB.EventNIDs(ctx, neededAuthEventNIDs) - if err != nil { - return fmt.Errorf("r.DB.EventNIDs: %w", err) - } - for authEventID, authEventNID := range newAuthEventNIDs { - cache[authEventID] = authEventNID + // Otherwise, we need to store, and that means we need to know the + // auth event NIDs. Let's see if we can find those. + authEventNIDs := make([]types.EventNID, 0, len(event.AuthEventIDs())) + for _, eventID := range event.AuthEventIDs() { + knownEvent, ok := known[eventID] + if !ok { + return fmt.Errorf("missing auth event %s for %s", eventID, event.EventID()) } + authEventNIDs = append(authEventNIDs, knownEvent.EventNID) } - // Now collect the event NIDs for all of the auth events. - authEventNIDsForEvent := make([]types.EventNID, 0, len(event.AuthEventIDs())) - for _, authEventID := range event.AuthEventIDs() { - authEventNIDsForEvent = append(authEventNIDsForEvent, cache[authEventID]) - } - - // If we haven't accumulated all of the auth events needed for the - // event then we shouldn't persist the event as something is wrong. - if len(authEventNIDsForEvent) != len(event.AuthEventIDs()) { - return fmt.Errorf("missing auth event NIDs for event %s", event.EventID()) + // Let's take a note of the fact that we now know about this event. + known[event.EventID()] = types.Event{} + if err := auth.AddEvent(event); err != nil { + return fmt.Errorf("auth.AddEvent: %w", err) } // Finally, store the event in the database. - if _, _, _, _, err := r.DB.StoreEvent(ctx, event, authEventNIDsForEvent, false); err != nil { + eventNID, _, _, _, _, err := r.DB.StoreEvent(ctx, event, authEventNIDs, false) + if err != nil { return fmt.Errorf("r.DB.StoreEvent: %w", err) } + + // Now we know about this event, too. + known[event.EventID()] = types.Event{ + EventNID: eventNID, + Event: event, + } } }