diff --git a/roomserver/internal/helpers/auth.go b/roomserver/internal/helpers/auth.go index 1f4215e7..e9ae766a 100644 --- a/roomserver/internal/helpers/auth.go +++ b/roomserver/internal/helpers/auth.go @@ -83,7 +83,7 @@ func CheckForSoftFail( // Check if the event is allowed. if err = gomatrixserverlib.Allowed(event.Event, &authEvents); err != nil { // return true, nil - return true, err + return true, fmt.Errorf("gomatrixserverlib.Allowed: %w", err) } return false, nil } diff --git a/roomserver/internal/helpers/eventcache.go b/roomserver/internal/helpers/eventcache.go new file mode 100644 index 00000000..54f69045 --- /dev/null +++ b/roomserver/internal/helpers/eventcache.go @@ -0,0 +1,63 @@ +package helpers + +import ( + "context" + "fmt" + "sync" + + "github.com/matrix-org/dendrite/roomserver/storage" + "github.com/matrix-org/dendrite/roomserver/types" +) + +type CachedDB struct { + mutex sync.RWMutex + eventsByID map[string]*types.Event + eventsByNID map[types.EventNID]*types.Event + storage.Database +} + +func NewCachedDB(db storage.Database) *CachedDB { + return &CachedDB{ + Database: db, + eventsByID: make(map[string]*types.Event), + eventsByNID: make(map[types.EventNID]*types.Event), + } +} + +func (c *CachedDB) Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error) { + fmt.Println("Want", eventNIDs) + events := make([]types.Event, len(eventNIDs)) + retrieve := make([]types.EventNID, 0, len(eventNIDs)) + c.mutex.RLock() + for i, eventNID := range eventNIDs { + if cached, ok := c.eventsByNID[eventNID]; ok { + events[i] = *cached + fmt.Println(i, "Existing", cached, cached.EventID(), cached.EventNID, cached.Type(), *cached.StateKey()) + } else { + retrieve = append(retrieve, eventNID) + } + } + c.mutex.RUnlock() + var retrieved []types.Event + var err error + if len(retrieve) > 0 { + retrieved, err = c.Database.Events(ctx, retrieve) + if err != nil { + return nil, err + } + } + c.mutex.Lock() + defer c.mutex.Unlock() + for i, event := range retrieved { + c.eventsByID[event.EventID()] = &retrieved[i] + c.eventsByNID[event.EventNID] = &retrieved[i] + } + for i, eventNID := range eventNIDs { + if cached, ok := c.eventsByNID[eventNID]; ok { + fmt.Println(i, "Found", cached, cached.EventID(), cached.EventNID, cached.Type(), *cached.StateKey(), string(cached.Content())) + events[i] = *cached + } + } + fmt.Println("Returning", events) + return events, nil +} diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index 6bc43c9c..14be6192 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -26,6 +26,7 @@ import ( "github.com/matrix-org/dendrite/internal/hooks" "github.com/matrix-org/dendrite/roomserver/acls" "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/gomatrixserverlib" "github.com/prometheus/client_golang/prometheus" @@ -53,11 +54,16 @@ type inputWorker struct { r *Inputer running atomic.Bool input *fifoQueue + db *helpers.CachedDB } // Guarded by a CAS on w.running func (w *inputWorker) start() { - defer w.running.Store(false) + defer func() { + w.db = nil + w.running.Store(false) + }() + w.db = helpers.NewCachedDB(w.r.DB) for { select { case <-w.input.wait(): @@ -69,7 +75,7 @@ func (w *inputWorker) start() { "room_id": task.event.Event.RoomID(), }).Dec() hooks.Run(hooks.KindNewEventReceived, task.event.Event) - _, task.err = w.r.processRoomEvent(task.ctx, task.event) + _, task.err = w.r.processRoomEvent(task.ctx, task.event, w.db) if task.err == nil { hooks.Run(hooks.KindNewEventPersisted, task.event.Event) } else { diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 2a558c48..ff49e638 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -26,6 +26,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/matrix-org/dendrite/roomserver/state" + "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -62,6 +63,7 @@ var processRoomEventDuration = prometheus.NewHistogramVec( func (r *Inputer) processRoomEvent( ctx context.Context, input *api.InputRoomEvent, + db storage.Database, ) (eventID string, err error) { // Measure how long it takes to process this event. started := time.Now() @@ -79,7 +81,7 @@ func (r *Inputer) processRoomEvent( // if we have already got this event then do not process it again, if the input kind is an outlier. // Outliers contain no extra information which may warrant a re-processing. if input.Kind == api.KindOutlier { - evs, err2 := r.DB.EventsFromIDs(ctx, []string{event.EventID()}) + evs, err2 := db.EventsFromIDs(ctx, []string{event.EventID()}) if err2 == nil && len(evs) == 1 { // check hash matches if we're on early room versions where the event ID was a random string idFormat, err2 := headered.RoomVersion.EventIDFormat() @@ -101,7 +103,7 @@ func (r *Inputer) processRoomEvent( // Check that the event passes authentication checks and work out // the numeric IDs for the auth events. isRejected := false - authEventNIDs, rejectionErr := helpers.CheckAuthEvents(ctx, r.DB, headered, input.AuthEventIDs) + authEventNIDs, rejectionErr := helpers.CheckAuthEvents(ctx, db, headered, input.AuthEventIDs) if rejectionErr != nil { logrus.WithError(rejectionErr).WithField("event_id", event.EventID()).WithField("auth_event_ids", input.AuthEventIDs).Error("helpers.CheckAuthEvents failed for event, rejecting event") isRejected = true @@ -111,7 +113,7 @@ func (r *Inputer) processRoomEvent( if input.Kind == api.KindNew { // Check that the event passes authentication checks based on the // current room state. - softfail, err = helpers.CheckForSoftFail(ctx, r.DB, headered, input.StateEventIDs) + softfail, err = helpers.CheckForSoftFail(ctx, db, headered, input.StateEventIDs) if err != nil { logrus.WithFields(logrus.Fields{ "event_id": event.EventID(),