diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index f4216805..147103cf 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -130,7 +130,10 @@ func (r *Inputer) processRoomEvent( return fmt.Errorf("r.Queryer.QueryMissingAuthPrevEvents: %w", err) } } - if len(missingRes.MissingAuthEventIDs) > 0 || len(missingRes.MissingPrevEventIDs) > 0 { + missingAuth := len(missingRes.MissingAuthEventIDs) > 0 + missingPrev := !input.HasState && len(missingRes.MissingPrevEventIDs) > 0 + + if missingAuth || missingPrev { serverReq := &fedapi.QueryJoinedHostServerNamesInRoomRequest{ RoomID: event.RoomID(), ExcludeSelf: true, @@ -138,9 +141,26 @@ func (r *Inputer) processRoomEvent( if err = r.FSAPI.QueryJoinedHostServerNamesInRoom(ctx, serverReq, serverRes); err != nil { return fmt.Errorf("r.FSAPI.QueryJoinedHostServerNamesInRoom: %w", err) } - } - if input.Origin != "" { - serverRes.ServerNames = append(serverRes.ServerNames, input.Origin) + // Sort all of the servers into a map so that we can randomise + // their order. Then make sure that the input origin and the + // event origin are first on the list. + servers := map[gomatrixserverlib.ServerName]struct{}{} + for _, server := range serverRes.ServerNames { + servers[server] = struct{}{} + } + serverRes.ServerNames = serverRes.ServerNames[:0] + if input.Origin != "" { + serverRes.ServerNames = append(serverRes.ServerNames, input.Origin) + delete(servers, input.Origin) + } + if origin := event.Origin(); origin != input.Origin { + serverRes.ServerNames = append(serverRes.ServerNames, origin) + delete(servers, origin) + } + for server := range servers { + serverRes.ServerNames = append(serverRes.ServerNames, server) + delete(servers, server) + } } // First of all, check that the auth events of the event are known. @@ -149,7 +169,7 @@ func (r *Inputer) processRoomEvent( authEvents := gomatrixserverlib.NewAuthEvents(nil) knownEvents := map[string]*types.Event{} if err = r.fetchAuthEvents(ctx, logger, headered, &authEvents, knownEvents, serverRes.ServerNames); err != nil { - return fmt.Errorf("r.checkForMissingAuthEvents: %w", err) + return fmt.Errorf("r.fetchAuthEvents: %w", err) } // Check if the event is allowed by its auth events. If it isn't then @@ -190,7 +210,6 @@ func (r *Inputer) processRoomEvent( // typical federated room join) then we won't bother trying to fetch prev events // because we may not be allowed to see them and we have no choice but to trust // the state event IDs provided to us in the join instead. - missingPrev := !input.HasState && len(missingRes.MissingPrevEventIDs) > 0 if missingPrev && input.Kind == api.KindNew { // Don't do this for KindOld events, otherwise old events that we fetch // to satisfy missing prev events/state will end up recursively calling @@ -204,13 +223,10 @@ func (r *Inputer) processRoomEvent( federation: r.FSAPI, keys: r.KeyRing, roomsMu: internal.NewMutexByRoom(), - servers: map[gomatrixserverlib.ServerName]struct{}{}, + servers: serverRes.ServerNames, hadEvents: map[string]bool{}, haveEvents: map[string]*gomatrixserverlib.HeaderedEvent{}, } - for _, serverName := range serverRes.ServerNames { - missingState.servers[serverName] = struct{}{} - } if err = missingState.processEventWithMissingState(ctx, event, headered.RoomVersion); err != nil { isRejected = true rejectionErr = fmt.Errorf("missingState.processEventWithMissingState: %w", err) @@ -399,12 +415,11 @@ func (r *Inputer) fetchAuthEvents( continue } - // Check the signatures of the event. - // TODO: It really makes sense for the federation API to be doing this, - // because then it can attempt another server if one serves up an event - // with an invalid signature. For now this will do. + // Check the signatures of the event. If this fails then we'll simply + // skip it, because gomatrixserverlib.Allowed() will notice a problem + // if a critical event is missing anyway. if err := authEvent.VerifyEventSignatures(ctx, r.FSAPI.KeyRing()); err != nil { - return fmt.Errorf("event.VerifyEventSignatures: %w", err) + continue } // In order to store the new auth event, we need to know its auth chain @@ -457,7 +472,7 @@ func (r *Inputer) calculateAndSetState( var err error roomState := state.NewStateResolution(r.DB, roomInfo) - if input.HasState && !isRejected { + if input.HasState { // Check here if we think we're in the room already. stateAtEvent.Overwrite = true var joinEventNIDs []types.EventNID diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index aa2b94f8..02ff0f8d 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -25,7 +25,7 @@ type missingStateReq struct { keys gomatrixserverlib.JSONVerifier federation fedapi.FederationInternalAPI roomsMu *internal.MutexByRoom - servers map[gomatrixserverlib.ServerName]struct{} + servers []gomatrixserverlib.ServerName hadEvents map[string]bool hadEventsMutex sync.Mutex haveEvents map[string]*gomatrixserverlib.HeaderedEvent @@ -417,7 +417,7 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserve } var missingResp *gomatrixserverlib.RespMissingEvents - for server := range t.servers { + for _, server := range t.servers { var m gomatrixserverlib.RespMissingEvents if m, err = t.federation.LookupMissingEvents(ctx, server, e.RoomID(), gomatrixserverlib.MissingEvents{ Limit: 20, @@ -700,7 +700,7 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs } var event *gomatrixserverlib.Event found := false - for serverName := range t.servers { + for _, serverName := range t.servers { reqctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() txn, err := t.federation.GetEvent(reqctx, serverName, missingEventID) diff --git a/roomserver/internal/perform/perform_join.go b/roomserver/internal/perform/perform_join.go index 2b0bccda..dfa21bcb 100644 --- a/roomserver/internal/perform/perform_join.go +++ b/roomserver/internal/perform/perform_join.go @@ -51,7 +51,7 @@ func (r *Joiner) PerformJoin( req *rsAPI.PerformJoinRequest, res *rsAPI.PerformJoinResponse, ) { - roomID, joinedVia, err := r.performJoin(ctx, req) + roomID, joinedVia, err := r.performJoin(context.Background(), req) if err != nil { logrus.WithContext(ctx).WithFields(logrus.Fields{ "room_id": req.RoomIDOrAlias, diff --git a/roomserver/internal/perform/perform_leave.go b/roomserver/internal/perform/perform_leave.go index b1991649..3c46e657 100644 --- a/roomserver/internal/perform/perform_leave.go +++ b/roomserver/internal/perform/perform_leave.go @@ -52,7 +52,7 @@ func (r *Leaver) PerformLeave( return nil, fmt.Errorf("user %q does not belong to this homeserver", req.UserID) } if strings.HasPrefix(req.RoomID, "!") { - output, err := r.performLeaveRoomByID(ctx, req, res) + output, err := r.performLeaveRoomByID(context.Background(), req, res) if err != nil { logrus.WithContext(ctx).WithFields(logrus.Fields{ "room_id": req.RoomID,