From b74a93f69ab69769c52833b6eca7dd1761ad3a9e Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 16 Jul 2021 10:00:30 +0100 Subject: [PATCH] Recursive fetch auth event NIDs --- roomserver/internal/query/query.go | 51 ++++++++++++++++++++++++------ 1 file changed, 42 insertions(+), 9 deletions(-) diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index a2bb2848..088d64af 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -548,6 +548,8 @@ func (r *Queryer) QueryStateAndAuthChainIDs( return fmt.Errorf("r.DB.StateAtEventIDs: %w", err) } + // STATE EVENTS + eventNIDs := map[types.EventNID]struct{}{} for _, prevState := range prevStates { var entries []types.StateEntry @@ -559,30 +561,61 @@ func (r *Queryer) QueryStateAndAuthChainIDs( eventNIDs[entry.EventNID] = struct{}{} } } + var eventNIDsArray types.EventNIDs for nid := range eventNIDs { eventNIDsArray = append(eventNIDsArray, nid) } - authEventNIDsArray, err := r.DB.AuthEventNIDs(ctx, eventNIDsArray) - if err != nil { - return fmt.Errorf("r.DB.AuthEventNIDs: %w", err) - } - stateEventIDs, err := r.DB.EventIDs(ctx, eventNIDsArray) if err != nil { return fmt.Errorf("r.DB.EventIDs: %w", err) } + for _, eventID := range stateEventIDs { + response.StateEvents = append(response.StateEvents, eventID) + } + + // AUTH EVENTS + + covered := map[types.EventNID]bool{} + fetch := []types.EventNID{} + for _, eventNID := range eventNIDsArray { + covered[eventNID] = false + } + for { + fetch = fetch[:0] + for nid, id := range covered { + if !id { + fetch = append(fetch, nid) + } + } + if len(fetch) == 0 { + break + } + + var nids types.EventNIDs + nids, err = r.DB.AuthEventNIDs(ctx, fetch) + if err != nil { + return fmt.Errorf("r.DB.AuthEventNIDs: %w", err) + } + for _, nid := range nids { + if _, ok := covered[nid]; !ok { + covered[nid] = true + } + } + } + + authEventNIDsArray := make(types.EventNIDs, 0, len(covered)) + for nid := range covered { + authEventNIDsArray = append(authEventNIDsArray, nid) + } + authEventIDs, err := r.DB.EventIDs(ctx, authEventNIDsArray) if err != nil { return fmt.Errorf("r.DB.EventIDs: %w", err) } - for _, eventID := range stateEventIDs { - response.StateEvents = append(response.StateEvents, eventID) - } - for _, eventID := range authEventIDs { response.AuthChainEvents = append(response.AuthChainEvents, eventID) }