Refactor auth checking

This commit is contained in:
Neil Alexander 2021-12-09 16:12:21 +00:00
parent da3c1e226d
commit 1bba164bf4
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944

View file

@ -101,20 +101,27 @@ func (r *Inputer) processRoomEvent(
// First of all, check that the auth events of the event are known. // First of all, check that the auth events of the event are known.
// If they aren't then we will ask the federation API for them. // If they aren't then we will ask the federation API for them.
if err := r.checkForMissingAuthEvents(ctx, input.Event, map[string]types.EventNID{}); err != nil { isRejected := false
logrus.WithError(err).Error("XXX: r.checkForMissingAuthEvents") authEvents := gomatrixserverlib.NewAuthEvents(nil)
knownAuthEvents := map[string]types.Event{}
if err := r.checkForMissingAuthEvents(ctx, input.Event, &authEvents, knownAuthEvents); err != nil {
return "", fmt.Errorf("r.checkForMissingAuthEvents: %w", err) return "", fmt.Errorf("r.checkForMissingAuthEvents: %w", err)
} }
// Check that the event passes authentication checks and work out // Check if the event is allowed by its auth events. If it isn't then
// the numeric IDs for the auth events. // we consider the event to be "rejected" — it will still be persisted.
isRejected := false var rejectionErr error
authEventNIDs, rejectionErr := helpers.CheckAuthEvents(ctx, r.DB, headered, input.AuthEventIDs) if rejectionErr = gomatrixserverlib.Allowed(event, &authEvents); rejectionErr != nil {
if rejectionErr != nil {
logrus.WithError(rejectionErr).WithField("event_id", event.EventID()).WithField("auth_event_ids", input.AuthEventIDs).Error("helpers.CheckAuthEvents failed for event, rejecting event")
isRejected = true isRejected = true
} }
// Accumulate the auth event NIDs.
authEventIDs := event.AuthEventIDs()
authEventNIDs := make([]types.EventNID, 0, len(authEventIDs))
for _, authEventID := range authEventIDs {
authEventNIDs = append(authEventNIDs, knownAuthEvents[authEventID].EventNID)
}
// Then check if the prev events are known, which we need in order // Then check if the prev events are known, which we need in order
// to calculate the state before the event. // to calculate the state before the event.
if err := r.checkForMissingPrevEvents(ctx, input); err != nil { if err := r.checkForMissingPrevEvents(ctx, input); err != nil {
@ -246,29 +253,32 @@ func (r *Inputer) processRoomEvent(
func (r *Inputer) checkForMissingAuthEvents( func (r *Inputer) checkForMissingAuthEvents(
ctx context.Context, ctx context.Context,
event *gomatrixserverlib.HeaderedEvent, event *gomatrixserverlib.HeaderedEvent,
cache map[string]types.EventNID, auth *gomatrixserverlib.AuthEvents,
known map[string]types.Event,
) error { ) error {
authEventIDs := event.AuthEventIDs() authEventIDs := event.AuthEventIDs()
if len(authEventIDs) == 0 { if len(authEventIDs) == 0 {
return nil return nil
} }
knownAuthEventNIDs, err := r.DB.EventNIDs(ctx, authEventIDs) unknown := map[string]struct{}{}
if err != nil {
return fmt.Errorf("r.DB.EventNIDs: %w", err)
}
for authEventID, authEventNID := range knownAuthEventNIDs {
cache[authEventID] = authEventNID
}
missingAuthEventIDs := make([]string, 0, len(authEventIDs)-len(knownAuthEventNIDs)) authEvents, err := r.DB.EventsFromIDs(ctx, authEventIDs)
for _, authEventID := range authEventIDs { if err != nil {
if _, ok := knownAuthEventNIDs[authEventID]; !ok { return fmt.Errorf("r.DB.EventsFromIDs: %w", err)
missingAuthEventIDs = append(missingAuthEventIDs, authEventID) }
for _, event := range authEvents {
if event.Event != nil {
known[event.EventID()] = event
if err := auth.AddEvent(event.Event); err != nil {
return fmt.Errorf("auth.AddEvent: %w", err)
}
} else {
unknown[event.EventID()] = struct{}{}
} }
} }
if len(missingAuthEventIDs) > 0 { if len(unknown) > 0 {
req := &fedapi.QueryEventAuthFromFederationRequest{ req := &fedapi.QueryEventAuthFromFederationRequest{
RoomID: event.RoomID(), RoomID: event.RoomID(),
EventID: event.EventID(), EventID: event.EventID(),
@ -282,45 +292,40 @@ func (r *Inputer) checkForMissingAuthEvents(
res.Events, res.Events,
gomatrixserverlib.TopologicalOrderByAuthEvents, gomatrixserverlib.TopologicalOrderByAuthEvents,
) { ) {
// Work out which event NIDs we need to look up from the database. If // If we already know about this event then we don't need to store
// the event NID is already in the event map in memory then we can don't // it or do anything further with it.
// need to ask the database again. if _, ok := known[event.EventID()]; ok {
neededAuthEventNIDs := make([]string, 0, len(event.AuthEventIDs())) continue
for _, authEventID := range event.AuthEventIDs() {
if _, ok := cache[authEventID]; !ok {
neededAuthEventNIDs = append(neededAuthEventNIDs, authEventID)
}
} }
// If we need to fetch some event NIDs from the database then do that. // Otherwise, we need to store, and that means we need to know the
// We will also add those to the auth event map in memory, so that we // auth event NIDs. Let's see if we can find those.
// can skip future database hits for the same event IDs. authEventNIDs := make([]types.EventNID, 0, len(event.AuthEventIDs()))
if len(neededAuthEventNIDs) > 0 { for _, eventID := range event.AuthEventIDs() {
newAuthEventNIDs, err := r.DB.EventNIDs(ctx, neededAuthEventNIDs) knownEvent, ok := known[eventID]
if err != nil { if !ok {
return fmt.Errorf("r.DB.EventNIDs: %w", err) return fmt.Errorf("missing auth event %s for %s", eventID, event.EventID())
}
for authEventID, authEventNID := range newAuthEventNIDs {
cache[authEventID] = authEventNID
} }
authEventNIDs = append(authEventNIDs, knownEvent.EventNID)
} }
// Now collect the event NIDs for all of the auth events. // Let's take a note of the fact that we now know about this event.
authEventNIDsForEvent := make([]types.EventNID, 0, len(event.AuthEventIDs())) known[event.EventID()] = types.Event{}
for _, authEventID := range event.AuthEventIDs() { if err := auth.AddEvent(event); err != nil {
authEventNIDsForEvent = append(authEventNIDsForEvent, cache[authEventID]) return fmt.Errorf("auth.AddEvent: %w", err)
}
// If we haven't accumulated all of the auth events needed for the
// event then we shouldn't persist the event as something is wrong.
if len(authEventNIDsForEvent) != len(event.AuthEventIDs()) {
return fmt.Errorf("missing auth event NIDs for event %s", event.EventID())
} }
// Finally, store the event in the database. // Finally, store the event in the database.
if _, _, _, _, err := r.DB.StoreEvent(ctx, event, authEventNIDsForEvent, false); err != nil { eventNID, _, _, _, _, err := r.DB.StoreEvent(ctx, event, authEventNIDs, false)
if err != nil {
return fmt.Errorf("r.DB.StoreEvent: %w", err) return fmt.Errorf("r.DB.StoreEvent: %w", err)
} }
// Now we know about this event, too.
known[event.EventID()] = types.Event{
EventNID: eventNID,
Event: event,
}
} }
} }