mirror of
https://github.com/hoernschen/dendrite.git
synced 2024-12-26 15:08:28 +00:00
Try to optimise CheckServerAllowedToSeeEvent
by ensuring repeated state keys and events aren't requested
This commit is contained in:
parent
59cf8e936e
commit
a64d019559
2 changed files with 62 additions and 18 deletions
|
@ -222,11 +222,50 @@ func LoadStateEvents(
|
||||||
return LoadEvents(ctx, db, eventNIDs)
|
return LoadEvents(ctx, db, eventNIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
func CheckServerAllowedToSeeEvent(
|
type CheckServerAllowedToSeeEventContext struct {
|
||||||
ctx context.Context, db storage.Database, info types.RoomInfo, eventID string, serverName gomatrixserverlib.ServerName, isServerInRoom bool,
|
ctx context.Context
|
||||||
|
db storage.Database
|
||||||
|
info types.RoomInfo
|
||||||
|
state state.StateResolution
|
||||||
|
stateKeys map[types.EventStateKeyNID]string
|
||||||
|
stateEvents map[types.EventNID]*gomatrixserverlib.Event
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewCheckServerAllowedToSeeEventContext(ctx context.Context, db storage.Database, info types.RoomInfo) *CheckServerAllowedToSeeEventContext {
|
||||||
|
return &CheckServerAllowedToSeeEventContext{
|
||||||
|
ctx: ctx,
|
||||||
|
db: db,
|
||||||
|
info: info,
|
||||||
|
state: state.NewStateResolution(db, info),
|
||||||
|
stateKeys: make(map[types.EventStateKeyNID]string),
|
||||||
|
stateEvents: make(map[types.EventNID]*gomatrixserverlib.Event),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CheckServerAllowedToSeeEventContext) LoadStateEvents(
|
||||||
|
ctx context.Context, db storage.Database, stateEntries []types.StateEntry,
|
||||||
|
) ([]*gomatrixserverlib.Event, error) {
|
||||||
|
events := make([]*gomatrixserverlib.Event, 0, len(stateEntries))
|
||||||
|
eventNIDsToFetch := make([]types.EventNID, 0, len(stateEntries))
|
||||||
|
for i, e := range stateEntries {
|
||||||
|
if event, ok := c.stateEvents[e.EventNID]; ok {
|
||||||
|
events = append(events, event)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
eventNIDsToFetch = append(eventNIDsToFetch, stateEntries[i].EventNID)
|
||||||
|
}
|
||||||
|
fetchedEvents, err := LoadEvents(ctx, db, eventNIDsToFetch)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
events = append(events, fetchedEvents...)
|
||||||
|
return events, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CheckServerAllowedToSeeEventContext) CheckServerAllowedToSeeEvent(
|
||||||
|
eventID string, serverName gomatrixserverlib.ServerName, isServerInRoom bool,
|
||||||
) (bool, error) {
|
) (bool, error) {
|
||||||
roomState := state.NewStateResolution(db, info)
|
stateEntries, err := c.state.LoadStateAtEvent(c.ctx, eventID)
|
||||||
stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
return false, nil
|
return false, nil
|
||||||
|
@ -237,22 +276,21 @@ func CheckServerAllowedToSeeEvent(
|
||||||
// Extract all of the event state key NIDs from the room state.
|
// Extract all of the event state key NIDs from the room state.
|
||||||
var stateKeyNIDs []types.EventStateKeyNID
|
var stateKeyNIDs []types.EventStateKeyNID
|
||||||
for _, entry := range stateEntries {
|
for _, entry := range stateEntries {
|
||||||
|
if _, ok := c.stateKeys[entry.EventStateKeyNID]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
stateKeyNIDs = append(stateKeyNIDs, entry.EventStateKeyNID)
|
stateKeyNIDs = append(stateKeyNIDs, entry.EventStateKeyNID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Then request those state key NIDs from the database.
|
// Then request those state key NIDs from the database.
|
||||||
stateKeys, err := db.EventStateKeys(ctx, stateKeyNIDs)
|
stateKeys, err := c.db.EventStateKeys(c.ctx, stateKeyNIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("db.EventStateKeys: %w", err)
|
return false, fmt.Errorf("db.EventStateKeys: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the event state key doesn't match the given servername
|
// Add the results to the cache.
|
||||||
// then we'll filter it out. This does preserve state keys that
|
for stateKeyNID, stateKey := range stateKeys {
|
||||||
// are "" since these will contain history visibility etc.
|
c.stateKeys[stateKeyNID] = stateKey
|
||||||
for nid, key := range stateKeys {
|
|
||||||
if key != "" && !strings.HasSuffix(key, ":"+string(serverName)) {
|
|
||||||
delete(stateKeys, nid)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Now filter through all of the state events for the room.
|
// Now filter through all of the state events for the room.
|
||||||
|
@ -260,8 +298,10 @@ func CheckServerAllowedToSeeEvent(
|
||||||
// keys then we'll add it to the list of filtered entries.
|
// keys then we'll add it to the list of filtered entries.
|
||||||
var filteredEntries []types.StateEntry
|
var filteredEntries []types.StateEntry
|
||||||
for _, entry := range stateEntries {
|
for _, entry := range stateEntries {
|
||||||
if _, ok := stateKeys[entry.EventStateKeyNID]; ok {
|
if key, ok := stateKeys[entry.EventStateKeyNID]; ok {
|
||||||
filteredEntries = append(filteredEntries, entry)
|
if key == "" || strings.HasSuffix(key, ":"+string(serverName)) {
|
||||||
|
filteredEntries = append(filteredEntries, entry)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -269,7 +309,7 @@ func CheckServerAllowedToSeeEvent(
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
stateAtEvent, err := LoadStateEvents(ctx, db, filteredEntries)
|
stateAtEvent, err := c.LoadStateEvents(c.ctx, c.db, filteredEntries)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
@ -303,6 +343,8 @@ func ScanEventTree(
|
||||||
var checkedServerInRoom bool
|
var checkedServerInRoom bool
|
||||||
var isServerInRoom bool
|
var isServerInRoom bool
|
||||||
|
|
||||||
|
c := NewCheckServerAllowedToSeeEventContext(ctx, db, info)
|
||||||
|
|
||||||
// Loop through the event IDs to retrieve the requested events and go
|
// Loop through the event IDs to retrieve the requested events and go
|
||||||
// through the whole tree (up to the provided limit) using the events'
|
// through the whole tree (up to the provided limit) using the events'
|
||||||
// "prev_event" key.
|
// "prev_event" key.
|
||||||
|
@ -345,7 +387,7 @@ BFSLoop:
|
||||||
// hasn't been seen before.
|
// hasn't been seen before.
|
||||||
if !visited[pre] {
|
if !visited[pre] {
|
||||||
visited[pre] = true
|
visited[pre] = true
|
||||||
allowed, err = CheckServerAllowedToSeeEvent(ctx, db, info, pre, serverName, isServerInRoom)
|
allowed, err = c.CheckServerAllowedToSeeEvent(pre, serverName, isServerInRoom)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error(
|
util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error(
|
||||||
"Error checking if allowed to see event",
|
"Error checking if allowed to see event",
|
||||||
|
|
|
@ -376,8 +376,10 @@ func (r *Queryer) QueryServerAllowedToSeeEvent(
|
||||||
if info == nil {
|
if info == nil {
|
||||||
return fmt.Errorf("QueryServerAllowedToSeeEvent: no room info for room %s", roomID)
|
return fmt.Errorf("QueryServerAllowedToSeeEvent: no room info for room %s", roomID)
|
||||||
}
|
}
|
||||||
response.AllowedToSeeEvent, err = helpers.CheckServerAllowedToSeeEvent(
|
|
||||||
ctx, r.DB, *info, request.EventID, request.ServerName, inRoomRes.IsInRoom,
|
c := helpers.NewCheckServerAllowedToSeeEventContext(ctx, r.DB, *info)
|
||||||
|
response.AllowedToSeeEvent, err = c.CheckServerAllowedToSeeEvent(
|
||||||
|
request.EventID, request.ServerName, inRoomRes.IsInRoom,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue