diff --git a/roomserver/internal/input/input.go b/roomserver/internal/input/input.go index 4d6da229..61f46cb2 100644 --- a/roomserver/internal/input/input.go +++ b/roomserver/internal/input/input.go @@ -44,8 +44,7 @@ type Inputer struct { type inputTask struct { ctx context.Context event *api.InputRoomEvent - wg *sync.WaitGroup - err error // written back by worker, only safe to read when all tasks are done + err chan error } type inputWorker struct { @@ -70,13 +69,13 @@ func (w *inputWorker) start() { default: } hooks.Run(hooks.KindNewEventReceived, task.event.Event) - _, task.err = w.r.processRoomEvent(task.ctx, task.event) - if task.err == nil { + _, err := w.r.processRoomEvent(task.ctx, task.event) + if err == nil { hooks.Run(hooks.KindNewEventPersisted, task.event.Event) } else { - sentry.CaptureException(task.err) + sentry.CaptureException(err) } - task.wg.Done() + task.err <- err case <-time.After(time.Second * 5): return } @@ -134,9 +133,11 @@ func (r *Inputer) InputRoomEvents( // Create a wait group. Each task that we dispatch will call Done on // this wait group so that we know when all of our events have been // processed. - wg := &sync.WaitGroup{} - wg.Add(len(request.InputRoomEvents)) - tasks := make([]*inputTask, len(request.InputRoomEvents)) + ctx, cancel := context.WithCancel(ctx) + defer cancel() + count := len(request.InputRoomEvents) + wait := make(chan error, count) + tasks := make([]*inputTask, count) for i, e := range request.InputRoomEvents { // Work out if we are running per-room workers or if we're just doing @@ -161,7 +162,7 @@ func (r *Inputer) InputRoomEvents( tasks[i] = &inputTask{ ctx: ctx, event: &request.InputRoomEvents[i], - wg: wg, + err: wait, } // Send the task to the worker. @@ -171,15 +172,23 @@ func (r *Inputer) InputRoomEvents( worker.input.push(tasks[i]) } - // Wait for all of the workers to return results about our tasks. - wg.Wait() + // Wait for the request context to close and then + go func() { + <-ctx.Done() + close(wait) + }() + // Wait for all of the workers to return results about our tasks. // If any of the tasks returned an error, we should probably report // that back to the caller. - for _, task := range tasks { - if task.err != nil { - response.ErrMsg = task.err.Error() - _, rejected := task.err.(*gomatrixserverlib.NotAllowed) + for err := range wait { + count-- + if count == 0 { + cancel() + } + if err != nil { + response.ErrMsg = err.Error() + _, rejected := err.(*gomatrixserverlib.NotAllowed) response.NotAllowed = rejected return }