Try to optimise CheckServerAllowedToSeeEvent by ensuring repeated state keys and events aren't requested

This commit is contained in:
Neil Alexander 2021-11-08 14:41:09 +00:00
parent 59cf8e936e
commit a64d019559
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
2 changed files with 62 additions and 18 deletions

View file

@ -222,11 +222,50 @@ func LoadStateEvents(
return LoadEvents(ctx, db, eventNIDs) return LoadEvents(ctx, db, eventNIDs)
} }
func CheckServerAllowedToSeeEvent( type CheckServerAllowedToSeeEventContext struct {
ctx context.Context, db storage.Database, info types.RoomInfo, eventID string, serverName gomatrixserverlib.ServerName, isServerInRoom bool, ctx context.Context
db storage.Database
info types.RoomInfo
state state.StateResolution
stateKeys map[types.EventStateKeyNID]string
stateEvents map[types.EventNID]*gomatrixserverlib.Event
}
func NewCheckServerAllowedToSeeEventContext(ctx context.Context, db storage.Database, info types.RoomInfo) *CheckServerAllowedToSeeEventContext {
return &CheckServerAllowedToSeeEventContext{
ctx: ctx,
db: db,
info: info,
state: state.NewStateResolution(db, info),
stateKeys: make(map[types.EventStateKeyNID]string),
stateEvents: make(map[types.EventNID]*gomatrixserverlib.Event),
}
}
func (c *CheckServerAllowedToSeeEventContext) LoadStateEvents(
ctx context.Context, db storage.Database, stateEntries []types.StateEntry,
) ([]*gomatrixserverlib.Event, error) {
events := make([]*gomatrixserverlib.Event, 0, len(stateEntries))
eventNIDsToFetch := make([]types.EventNID, 0, len(stateEntries))
for i, e := range stateEntries {
if event, ok := c.stateEvents[e.EventNID]; ok {
events = append(events, event)
continue
}
eventNIDsToFetch = append(eventNIDsToFetch, stateEntries[i].EventNID)
}
fetchedEvents, err := LoadEvents(ctx, db, eventNIDsToFetch)
if err != nil {
return nil, err
}
events = append(events, fetchedEvents...)
return events, nil
}
func (c *CheckServerAllowedToSeeEventContext) CheckServerAllowedToSeeEvent(
eventID string, serverName gomatrixserverlib.ServerName, isServerInRoom bool,
) (bool, error) { ) (bool, error) {
roomState := state.NewStateResolution(db, info) stateEntries, err := c.state.LoadStateAtEvent(c.ctx, eventID)
stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID)
if err != nil { if err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
return false, nil return false, nil
@ -237,22 +276,21 @@ func CheckServerAllowedToSeeEvent(
// Extract all of the event state key NIDs from the room state. // Extract all of the event state key NIDs from the room state.
var stateKeyNIDs []types.EventStateKeyNID var stateKeyNIDs []types.EventStateKeyNID
for _, entry := range stateEntries { for _, entry := range stateEntries {
if _, ok := c.stateKeys[entry.EventStateKeyNID]; ok {
continue
}
stateKeyNIDs = append(stateKeyNIDs, entry.EventStateKeyNID) stateKeyNIDs = append(stateKeyNIDs, entry.EventStateKeyNID)
} }
// Then request those state key NIDs from the database. // Then request those state key NIDs from the database.
stateKeys, err := db.EventStateKeys(ctx, stateKeyNIDs) stateKeys, err := c.db.EventStateKeys(c.ctx, stateKeyNIDs)
if err != nil { if err != nil {
return false, fmt.Errorf("db.EventStateKeys: %w", err) return false, fmt.Errorf("db.EventStateKeys: %w", err)
} }
// If the event state key doesn't match the given servername // Add the results to the cache.
// then we'll filter it out. This does preserve state keys that for stateKeyNID, stateKey := range stateKeys {
// are "" since these will contain history visibility etc. c.stateKeys[stateKeyNID] = stateKey
for nid, key := range stateKeys {
if key != "" && !strings.HasSuffix(key, ":"+string(serverName)) {
delete(stateKeys, nid)
}
} }
// Now filter through all of the state events for the room. // Now filter through all of the state events for the room.
@ -260,8 +298,10 @@ func CheckServerAllowedToSeeEvent(
// keys then we'll add it to the list of filtered entries. // keys then we'll add it to the list of filtered entries.
var filteredEntries []types.StateEntry var filteredEntries []types.StateEntry
for _, entry := range stateEntries { for _, entry := range stateEntries {
if _, ok := stateKeys[entry.EventStateKeyNID]; ok { if key, ok := stateKeys[entry.EventStateKeyNID]; ok {
filteredEntries = append(filteredEntries, entry) if key == "" || strings.HasSuffix(key, ":"+string(serverName)) {
filteredEntries = append(filteredEntries, entry)
}
} }
} }
@ -269,7 +309,7 @@ func CheckServerAllowedToSeeEvent(
return false, nil return false, nil
} }
stateAtEvent, err := LoadStateEvents(ctx, db, filteredEntries) stateAtEvent, err := c.LoadStateEvents(c.ctx, c.db, filteredEntries)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -303,6 +343,8 @@ func ScanEventTree(
var checkedServerInRoom bool var checkedServerInRoom bool
var isServerInRoom bool var isServerInRoom bool
c := NewCheckServerAllowedToSeeEventContext(ctx, db, info)
// Loop through the event IDs to retrieve the requested events and go // Loop through the event IDs to retrieve the requested events and go
// through the whole tree (up to the provided limit) using the events' // through the whole tree (up to the provided limit) using the events'
// "prev_event" key. // "prev_event" key.
@ -345,7 +387,7 @@ BFSLoop:
// hasn't been seen before. // hasn't been seen before.
if !visited[pre] { if !visited[pre] {
visited[pre] = true visited[pre] = true
allowed, err = CheckServerAllowedToSeeEvent(ctx, db, info, pre, serverName, isServerInRoom) allowed, err = c.CheckServerAllowedToSeeEvent(pre, serverName, isServerInRoom)
if err != nil { if err != nil {
util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error( util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error(
"Error checking if allowed to see event", "Error checking if allowed to see event",

View file

@ -376,8 +376,10 @@ func (r *Queryer) QueryServerAllowedToSeeEvent(
if info == nil { if info == nil {
return fmt.Errorf("QueryServerAllowedToSeeEvent: no room info for room %s", roomID) return fmt.Errorf("QueryServerAllowedToSeeEvent: no room info for room %s", roomID)
} }
response.AllowedToSeeEvent, err = helpers.CheckServerAllowedToSeeEvent(
ctx, r.DB, *info, request.EventID, request.ServerName, inRoomRes.IsInRoom, c := helpers.NewCheckServerAllowedToSeeEventContext(ctx, r.DB, *info)
response.AllowedToSeeEvent, err = c.CheckServerAllowedToSeeEvent(
request.EventID, request.ServerName, inRoomRes.IsInRoom,
) )
return return
} }