From 5106cc807cf22a95420b24f6bfdd5c9ac8aa06de Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Fri, 11 Feb 2022 17:40:14 +0000 Subject: [PATCH] Ensure only one transaction is used for RS input per room (#2178) * Ensure the input API only uses a single transaction * Remove more of the dead query API call * Tidy up * Fix tests hopefully * Don't do unnecessary work for rooms that don't exist * Improve error, fix another case where transaction wasn't used properly * Add a unit test for checking single transaction on RS input API * Fix logic oops when deciding whether to use a transaction in storeEvent --- federationapi/routing/send_test.go | 43 +----- roomserver/api/api.go | 7 - roomserver/api/api_trace.go | 10 -- roomserver/api/query.go | 21 --- roomserver/internal/input/input_events.go | 33 ++--- roomserver/internal/input/input_missing.go | 123 +++++++++--------- roomserver/internal/input/input_test.go | 93 +++++++++++++ roomserver/internal/query/query.go | 33 ----- roomserver/inthttp/client.go | 14 -- roomserver/inthttp/server.go | 14 -- .../storage/postgres/event_json_table.go | 3 +- roomserver/storage/shared/room_updater.go | 29 ++++- roomserver/storage/shared/storage.go | 2 +- 13 files changed, 211 insertions(+), 214 deletions(-) create mode 100644 roomserver/internal/input/input_test.go diff --git a/federationapi/routing/send_test.go b/federationapi/routing/send_test.go index f1f6169d..4280643e 100644 --- a/federationapi/routing/send_test.go +++ b/federationapi/routing/send_test.go @@ -93,11 +93,10 @@ func (o *testEDUProducer) InputCrossSigningKeyUpdate( type testRoomserverAPI struct { api.RoomserverInternalAPITrace - inputRoomEvents []api.InputRoomEvent - queryMissingAuthPrevEvents func(*api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse - queryStateAfterEvents func(*api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse - queryEventsByID func(req *api.QueryEventsByIDRequest) api.QueryEventsByIDResponse - queryLatestEventsAndState func(*api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse + inputRoomEvents []api.InputRoomEvent + queryStateAfterEvents func(*api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse + queryEventsByID func(req *api.QueryEventsByIDRequest) api.QueryEventsByIDResponse + queryLatestEventsAndState func(*api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse } func (t *testRoomserverAPI) InputRoomEvents( @@ -140,20 +139,6 @@ func (t *testRoomserverAPI) QueryStateAfterEvents( return nil } -// Query the state after a list of events in a room from the room server. -func (t *testRoomserverAPI) QueryMissingAuthPrevEvents( - ctx context.Context, - request *api.QueryMissingAuthPrevEventsRequest, - response *api.QueryMissingAuthPrevEventsResponse, -) error { - response.RoomVersion = testRoomVersion - res := t.queryMissingAuthPrevEvents(request) - response.RoomExists = res.RoomExists - response.MissingAuthEventIDs = res.MissingAuthEventIDs - response.MissingPrevEventIDs = res.MissingPrevEventIDs - return nil -} - // Query a list of events by event ID. func (t *testRoomserverAPI) QueryEventsByID( ctx context.Context, @@ -312,15 +297,7 @@ func assertInputRoomEvents(t *testing.T, got []api.InputRoomEvent, want []*gomat // The purpose of this test is to check that receiving an event over federation for which we have the prev_events works correctly, and passes it on // to the roomserver. It's the most basic test possible. func TestBasicTransaction(t *testing.T) { - rsAPI := &testRoomserverAPI{ - queryMissingAuthPrevEvents: func(req *api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse { - return api.QueryMissingAuthPrevEventsResponse{ - RoomExists: true, - MissingAuthEventIDs: []string{}, - MissingPrevEventIDs: []string{}, - } - }, - } + rsAPI := &testRoomserverAPI{} pdus := []json.RawMessage{ testData[len(testData)-1], // a message event } @@ -332,15 +309,7 @@ func TestBasicTransaction(t *testing.T) { // The purpose of this test is to check that if the event received fails auth checks the event is still sent to the roomserver // as it does the auth check. func TestTransactionFailAuthChecks(t *testing.T) { - rsAPI := &testRoomserverAPI{ - queryMissingAuthPrevEvents: func(req *api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse { - return api.QueryMissingAuthPrevEventsResponse{ - RoomExists: true, - MissingAuthEventIDs: []string{}, - MissingPrevEventIDs: []string{}, - } - }, - } + rsAPI := &testRoomserverAPI{} pdus := []json.RawMessage{ testData[len(testData)-1], // a message event } diff --git a/roomserver/api/api.go b/roomserver/api/api.go index d35fd84d..e6d37e8f 100644 --- a/roomserver/api/api.go +++ b/roomserver/api/api.go @@ -83,13 +83,6 @@ type RoomserverInternalAPI interface { response *QueryStateAfterEventsResponse, ) error - // Query whether the roomserver is missing any auth or prev events. - QueryMissingAuthPrevEvents( - ctx context.Context, - request *QueryMissingAuthPrevEventsRequest, - response *QueryMissingAuthPrevEventsResponse, - ) error - // Query a list of events by event ID. QueryEventsByID( ctx context.Context, diff --git a/roomserver/api/api_trace.go b/roomserver/api/api_trace.go index 64cbaca4..16f52abb 100644 --- a/roomserver/api/api_trace.go +++ b/roomserver/api/api_trace.go @@ -129,16 +129,6 @@ func (t *RoomserverInternalAPITrace) QueryStateAfterEvents( return err } -func (t *RoomserverInternalAPITrace) QueryMissingAuthPrevEvents( - ctx context.Context, - req *QueryMissingAuthPrevEventsRequest, - res *QueryMissingAuthPrevEventsResponse, -) error { - err := t.Impl.QueryMissingAuthPrevEvents(ctx, req, res) - util.GetLogger(ctx).WithError(err).Infof("QueryMissingAuthPrevEvents req=%+v res=%+v", js(req), js(res)) - return err -} - func (t *RoomserverInternalAPITrace) QueryEventsByID( ctx context.Context, req *QueryEventsByIDRequest, diff --git a/roomserver/api/query.go b/roomserver/api/query.go index 28321715..96d6711c 100644 --- a/roomserver/api/query.go +++ b/roomserver/api/query.go @@ -83,27 +83,6 @@ type QueryStateAfterEventsResponse struct { StateEvents []*gomatrixserverlib.HeaderedEvent `json:"state_events"` } -type QueryMissingAuthPrevEventsRequest struct { - // The room ID to query the state in. - RoomID string `json:"room_id"` - // The list of auth events to check the existence of. - AuthEventIDs []string `json:"auth_event_ids"` - // The list of previous events to check the existence of. - PrevEventIDs []string `json:"prev_event_ids"` -} - -type QueryMissingAuthPrevEventsResponse struct { - // Does the room exist on this roomserver? - // If the room doesn't exist all other fields will be empty. - RoomExists bool `json:"room_exists"` - // The room version of the room. - RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"` - // The event IDs of the auth events that we don't know locally. - MissingAuthEventIDs []string `json:"missing_auth_event_ids"` - // The event IDs of the previous events that we don't know locally. - MissingPrevEventIDs []string `json:"missing_prev_event_ids"` -} - // QueryEventsByIDRequest is a request to QueryEventsByID type QueryEventsByIDRequest struct { // The event IDs to look up. diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 873a051c..4e151699 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -128,20 +128,24 @@ func (r *Inputer) processRoomEvent( } } - missingRes := &api.QueryMissingAuthPrevEventsResponse{} - serverRes := &fedapi.QueryJoinedHostServerNamesInRoomResponse{} - if event.Type() != gomatrixserverlib.MRoomCreate || !event.StateKeyEquals("") { - missingReq := &api.QueryMissingAuthPrevEventsRequest{ - RoomID: event.RoomID(), - AuthEventIDs: event.AuthEventIDs(), - PrevEventIDs: event.PrevEventIDs(), - } - if err := r.Queryer.QueryMissingAuthPrevEvents(ctx, missingReq, missingRes); err != nil { - return rollbackTransaction, fmt.Errorf("r.Queryer.QueryMissingAuthPrevEvents: %w", err) - } + // Don't waste time processing the event if the room doesn't exist. + // A room entry locally will only be created in response to a create + // event. + isCreateEvent := event.Type() == gomatrixserverlib.MRoomCreate && event.StateKeyEquals("") + if !updater.RoomExists() && !isCreateEvent { + return rollbackTransaction, fmt.Errorf("room %s does not exist for event %s", event.RoomID(), event.EventID()) + } + + var missingAuth, missingPrev bool + serverRes := &fedapi.QueryJoinedHostServerNamesInRoomResponse{} + if !isCreateEvent { + missingAuthIDs, missingPrevIDs, err := updater.MissingAuthPrevEvents(ctx, event) + if err != nil { + return rollbackTransaction, fmt.Errorf("updater.MissingAuthPrevEvents: %w", err) + } + missingAuth = len(missingAuthIDs) > 0 + missingPrev = !input.HasState && len(missingPrevIDs) > 0 } - missingAuth := len(missingRes.MissingAuthEventIDs) > 0 - missingPrev := !input.HasState && len(missingRes.MissingPrevEventIDs) > 0 if missingAuth || missingPrev { serverReq := &fedapi.QueryJoinedHostServerNamesInRoomRequest{ @@ -246,14 +250,13 @@ func (r *Inputer) processRoomEvent( missingState := missingStateReq{ origin: input.Origin, inputer: r, - queryer: r.Queryer, db: updater, federation: r.FSAPI, keys: r.KeyRing, roomsMu: internal.NewMutexByRoom(), servers: serverRes.ServerNames, hadEvents: map[string]bool{}, - haveEvents: map[string]*gomatrixserverlib.HeaderedEvent{}, + haveEvents: map[string]*gomatrixserverlib.Event{}, } if stateSnapshot, err := missingState.processEventWithMissingState(ctx, event, headered.RoomVersion); err != nil { // Something went wrong with retrieving the missing state, so we can't diff --git a/roomserver/internal/input/input_missing.go b/roomserver/internal/input/input_missing.go index 19771d4b..fc3be798 100644 --- a/roomserver/internal/input/input_missing.go +++ b/roomserver/internal/input/input_missing.go @@ -10,7 +10,7 @@ import ( fedapi "github.com/matrix-org/dendrite/federationapi/api" "github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/roomserver/api" - "github.com/matrix-org/dendrite/roomserver/internal/query" + "github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" @@ -27,14 +27,13 @@ type missingStateReq struct { origin gomatrixserverlib.ServerName db *shared.RoomUpdater inputer *Inputer - queryer *query.Queryer keys gomatrixserverlib.JSONVerifier federation fedapi.FederationInternalAPI roomsMu *internal.MutexByRoom servers []gomatrixserverlib.ServerName hadEvents map[string]bool hadEventsMutex sync.Mutex - haveEvents map[string]*gomatrixserverlib.HeaderedEvent + haveEvents map[string]*gomatrixserverlib.Event haveEventsMutex sync.Mutex } @@ -326,20 +325,20 @@ func (t *missingStateReq) lookupStateAfterEvent(ctx context.Context, roomVersion for i := range respState.StateEvents { se := respState.StateEvents[i] if se.Type() == h.Type() && se.StateKeyEquals(*h.StateKey()) { - respState.StateEvents[i] = h.Unwrap() + respState.StateEvents[i] = h addedToState = true break } } if !addedToState { - respState.StateEvents = append(respState.StateEvents, h.Unwrap()) + respState.StateEvents = append(respState.StateEvents, h) } } return respState, false, nil } -func (t *missingStateReq) cacheAndReturn(ev *gomatrixserverlib.HeaderedEvent) *gomatrixserverlib.HeaderedEvent { +func (t *missingStateReq) cacheAndReturn(ev *gomatrixserverlib.Event) *gomatrixserverlib.Event { t.haveEventsMutex.Lock() defer t.haveEventsMutex.Unlock() if cached, exists := t.haveEvents[ev.EventID()]; exists { @@ -350,32 +349,49 @@ func (t *missingStateReq) cacheAndReturn(ev *gomatrixserverlib.HeaderedEvent) *g } func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, roomID, eventID string) *parsedRespState { - var res api.QueryStateAfterEventsResponse - err := t.queryer.QueryStateAfterEvents(ctx, &api.QueryStateAfterEventsRequest{ - RoomID: roomID, - PrevEventIDs: []string{eventID}, - }, &res) - if err != nil || !res.PrevEventsExist { - util.GetLogger(ctx).WithField("room_id", roomID).WithError(err).Warnf("failed to query state after %s locally, prev exists=%v", eventID, res.PrevEventsExist) + var res parsedRespState + roomInfo, err := t.db.RoomInfo(ctx, roomID) + if err != nil { return nil } - stateEvents := make([]*gomatrixserverlib.HeaderedEvent, len(res.StateEvents)) - for i, ev := range res.StateEvents { + roomState := state.NewStateResolution(t.db, roomInfo) + stateAtEvents, err := t.db.StateAtEventIDs(ctx, []string{eventID}) + if err != nil { + util.GetLogger(ctx).WithField("room_id", roomID).WithError(err).Warnf("failed to get state after %s locally", eventID) + return nil + } + stateEntries, err := roomState.LoadCombinedStateAfterEvents(ctx, stateAtEvents) + if err != nil { + util.GetLogger(ctx).WithField("room_id", roomID).WithError(err).Warnf("failed to load combined state after %s locally", eventID) + return nil + } + stateEventNIDs := make([]types.EventNID, 0, len(stateEntries)) + for _, entry := range stateEntries { + stateEventNIDs = append(stateEventNIDs, entry.EventNID) + } + stateEvents, err := t.db.Events(ctx, stateEventNIDs) + if err != nil { + util.GetLogger(ctx).WithField("room_id", roomID).WithError(err).Warnf("failed to load state events locally") + return nil + } + res.StateEvents = make([]*gomatrixserverlib.Event, 0, len(stateEvents)) + for _, ev := range stateEvents { // set the event from the haveEvents cache - this means we will share pointers with other prev_event branches for this // processEvent request, which is better for memory. - stateEvents[i] = t.cacheAndReturn(ev) + res.StateEvents = append(res.StateEvents, t.cacheAndReturn(ev.Event)) t.hadEvent(ev.EventID()) } - // we should never access res.StateEvents again so we delete it here to make GC faster - res.StateEvents = nil - var authEvents []*gomatrixserverlib.Event + // encourage GC + stateEvents, stateEventNIDs, stateEntries, stateAtEvents = nil, nil, nil, nil // nolint:ineffassign + missingAuthEvents := map[string]bool{} + res.AuthEvents = make([]*gomatrixserverlib.Event, 0, len(stateEvents)*3) for _, ev := range stateEvents { t.haveEventsMutex.Lock() for _, ae := range ev.AuthEventIDs() { if aev, ok := t.haveEvents[ae]; ok { - authEvents = append(authEvents, aev.Unwrap()) + res.AuthEvents = append(res.AuthEvents, aev) } else { missingAuthEvents[ae] = true } @@ -389,25 +405,18 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, room for evID := range missingAuthEvents { missingEventList = append(missingEventList, evID) } - queryReq := api.QueryEventsByIDRequest{ - EventIDs: missingEventList, - } util.GetLogger(ctx).WithField("count", len(missingEventList)).Debugf("Fetching missing auth events") - var queryRes api.QueryEventsByIDResponse - if err = t.queryer.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil { + events, err := t.db.EventsFromIDs(ctx, missingEventList) + if err != nil { return nil } - for i, ev := range queryRes.Events { - authEvents = append(authEvents, t.cacheAndReturn(queryRes.Events[i]).Unwrap()) + for i, ev := range events { + res.AuthEvents = append(res.AuthEvents, t.cacheAndReturn(events[i].Event)) t.hadEvent(ev.EventID()) } - queryRes.Events = nil } - return &parsedRespState{ - StateEvents: gomatrixserverlib.UnwrapEventHeaders(stateEvents), - AuthEvents: authEvents, - } + return &res } // lookuptStateBeforeEvent returns the room state before the event e, which is just /state_ids and/or /state depending on what @@ -448,7 +457,7 @@ retryAllowedState: return nil, fmt.Errorf("missing auth event %s and failed to look it up: %w", missing.AuthEventID, err2) } util.GetLogger(ctx).Tracef("fetched event %s", missing.AuthEventID) - resolvedStateEvents = append(resolvedStateEvents, h.Unwrap()) + resolvedStateEvents = append(resolvedStateEvents, h) goto retryAllowedState default: } @@ -513,7 +522,7 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserve logger.Debugf("get_missing_events returned %d events", len(missingResp.Events)) missingEvents := make([]*gomatrixserverlib.Event, 0, len(missingResp.Events)) for _, ev := range missingResp.Events.UntrustedEvents(roomVersion) { - missingEvents = append(missingEvents, t.cacheAndReturn(ev.Headered(roomVersion)).Unwrap()) + missingEvents = append(missingEvents, t.cacheAndReturn(ev)) } // topologically sort and sanity check that we are making forward progress @@ -602,11 +611,11 @@ func (t *missingStateReq) lookupMissingStateViaState( // We load these as trusted as we called state.Check before which loaded them as untrusted. for i, evJSON := range state.AuthEvents { ev, _ := gomatrixserverlib.NewEventFromTrustedJSON(evJSON, false, roomVersion) - parsedState.AuthEvents[i] = t.cacheAndReturn(ev.Headered(roomVersion)).Unwrap() + parsedState.AuthEvents[i] = t.cacheAndReturn(ev) } for i, evJSON := range state.StateEvents { ev, _ := gomatrixserverlib.NewEventFromTrustedJSON(evJSON, false, roomVersion) - parsedState.StateEvents[i] = t.cacheAndReturn(ev.Headered(roomVersion)).Unwrap() + parsedState.StateEvents[i] = t.cacheAndReturn(ev) } return parsedState, nil } @@ -634,23 +643,22 @@ func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roo } t.haveEventsMutex.Unlock() - // fetch as many as we can from the roomserver - queryReq := api.QueryEventsByIDRequest{ - EventIDs: missingEventList, + events, err := t.db.EventsFromIDs(ctx, missingEventList) + if err != nil { + return nil, fmt.Errorf("t.db.EventsFromIDs: %w", err) } - var queryRes api.QueryEventsByIDResponse - if err = t.queryer.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil { - return nil, err - } - for i, ev := range queryRes.Events { - queryRes.Events[i] = t.cacheAndReturn(queryRes.Events[i]) + + for i, ev := range events { + events[i].Event = t.cacheAndReturn(events[i].Event) t.hadEvent(ev.EventID()) - evID := queryRes.Events[i].EventID() + evID := events[i].EventID() if missing[evID] { delete(missing, evID) } } - queryRes.Events = nil // allow it to be GCed + + // encourage GC + events = nil // nolint:ineffassign concurrentRequests := 8 missingCount := len(missing) @@ -704,7 +712,7 @@ func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roo // Define what we'll do in order to fetch the missing event ID. fetch := func(missingEventID string) { - var h *gomatrixserverlib.HeaderedEvent + var h *gomatrixserverlib.Event h, err = t.lookupEvent(ctx, roomVersion, roomID, missingEventID, false) switch err.(type) { case verifySigError: @@ -759,7 +767,7 @@ func (t *missingStateReq) createRespStateFromStateIDs( logrus.Tracef("Missing state event in createRespStateFromStateIDs: %s", stateIDs.StateEventIDs[i]) continue } - respState.StateEvents = append(respState.StateEvents, ev.Unwrap()) + respState.StateEvents = append(respState.StateEvents, ev) } for i := range stateIDs.AuthEventIDs { ev, ok := t.haveEvents[stateIDs.AuthEventIDs[i]] @@ -767,7 +775,7 @@ func (t *missingStateReq) createRespStateFromStateIDs( logrus.Tracef("Missing auth event in createRespStateFromStateIDs: %s", stateIDs.AuthEventIDs[i]) continue } - respState.AuthEvents = append(respState.AuthEvents, ev.Unwrap()) + respState.AuthEvents = append(respState.AuthEvents, ev) } // We purposefully do not do auth checks on the returned events, as they will still // be processed in the exact same way, just as a 'rejected' event @@ -775,17 +783,14 @@ func (t *missingStateReq) createRespStateFromStateIDs( return &respState, nil } -func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, _, missingEventID string, localFirst bool) (*gomatrixserverlib.HeaderedEvent, error) { +func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, _, missingEventID string, localFirst bool) (*gomatrixserverlib.Event, error) { if localFirst { // fetch from the roomserver - queryReq := api.QueryEventsByIDRequest{ - EventIDs: []string{missingEventID}, - } - var queryRes api.QueryEventsByIDResponse - if err := t.queryer.QueryEventsByID(ctx, &queryReq, &queryRes); err != nil { + events, err := t.db.EventsFromIDs(ctx, []string{missingEventID}) + if err != nil { util.GetLogger(ctx).Warnf("Failed to query roomserver for missing event %s: %s - falling back to remote", missingEventID, err) - } else if len(queryRes.Events) == 1 { - return queryRes.Events[0], nil + } else if len(events) == 1 { + return events[0].Event, nil } } var event *gomatrixserverlib.Event @@ -822,7 +827,7 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs util.GetLogger(ctx).WithError(err).Warnf("Couldn't validate signature of event %q from /event", event.EventID()) return nil, verifySigError{event.EventID(), err} } - return t.cacheAndReturn(event.Headered(roomVersion)), nil + return t.cacheAndReturn(event), nil } func checkAllowedByState(e *gomatrixserverlib.Event, stateEvents []*gomatrixserverlib.Event) error { diff --git a/roomserver/internal/input/input_test.go b/roomserver/internal/input/input_test.go new file mode 100644 index 00000000..4fa96628 --- /dev/null +++ b/roomserver/internal/input/input_test.go @@ -0,0 +1,93 @@ +package input_test + +import ( + "context" + "fmt" + "os" + "testing" + "time" + + "github.com/matrix-org/dendrite/internal/caching" + "github.com/matrix-org/dendrite/roomserver/api" + "github.com/matrix-org/dendrite/roomserver/internal/input" + "github.com/matrix-org/dendrite/roomserver/storage" + "github.com/matrix-org/dendrite/setup/config" + "github.com/matrix-org/gomatrixserverlib" +) + +func psqlConnectionString() config.DataSource { + user := os.Getenv("POSTGRES_USER") + if user == "" { + user = "dendrite" + } + dbName := os.Getenv("POSTGRES_DB") + if dbName == "" { + dbName = "dendrite" + } + connStr := fmt.Sprintf( + "user=%s dbname=%s sslmode=disable", user, dbName, + ) + password := os.Getenv("POSTGRES_PASSWORD") + if password != "" { + connStr += fmt.Sprintf(" password=%s", password) + } + host := os.Getenv("POSTGRES_HOST") + if host != "" { + connStr += fmt.Sprintf(" host=%s", host) + } + return config.DataSource(connStr) +} + +func TestSingleTransactionOnInput(t *testing.T) { + deadline, _ := t.Deadline() + if max := time.Now().Add(time.Second * 3); deadline.After(max) { + deadline = max + } + ctx, cancel := context.WithDeadline(context.Background(), deadline) + defer cancel() + + event, err := gomatrixserverlib.NewEventFromTrustedJSON( + []byte(`{"auth_events":[],"content":{"creator":"@neilalexander:dendrite.matrix.org","room_version":"6"},"depth":1,"hashes":{"sha256":"jqOqdNEH5r0NiN3xJtj0u5XUVmRqq9YvGbki1wxxuuM"},"origin":"dendrite.matrix.org","origin_server_ts":1644595362726,"prev_events":[],"prev_state":[],"room_id":"!jSZZRknA6GkTBXNP:dendrite.matrix.org","sender":"@neilalexander:dendrite.matrix.org","signatures":{"dendrite.matrix.org":{"ed25519:6jB2aB":"bsQXO1wketf1OSe9xlndDIWe71W9KIundc6rBw4KEZdGPW7x4Tv4zDWWvbxDsG64sS2IPWfIm+J0OOozbrWIDw"}},"state_key":"","type":"m.room.create"}`), + false, gomatrixserverlib.RoomVersionV6, + ) + if err != nil { + t.Fatal(err) + } + in := api.InputRoomEvent{ + Kind: api.KindOutlier, // don't panic if we generate an output event + Event: event.Headered(gomatrixserverlib.RoomVersionV6), + } + cache, err := caching.NewInMemoryLRUCache(false) + if err != nil { + t.Fatal(err) + } + db, err := storage.Open( + &config.DatabaseOptions{ + ConnectionString: psqlConnectionString(), + MaxOpenConnections: 1, + MaxIdleConnections: 1, + }, + cache, + ) + if err != nil { + t.Logf("PostgreSQL not available (%s), skipping", err) + t.SkipNow() + } + inputter := &input.Inputer{ + DB: db, + } + res := &api.InputRoomEventsResponse{} + inputter.InputRoomEvents( + ctx, + &api.InputRoomEventsRequest{ + InputRoomEvents: []api.InputRoomEvent{in}, + Asynchronous: false, + }, + res, + ) + // If we fail here then it's because we've hit the test deadline, + // so we probably deadlocked + if err := res.Err(); err != nil { + t.Fatal(err) + } +} diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index 05cd686f..c8bbe770 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -125,39 +125,6 @@ func (r *Queryer) QueryStateAfterEvents( return nil } -// QueryMissingAuthPrevEvents implements api.RoomserverInternalAPI -func (r *Queryer) QueryMissingAuthPrevEvents( - ctx context.Context, - request *api.QueryMissingAuthPrevEventsRequest, - response *api.QueryMissingAuthPrevEventsResponse, -) error { - info, err := r.DB.RoomInfo(ctx, request.RoomID) - if err != nil { - return err - } - if info == nil { - return errors.New("room doesn't exist") - } - - response.RoomExists = !info.IsStub - response.RoomVersion = info.RoomVersion - - for _, authEventID := range request.AuthEventIDs { - if nids, err := r.DB.EventNIDs(ctx, []string{authEventID}); err != nil || len(nids) == 0 { - response.MissingAuthEventIDs = append(response.MissingAuthEventIDs, authEventID) - } - } - - for _, prevEventID := range request.PrevEventIDs { - state, err := r.DB.StateAtEventIDs(ctx, []string{prevEventID}) - if err != nil || len(state) == 0 || (!state[0].IsCreate() && state[0].BeforeStateSnapshotNID == 0) { - response.MissingPrevEventIDs = append(response.MissingPrevEventIDs, prevEventID) - } - } - - return nil -} - // QueryEventsByID implements api.RoomserverInternalAPI func (r *Queryer) QueryEventsByID( ctx context.Context, diff --git a/roomserver/inthttp/client.go b/roomserver/inthttp/client.go index 4f6a58bd..a61404ef 100644 --- a/roomserver/inthttp/client.go +++ b/roomserver/inthttp/client.go @@ -40,7 +40,6 @@ const ( // Query operations RoomserverQueryLatestEventsAndStatePath = "/roomserver/queryLatestEventsAndState" RoomserverQueryStateAfterEventsPath = "/roomserver/queryStateAfterEvents" - RoomserverQueryMissingAuthPrevEventsPath = "/roomserver/queryMissingAuthPrevEvents" RoomserverQueryEventsByIDPath = "/roomserver/queryEventsByID" RoomserverQueryMembershipForUserPath = "/roomserver/queryMembershipForUser" RoomserverQueryMembershipsForRoomPath = "/roomserver/queryMembershipsForRoom" @@ -302,19 +301,6 @@ func (h *httpRoomserverInternalAPI) QueryStateAfterEvents( return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } -// QueryStateAfterEvents implements RoomserverQueryAPI -func (h *httpRoomserverInternalAPI) QueryMissingAuthPrevEvents( - ctx context.Context, - request *api.QueryMissingAuthPrevEventsRequest, - response *api.QueryMissingAuthPrevEventsResponse, -) error { - span, ctx := opentracing.StartSpanFromContext(ctx, "QueryMissingAuthPrevEvents") - defer span.Finish() - - apiURL := h.roomserverURL + RoomserverQueryMissingAuthPrevEventsPath - return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) -} - // QueryEventsByID implements RoomserverQueryAPI func (h *httpRoomserverInternalAPI) QueryEventsByID( ctx context.Context, diff --git a/roomserver/inthttp/server.go b/roomserver/inthttp/server.go index bf319262..691a4583 100644 --- a/roomserver/inthttp/server.go +++ b/roomserver/inthttp/server.go @@ -149,20 +149,6 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router) { return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) - internalAPIMux.Handle( - RoomserverQueryMissingAuthPrevEventsPath, - httputil.MakeInternalAPI("queryMissingAuthPrevEvents", func(req *http.Request) util.JSONResponse { - var request api.QueryMissingAuthPrevEventsRequest - var response api.QueryMissingAuthPrevEventsResponse - if err := json.NewDecoder(req.Body).Decode(&request); err != nil { - return util.ErrorResponse(err) - } - if err := r.QueryMissingAuthPrevEvents(req.Context(), &request, &response); err != nil { - return util.ErrorResponse(err) - } - return util.JSONResponse{Code: http.StatusOK, JSON: &response} - }), - ) internalAPIMux.Handle( RoomserverQueryEventsByIDPath, httputil.MakeInternalAPI("queryEventsByID", func(req *http.Request) util.JSONResponse { diff --git a/roomserver/storage/postgres/event_json_table.go b/roomserver/storage/postgres/event_json_table.go index 433e445d..b3220eff 100644 --- a/roomserver/storage/postgres/event_json_table.go +++ b/roomserver/storage/postgres/event_json_table.go @@ -76,7 +76,8 @@ func prepareEventJSONTable(db *sql.DB) (tables.EventJSON, error) { func (s *eventJSONStatements) InsertEventJSON( ctx context.Context, txn *sql.Tx, eventNID types.EventNID, eventJSON []byte, ) error { - _, err := s.insertEventJSONStmt.ExecContext(ctx, int64(eventNID), eventJSON) + stmt := sqlutil.TxStmt(txn, s.insertEventJSONStmt) + _, err := stmt.ExecContext(ctx, int64(eventNID), eventJSON) return err } diff --git a/roomserver/storage/shared/room_updater.go b/roomserver/storage/shared/room_updater.go index fc75a260..89b878b9 100644 --- a/roomserver/storage/shared/room_updater.go +++ b/roomserver/storage/shared/room_updater.go @@ -16,6 +16,7 @@ type RoomUpdater struct { latestEvents []types.StateAtEventAndReference lastEventIDSent string currentStateSnapshotNID types.StateSnapshotNID + roomExists bool } func rollback(txn *sql.Tx) { @@ -33,7 +34,7 @@ func NewRoomUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomInfo *typ // succeed, processing a create event which creates the room, or it won't. if roomInfo == nil { return &RoomUpdater{ - transaction{ctx, txn}, d, nil, nil, "", 0, + transaction{ctx, txn}, d, nil, nil, "", 0, false, }, nil } @@ -57,10 +58,15 @@ func NewRoomUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomInfo *typ } } return &RoomUpdater{ - transaction{ctx, txn}, d, roomInfo, stateAndRefs, lastEventIDSent, currentStateSnapshotNID, + transaction{ctx, txn}, d, roomInfo, stateAndRefs, lastEventIDSent, currentStateSnapshotNID, true, }, nil } +// RoomExists returns true if the room exists and false otherwise. +func (u *RoomUpdater) RoomExists() bool { + return u.roomExists +} + // Implements sqlutil.Transaction func (u *RoomUpdater) Commit() error { if u.txn == nil { // SQLite mode probably @@ -97,6 +103,25 @@ func (u *RoomUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID { return u.currentStateSnapshotNID } +func (u *RoomUpdater) MissingAuthPrevEvents( + ctx context.Context, e *gomatrixserverlib.Event, +) (missingAuth, missingPrev []string, err error) { + for _, authEventID := range e.AuthEventIDs() { + if nids, err := u.EventNIDs(ctx, []string{authEventID}); err != nil || len(nids) == 0 { + missingAuth = append(missingAuth, authEventID) + } + } + + for _, prevEventID := range e.PrevEventIDs() { + state, err := u.StateAtEventIDs(ctx, []string{prevEventID}) + if err != nil || len(state) == 0 || (!state[0].IsCreate() && state[0].BeforeStateSnapshotNID == 0) { + missingPrev = append(missingPrev, prevEventID) + } + } + + return +} + // StorePreviousEvents implements types.RoomRecentEventsUpdater - This must be called from a Writer func (u *RoomUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error { return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 8319de26..e96c77af 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -553,7 +553,7 @@ func (d *Database) storeEvent( err error ) var txn *sql.Tx - if updater != nil { + if updater != nil && updater.txn != nil { txn = updater.txn } err = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error {