From 5223d02709587fbd0fe226b9efd66b3f43480b11 Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Mon, 25 Jan 2021 16:29:00 +0000 Subject: [PATCH] Still use workers for room concurrency --- roomserver/internal/input/input.go | 91 +++++++++++++++++++++++++++--- 1 file changed, 84 insertions(+), 7 deletions(-) diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index b16f1da5..ccca3ef3 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -19,6 +19,7 @@ import ( "context" "encoding/json" "sync" + "time" "github.com/Shopify/sarama" "github.com/matrix-org/dendrite/internal/hooks" @@ -27,6 +28,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/gomatrixserverlib" log "github.com/sirupsen/logrus" + "go.uber.org/atomic" ) type Inputer struct { @@ -37,6 +39,38 @@ type Inputer struct { OutputRoomEventTopic string latestEventsMutexes sync.Map // room ID -> sync.Mutex + workers sync.Map // room ID -> *inputWorker +} + +type inputTask struct { + ctx context.Context + event *api.InputRoomEvent + wg *sync.WaitGroup + err error // written back by worker, only safe to read when all tasks are done +} + +type inputWorker struct { + r *Inputer + running atomic.Bool + input chan *inputTask +} + +// Guarded by a CAS on w.running +func (w *inputWorker) start() { + defer w.running.Store(false) + for { + select { + case task := <-w.input: + hooks.Run(hooks.KindNewEventReceived, task.event.Event) + _, task.err = w.r.processRoomEvent(task.ctx, task.event) + if task.err == nil { + hooks.Run(hooks.KindNewEventPersisted, task.event.Event) + } + task.wg.Done() + case <-time.After(time.Second * 5): + return + } + } } // WriteOutputEvents implements OutputRoomEventWriter @@ -87,14 +121,57 @@ func (r *Inputer) InputRoomEvents( request *api.InputRoomEventsRequest, response *api.InputRoomEventsResponse, ) { - for i := range request.InputRoomEvents { - hooks.Run(hooks.KindNewEventReceived, &request.InputRoomEvents[i]) - if _, err := r.processRoomEvent(context.Background(), &request.InputRoomEvents[i]); err == nil { - hooks.Run(hooks.KindNewEventPersisted, &request.InputRoomEvents[i]) - } else { - response.ErrMsg = err.Error() - _, rejected := err.(*gomatrixserverlib.NotAllowed) + // Create a wait group. Each task that we dispatch will call Done on + // this wait group so that we know when all of our events have been + // processed. + wg := &sync.WaitGroup{} + wg.Add(len(request.InputRoomEvents)) + tasks := make([]*inputTask, len(request.InputRoomEvents)) + + for i, e := range request.InputRoomEvents { + // Work out if we are running per-room workers or if we're just doing + // it on a global basis (e.g. SQLite). + roomID := "global" + if r.DB.SupportsConcurrentRoomInputs() { + roomID = e.Event.RoomID() + } + + // Look up the worker, or create it if it doesn't exist. This channel + // is buffered to reduce the chance that we'll be blocked by another + // room - the channel will be quite small as it's just pointer types. + w, _ := r.workers.LoadOrStore(roomID, &inputWorker{ + r: r, + input: make(chan *inputTask, 32), + }) + worker := w.(*inputWorker) + + // Create a task. This contains the input event and a reference to + // the wait group, so that the worker can notify us when this specific + // task has been finished. + tasks[i] = &inputTask{ + ctx: context.Background(), + event: &request.InputRoomEvents[i], + wg: wg, + } + + // Send the task to the worker. + if worker.running.CAS(false, true) { + go worker.start() + } + worker.input <- tasks[i] + } + + // Wait for all of the workers to return results about our tasks. + wg.Wait() + + // If any of the tasks returned an error, we should probably report + // that back to the caller. + for _, task := range tasks { + if task.err != nil { + response.ErrMsg = task.err.Error() + _, rejected := task.err.(*gomatrixserverlib.NotAllowed) response.NotAllowed = rejected + return } } }