diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go
index a2bb2848..088d64af 100644
--- a/roomserver/internal/query/query.go
+++ b/roomserver/internal/query/query.go
@@ -548,6 +548,8 @@ func (r *Queryer) QueryStateAndAuthChainIDs(
 		return fmt.Errorf("r.DB.StateAtEventIDs: %w", err)
 	}
 
+	// STATE EVENTS
+
 	eventNIDs := map[types.EventNID]struct{}{}
 	for _, prevState := range prevStates {
 		var entries []types.StateEntry
@@ -559,30 +561,61 @@ func (r *Queryer) QueryStateAndAuthChainIDs(
 			eventNIDs[entry.EventNID] = struct{}{}
 		}
 	}
+
 	var eventNIDsArray types.EventNIDs
 	for nid := range eventNIDs {
 		eventNIDsArray = append(eventNIDsArray, nid)
 	}
 
-	authEventNIDsArray, err := r.DB.AuthEventNIDs(ctx, eventNIDsArray)
-	if err != nil {
-		return fmt.Errorf("r.DB.AuthEventNIDs: %w", err)
-	}
-
 	stateEventIDs, err := r.DB.EventIDs(ctx, eventNIDsArray)
 	if err != nil {
 		return fmt.Errorf("r.DB.EventIDs: %w", err)
 	}
 
+	for _, eventID := range stateEventIDs {
+		response.StateEvents = append(response.StateEvents, eventID)
+	}
+
+	// AUTH EVENTS
+
+	covered := map[types.EventNID]bool{}
+	fetch := []types.EventNID{}
+	for _, eventNID := range eventNIDsArray {
+		covered[eventNID] = false
+	}
+	for {
+		fetch = fetch[:0]
+		for nid, id := range covered {
+			if !id {
+				fetch = append(fetch, nid)
+			}
+		}
+		if len(fetch) == 0 {
+			break
+		}
+
+		var nids types.EventNIDs
+		nids, err = r.DB.AuthEventNIDs(ctx, fetch)
+		if err != nil {
+			return fmt.Errorf("r.DB.AuthEventNIDs: %w", err)
+		}
+		for _, nid := range nids {
+			if _, ok := covered[nid]; !ok {
+				covered[nid] = true
+			}
+		}
+	}
+
+	authEventNIDsArray := make(types.EventNIDs, 0, len(covered))
+	for nid := range covered {
+		authEventNIDsArray = append(authEventNIDsArray, nid)
+	}
+
 	authEventIDs, err := r.DB.EventIDs(ctx, authEventNIDsArray)
 	if err != nil {
 		return fmt.Errorf("r.DB.EventIDs: %w", err)
 	}
 
-	for _, eventID := range stateEventIDs {
-		response.StateEvents = append(response.StateEvents, eventID)
-	}
-
 	for _, eventID := range authEventIDs {
 		response.AuthChainEvents = append(response.AuthChainEvents, eventID)
 	}