diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 7b8c7bd1..4c7cd46e 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -101,7 +101,7 @@ func (r *Inputer) processRoomEvent( // First of all, check that the auth events of the event are known. // If they aren't then we will ask the federation API for them. - if err := r.checkForMissingAuthEvents(ctx, input); err != nil { + if err := r.checkForMissingAuthEvents(ctx, input.Event, map[string]types.EventNID{}); err != nil { logrus.WithError(err).Error("XXX: r.checkForMissingAuthEvents") return "", fmt.Errorf("r.checkForMissingAuthEvents: %w", err) } @@ -245,21 +245,21 @@ func (r *Inputer) processRoomEvent( func (r *Inputer) checkForMissingAuthEvents( ctx context.Context, - input *api.InputRoomEvent, + event *gomatrixserverlib.HeaderedEvent, + cache map[string]types.EventNID, ) error { - authEventIDs := input.Event.AuthEventIDs() + authEventIDs := event.AuthEventIDs() if len(authEventIDs) == 0 { return nil } - logrus.Printf("XXX: Auth event IDs: %+v", authEventIDs) - knownAuthEventNIDs, err := r.DB.EventNIDs(ctx, authEventIDs) if err != nil { return fmt.Errorf("r.DB.EventNIDs: %w", err) } - - logrus.Printf("XXX: Known auth event IDs: %+v", knownAuthEventNIDs) + for authEventID, authEventNID := range knownAuthEventNIDs { + cache[authEventID] = authEventNID + } missingAuthEventIDs := make([]string, 0, len(authEventIDs)-len(knownAuthEventNIDs)) for _, authEventID := range authEventIDs { @@ -268,21 +268,59 @@ func (r *Inputer) checkForMissingAuthEvents( } } - logrus.Printf("XXX: Missing auth event IDs: %+v", missingAuthEventIDs) - if len(missingAuthEventIDs) > 0 { req := &fedapi.QueryEventAuthFromFederationRequest{ - RoomID: input.Event.RoomID(), - EventID: input.Event.EventID(), + RoomID: event.RoomID(), + EventID: event.EventID(), } res := &fedapi.QueryEventAuthFromFederationResponse{} if err := r.FSAPI.QueryEventAuthFromFederation(ctx, req, res); err != nil { return fmt.Errorf("r.FSAPI.QueryEventAuthFromFederation: %w", err) } - authEventNIDs, rejection := helpers.CheckAuthEvents(ctx, r.DB, input.Event, input.AuthEventIDs) - if _, _, _, _, err := r.DB.StoreEvent(ctx, input.Event.Event, authEventNIDs, rejection != nil); err != nil { - return fmt.Errorf("r.DB.StoreEvent: %w", err) + for _, event := range gomatrixserverlib.ReverseTopologicalOrdering( + res.Events, + gomatrixserverlib.TopologicalOrderByAuthEvents, + ) { + // Work out which event NIDs we need to look up from the database. If + // the event NID is already in the event map in memory then we can don't + // need to ask the database again. + neededAuthEventNIDs := make([]string, 0, len(event.AuthEventIDs())) + for _, authEventID := range event.AuthEventIDs() { + if _, ok := cache[authEventID]; !ok { + neededAuthEventNIDs = append(neededAuthEventNIDs, authEventID) + } + } + + // If we need to fetch some event NIDs from the database then do that. + // We will also add those to the auth event map in memory, so that we + // can skip future database hits for the same event IDs. + if len(neededAuthEventNIDs) > 0 { + newAuthEventNIDs, err := r.DB.EventNIDs(ctx, neededAuthEventNIDs) + if err != nil { + return fmt.Errorf("r.DB.EventNIDs: %w", err) + } + for authEventID, authEventNID := range newAuthEventNIDs { + cache[authEventID] = authEventNID + } + } + + // Now collect the event NIDs for all of the auth events. + authEventNIDsForEvent := make([]types.EventNID, 0, len(event.AuthEventIDs())) + for _, authEventID := range event.AuthEventIDs() { + authEventNIDsForEvent = append(authEventNIDsForEvent, cache[authEventID]) + } + + // If we haven't accumulated all of the auth events needed for the + // event then we shouldn't persist the event as something is wrong. + if len(authEventNIDsForEvent) != len(event.AuthEventIDs()) { + return fmt.Errorf("missing auth event NIDs for event %s", event.EventID()) + } + + // Finally, store the event in the database. + if _, _, _, _, err := r.DB.StoreEvent(ctx, event, authEventNIDsForEvent, false); err != nil { + return fmt.Errorf("r.DB.StoreEvent: %w", err) + } } }