diff --git a/roomserver/internal/helpers/helpers.go b/roomserver/internal/helpers/helpers.go index a389cc89..c588e1d9 100644 --- a/roomserver/internal/helpers/helpers.go +++ b/roomserver/internal/helpers/helpers.go @@ -222,11 +222,50 @@ func LoadStateEvents( return LoadEvents(ctx, db, eventNIDs) } -func CheckServerAllowedToSeeEvent( - ctx context.Context, db storage.Database, info types.RoomInfo, eventID string, serverName gomatrixserverlib.ServerName, isServerInRoom bool, +type CheckServerAllowedToSeeEventContext struct { + ctx context.Context + db storage.Database + info types.RoomInfo + state state.StateResolution + stateKeys map[types.EventStateKeyNID]string + stateEvents map[types.EventNID]*gomatrixserverlib.Event +} + +func NewCheckServerAllowedToSeeEventContext(ctx context.Context, db storage.Database, info types.RoomInfo) *CheckServerAllowedToSeeEventContext { + return &CheckServerAllowedToSeeEventContext{ + ctx: ctx, + db: db, + info: info, + state: state.NewStateResolution(db, info), + stateKeys: make(map[types.EventStateKeyNID]string), + stateEvents: make(map[types.EventNID]*gomatrixserverlib.Event), + } +} + +func (c *CheckServerAllowedToSeeEventContext) LoadStateEvents( + ctx context.Context, db storage.Database, stateEntries []types.StateEntry, +) ([]*gomatrixserverlib.Event, error) { + events := make([]*gomatrixserverlib.Event, 0, len(stateEntries)) + eventNIDsToFetch := make([]types.EventNID, 0, len(stateEntries)) + for i, e := range stateEntries { + if event, ok := c.stateEvents[e.EventNID]; ok { + events = append(events, event) + continue + } + eventNIDsToFetch = append(eventNIDsToFetch, stateEntries[i].EventNID) + } + fetchedEvents, err := LoadEvents(ctx, db, eventNIDsToFetch) + if err != nil { + return nil, err + } + events = append(events, fetchedEvents...) + return events, nil +} + +func (c *CheckServerAllowedToSeeEventContext) CheckServerAllowedToSeeEvent( + eventID string, serverName gomatrixserverlib.ServerName, isServerInRoom bool, ) (bool, error) { - roomState := state.NewStateResolution(db, info) - stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID) + stateEntries, err := c.state.LoadStateAtEvent(c.ctx, eventID) if err != nil { if errors.Is(err, sql.ErrNoRows) { return false, nil @@ -237,22 +276,21 @@ func CheckServerAllowedToSeeEvent( // Extract all of the event state key NIDs from the room state. var stateKeyNIDs []types.EventStateKeyNID for _, entry := range stateEntries { + if _, ok := c.stateKeys[entry.EventStateKeyNID]; ok { + continue + } stateKeyNIDs = append(stateKeyNIDs, entry.EventStateKeyNID) } // Then request those state key NIDs from the database. - stateKeys, err := db.EventStateKeys(ctx, stateKeyNIDs) + stateKeys, err := c.db.EventStateKeys(c.ctx, stateKeyNIDs) if err != nil { return false, fmt.Errorf("db.EventStateKeys: %w", err) } - // If the event state key doesn't match the given servername - // then we'll filter it out. This does preserve state keys that - // are "" since these will contain history visibility etc. - for nid, key := range stateKeys { - if key != "" && !strings.HasSuffix(key, ":"+string(serverName)) { - delete(stateKeys, nid) - } + // Add the results to the cache. + for stateKeyNID, stateKey := range stateKeys { + c.stateKeys[stateKeyNID] = stateKey } // Now filter through all of the state events for the room. @@ -260,8 +298,10 @@ func CheckServerAllowedToSeeEvent( // keys then we'll add it to the list of filtered entries. var filteredEntries []types.StateEntry for _, entry := range stateEntries { - if _, ok := stateKeys[entry.EventStateKeyNID]; ok { - filteredEntries = append(filteredEntries, entry) + if key, ok := stateKeys[entry.EventStateKeyNID]; ok { + if key == "" || strings.HasSuffix(key, ":"+string(serverName)) { + filteredEntries = append(filteredEntries, entry) + } } } @@ -269,7 +309,7 @@ func CheckServerAllowedToSeeEvent( return false, nil } - stateAtEvent, err := LoadStateEvents(ctx, db, filteredEntries) + stateAtEvent, err := c.LoadStateEvents(c.ctx, c.db, filteredEntries) if err != nil { return false, err } @@ -303,6 +343,8 @@ func ScanEventTree( var checkedServerInRoom bool var isServerInRoom bool + c := NewCheckServerAllowedToSeeEventContext(ctx, db, info) + // Loop through the event IDs to retrieve the requested events and go // through the whole tree (up to the provided limit) using the events' // "prev_event" key. @@ -345,7 +387,7 @@ BFSLoop: // hasn't been seen before. if !visited[pre] { visited[pre] = true - allowed, err = CheckServerAllowedToSeeEvent(ctx, db, info, pre, serverName, isServerInRoom) + allowed, err = c.CheckServerAllowedToSeeEvent(pre, serverName, isServerInRoom) if err != nil { util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error( "Error checking if allowed to see event", diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index b80f08ab..e08d5b5a 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -376,8 +376,10 @@ func (r *Queryer) QueryServerAllowedToSeeEvent( if info == nil { return fmt.Errorf("QueryServerAllowedToSeeEvent: no room info for room %s", roomID) } - response.AllowedToSeeEvent, err = helpers.CheckServerAllowedToSeeEvent( - ctx, r.DB, *info, request.EventID, request.ServerName, inRoomRes.IsInRoom, + + c := helpers.NewCheckServerAllowedToSeeEventContext(ctx, r.DB, *info) + response.AllowedToSeeEvent, err = c.CheckServerAllowedToSeeEvent( + request.EventID, request.ServerName, inRoomRes.IsInRoom, ) return }