Reduce memory usage in federation /send endpoint (#1890)

* More aggressive event caching

* Deduplicate /state results

* Deduplicate more

* Ensure we use the correct list of events when excluding repeated state

* Fixes

* Ensure we track all events we already knew about properly
This commit is contained in:
Neil Alexander 2021-06-30 10:01:56 +01:00 committed by GitHub
parent c849e74dfc
commit 3afb161352
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 41 additions and 30 deletions

View file

@ -106,8 +106,8 @@ func Send(
eduAPI: eduAPI, eduAPI: eduAPI,
keys: keys, keys: keys,
federation: federation, federation: federation,
hadEvents: make(map[string]bool),
haveEvents: make(map[string]*gomatrixserverlib.HeaderedEvent), haveEvents: make(map[string]*gomatrixserverlib.HeaderedEvent),
newEvents: make(map[string]bool),
keyAPI: keyAPI, keyAPI: keyAPI,
roomsMu: mu, roomsMu: mu,
} }
@ -167,13 +167,12 @@ type txnReq struct {
servers []gomatrixserverlib.ServerName servers []gomatrixserverlib.ServerName
serversMutex sync.RWMutex serversMutex sync.RWMutex
roomsMu *internal.MutexByRoom roomsMu *internal.MutexByRoom
// a list of events from the auth and prev events which we already had
hadEvents map[string]bool
// local cache of events for auth checks, etc - this may include events // local cache of events for auth checks, etc - this may include events
// which the roomserver is unaware of. // which the roomserver is unaware of.
haveEvents map[string]*gomatrixserverlib.HeaderedEvent haveEvents map[string]*gomatrixserverlib.HeaderedEvent
// new events which the roomserver does not know about work string // metrics
newEvents map[string]bool
newEventsMutex sync.RWMutex
work string // metrics
} }
// A subset of FederationClient functionality that txn requires. Useful for testing. // A subset of FederationClient functionality that txn requires. Useful for testing.
@ -340,19 +339,6 @@ func (e missingPrevEventsError) Error() string {
return fmt.Sprintf("unable to get prev_events for event %q: %s", e.eventID, e.err) return fmt.Sprintf("unable to get prev_events for event %q: %s", e.eventID, e.err)
} }
func (t *txnReq) haveEventIDs() map[string]bool {
t.newEventsMutex.RLock()
defer t.newEventsMutex.RUnlock()
result := make(map[string]bool, len(t.haveEvents))
for eventID := range t.haveEvents {
if t.newEvents[eventID] {
continue
}
result[eventID] = true
}
return result
}
func (t *txnReq) processEDUs(ctx context.Context) { func (t *txnReq) processEDUs(ctx context.Context) {
for _, e := range t.EDUs { for _, e := range t.EDUs {
eduCountTotal.Inc() eduCountTotal.Inc()
@ -527,6 +513,15 @@ func (t *txnReq) processEvent(ctx context.Context, e *gomatrixserverlib.Event) e
return roomNotFoundError{e.RoomID()} return roomNotFoundError{e.RoomID()}
} }
// Prepare a map of all the events we already had before this point, so
// that we don't send them to the roomserver again.
for _, eventID := range append(e.AuthEventIDs(), e.PrevEventIDs()...) {
t.hadEvents[eventID] = true
}
for _, eventID := range append(stateResp.MissingAuthEventIDs, stateResp.MissingPrevEventIDs...) {
t.hadEvents[eventID] = false
}
if len(stateResp.MissingAuthEventIDs) > 0 { if len(stateResp.MissingAuthEventIDs) > 0 {
t.work = MetricsWorkMissingAuthEvents t.work = MetricsWorkMissingAuthEvents
logger.Infof("Event refers to %d unknown auth_events", len(stateResp.MissingAuthEventIDs)) logger.Infof("Event refers to %d unknown auth_events", len(stateResp.MissingAuthEventIDs))
@ -596,6 +591,8 @@ withNextEvent:
); err != nil { ); err != nil {
return fmt.Errorf("api.SendEvents: %w", err) return fmt.Errorf("api.SendEvents: %w", err)
} }
t.hadEvents[ev.EventID()] = true // if the roomserver didn't know about the event before, it does now
t.cacheAndReturn(ev.Headered(stateResp.RoomVersion))
delete(missingAuthEvents, missingAuthEventID) delete(missingAuthEvents, missingAuthEventID)
continue withNextEvent continue withNextEvent
} }
@ -739,7 +736,7 @@ func (t *txnReq) processEventWithMissingState(
api.KindOld, api.KindOld,
resolvedState, resolvedState,
backwardsExtremity.Headered(roomVersion), backwardsExtremity.Headered(roomVersion),
t.haveEventIDs(), t.hadEvents,
) )
if err != nil { if err != nil {
return fmt.Errorf("api.SendEventWithState: %w", err) return fmt.Errorf("api.SendEventWithState: %w", err)
@ -791,7 +788,7 @@ func (t *txnReq) lookupStateAfterEvent(ctx context.Context, roomVersion gomatrix
default: default:
return nil, false, fmt.Errorf("t.lookupEvent: %w", err) return nil, false, fmt.Errorf("t.lookupEvent: %w", err)
} }
t.cacheAndReturn(h) h = t.cacheAndReturn(h)
if h.StateKey() != nil { if h.StateKey() != nil {
addedToState := false addedToState := false
for i := range respState.StateEvents { for i := range respState.StateEvents {
@ -833,6 +830,7 @@ func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, event
// set the event from the haveEvents cache - this means we will share pointers with other prev_event branches for this // set the event from the haveEvents cache - this means we will share pointers with other prev_event branches for this
// processEvent request, which is better for memory. // processEvent request, which is better for memory.
stateEvents[i] = t.cacheAndReturn(ev) stateEvents[i] = t.cacheAndReturn(ev)
t.hadEvents[ev.EventID()] = true
} }
// we should never access res.StateEvents again so we delete it here to make GC faster // we should never access res.StateEvents again so we delete it here to make GC faster
res.StateEvents = nil res.StateEvents = nil
@ -863,8 +861,9 @@ func (t *txnReq) lookupStateAfterEventLocally(ctx context.Context, roomID, event
if err = t.rsAPI.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil { if err = t.rsAPI.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil {
return nil return nil
} }
for i := range queryRes.Events { for i, ev := range queryRes.Events {
authEvents = append(authEvents, t.cacheAndReturn(queryRes.Events[i]).Unwrap()) authEvents = append(authEvents, t.cacheAndReturn(queryRes.Events[i]).Unwrap())
t.hadEvents[ev.EventID()] = true
} }
queryRes.Events = nil queryRes.Events = nil
} }
@ -939,8 +938,9 @@ func (t *txnReq) getMissingEvents(ctx context.Context, e *gomatrixserverlib.Even
return nil, err return nil, err
} }
latestEvents := make([]string, len(res.LatestEvents)) latestEvents := make([]string, len(res.LatestEvents))
for i := range res.LatestEvents { for i, ev := range res.LatestEvents {
latestEvents[i] = res.LatestEvents[i].EventID latestEvents[i] = res.LatestEvents[i].EventID
t.hadEvents[ev.EventID] = true
} }
var missingResp *gomatrixserverlib.RespMissingEvents var missingResp *gomatrixserverlib.RespMissingEvents
@ -985,6 +985,12 @@ func (t *txnReq) getMissingEvents(ctx context.Context, e *gomatrixserverlib.Even
// For now, we do not allow Case B, so reject the event. // For now, we do not allow Case B, so reject the event.
logger.Infof("get_missing_events returned %d events", len(missingResp.Events)) logger.Infof("get_missing_events returned %d events", len(missingResp.Events))
// Make sure events from the missingResp are using the cache - missing events
// will be added and duplicates will be removed.
for i, ev := range missingResp.Events {
missingResp.Events[i] = t.cacheAndReturn(ev.Headered(roomVersion)).Unwrap()
}
// topologically sort and sanity check that we are making forward progress // topologically sort and sanity check that we are making forward progress
newEvents = gomatrixserverlib.ReverseTopologicalOrdering(missingResp.Events, gomatrixserverlib.TopologicalOrderByPrevEvents) newEvents = gomatrixserverlib.ReverseTopologicalOrdering(missingResp.Events, gomatrixserverlib.TopologicalOrderByPrevEvents)
shouldHaveSomeEventIDs := e.PrevEventIDs() shouldHaveSomeEventIDs := e.PrevEventIDs()
@ -1023,6 +1029,14 @@ func (t *txnReq) lookupMissingStateViaState(ctx context.Context, roomID, eventID
if err := state.Check(ctx, t.keys, nil); err != nil { if err := state.Check(ctx, t.keys, nil); err != nil {
return nil, err return nil, err
} }
// Cache the results of this state lookup and deduplicate anything we already
// have in the cache, freeing up memory.
for i, ev := range state.AuthEvents {
state.AuthEvents[i] = t.cacheAndReturn(ev.Headered(roomVersion)).Unwrap()
}
for i, ev := range state.StateEvents {
state.StateEvents[i] = t.cacheAndReturn(ev.Headered(roomVersion)).Unwrap()
}
return &state, nil return &state, nil
} }
@ -1055,9 +1069,10 @@ func (t *txnReq) lookupMissingStateViaStateIDs(ctx context.Context, roomID, even
if err = t.rsAPI.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil { if err = t.rsAPI.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil {
return nil, err return nil, err
} }
for i := range queryRes.Events { for i, ev := range queryRes.Events {
queryRes.Events[i] = t.cacheAndReturn(queryRes.Events[i])
t.hadEvents[ev.EventID()] = true
evID := queryRes.Events[i].EventID() evID := queryRes.Events[i].EventID()
t.cacheAndReturn(queryRes.Events[i])
if missing[evID] { if missing[evID] {
delete(missing, evID) delete(missing, evID)
} }
@ -1221,9 +1236,5 @@ func (t *txnReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib.
util.GetLogger(ctx).WithError(err).Warnf("Transaction: Couldn't validate signature of event %q", event.EventID()) util.GetLogger(ctx).WithError(err).Warnf("Transaction: Couldn't validate signature of event %q", event.EventID())
return nil, verifySigError{event.EventID(), err} return nil, verifySigError{event.EventID(), err}
} }
h := event.Headered(roomVersion) return t.cacheAndReturn(event.Headered(roomVersion)), nil
t.newEventsMutex.Lock()
t.newEvents[h.EventID()] = true
t.newEventsMutex.Unlock()
return h, nil
} }

View file

@ -370,7 +370,7 @@ func mustCreateTransaction(rsAPI api.RoomserverInternalAPI, fedClient txnFederat
keys: &test.NopJSONVerifier{}, keys: &test.NopJSONVerifier{},
federation: fedClient, federation: fedClient,
haveEvents: make(map[string]*gomatrixserverlib.HeaderedEvent), haveEvents: make(map[string]*gomatrixserverlib.HeaderedEvent),
newEvents: make(map[string]bool), hadEvents: make(map[string]bool),
roomsMu: internal.NewMutexByRoom(), roomsMu: internal.NewMutexByRoom(),
} }
t.PDUs = pdus t.PDUs = pdus