diff --git a/roomserver/internal/helpers/auth.go b/roomserver/internal/helpers/auth.go index 1f4215e7..e69f4c93 100644 --- a/roomserver/internal/helpers/auth.go +++ b/roomserver/internal/helpers/auth.go @@ -56,7 +56,7 @@ func CheckForSoftFail( // Then get the state entries for the current state snapshot. // We'll use this to check if the event is allowed right now. - roomState := state.NewStateResolution(db, *roomInfo) + roomState := state.NewStateResolution(db, *roomInfo, nil) authStateEntries, err = roomState.LoadStateAtSnapshot(ctx, roomInfo.StateSnapshotNID) if err != nil { return true, fmt.Errorf("roomState.LoadStateAtSnapshot: %w", err) diff --git a/roomserver/internal/helpers/helpers.go b/roomserver/internal/helpers/helpers.go index a829bffc..680628fb 100644 --- a/roomserver/internal/helpers/helpers.go +++ b/roomserver/internal/helpers/helpers.go @@ -172,7 +172,7 @@ func GetMembershipsAtState( } func StateBeforeEvent(ctx context.Context, db storage.Database, info types.RoomInfo, eventNID types.EventNID) ([]types.StateEntry, error) { - roomState := state.NewStateResolution(db, info) + roomState := state.NewStateResolution(db, info, nil) // Lookup the event NID eIDs, err := db.EventIDs(ctx, []types.EventNID{eventNID}) if err != nil { @@ -217,7 +217,7 @@ func LoadStateEvents( func CheckServerAllowedToSeeEvent( ctx context.Context, db storage.Database, info types.RoomInfo, eventID string, serverName gomatrixserverlib.ServerName, isServerInRoom bool, ) (bool, error) { - roomState := state.NewStateResolution(db, info) + roomState := state.NewStateResolution(db, info, nil) stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID) if err != nil { if errors.Is(err, sql.ErrNoRows) { @@ -379,7 +379,7 @@ func QueryLatestEventsAndState( return nil } - roomState := state.NewStateResolution(db, *roomInfo) + roomState := state.NewStateResolution(db, *roomInfo, nil) response.RoomExists = true response.RoomVersion = roomInfo.RoomVersion diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 2a558c48..46727c21 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -249,7 +249,7 @@ func (r *Inputer) calculateAndSetState( isRejected bool, ) error { var err error - roomState := state.NewStateResolution(r.DB, roomInfo) + roomState := state.NewStateResolution(r.DB, roomInfo, nil) if input.HasState && !isRejected { // Check here if we think we're in the room already. @@ -271,7 +271,7 @@ func (r *Inputer) calculateAndSetState( } entries = types.DeduplicateStateEntries(entries) - if stateAtEvent.BeforeStateSnapshotNID, err = r.DB.AddState(ctx, roomInfo.RoomNID, nil, entries); err != nil { + if stateAtEvent.BeforeStateSnapshotNID, err = r.DB.AddState(ctx, nil, roomInfo.RoomNID, nil, entries); err != nil { return fmt.Errorf("r.DB.AddState: %w", err) } } else { diff --git a/roomserver/internal/input/input_latest_events.go b/roomserver/internal/input/input_latest_events.go index c9264a27..f9b79670 100644 --- a/roomserver/internal/input/input_latest_events.go +++ b/roomserver/internal/input/input_latest_events.go @@ -199,7 +199,7 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error { func (u *latestEventsUpdater) latestState() error { var err error - roomState := state.NewStateResolution(u.api.DB, *u.roomInfo) + roomState := state.NewStateResolution(u.api.DB, *u.roomInfo, u.updater.Txn) // Work out if the state at the extremities has actually changed // or not. If they haven't then we won't bother doing all of the diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go index d9d720f2..6c33f3c8 100644 --- a/roomserver/internal/perform/perform_backfill.go +++ b/roomserver/internal/perform/perform_backfill.go @@ -146,7 +146,7 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform } var beforeStateSnapshotNID types.StateSnapshotNID - if beforeStateSnapshotNID, err = r.DB.AddState(ctx, roomNID, nil, entries); err != nil { + if beforeStateSnapshotNID, err = r.DB.AddState(ctx, nil, roomNID, nil, entries); err != nil { logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to persist state entries to get snapshot nid") return err } diff --git a/roomserver/internal/perform/perform_inbound_peek.go b/roomserver/internal/perform/perform_inbound_peek.go index eb3c9727..7fbc101f 100644 --- a/roomserver/internal/perform/perform_inbound_peek.go +++ b/roomserver/internal/perform/perform_inbound_peek.go @@ -79,7 +79,7 @@ func (r *InboundPeeker) PerformInboundPeek( response.LatestEvent = sortedLatestEvents[0].Headered(info.RoomVersion) // XXX: do we actually need to do a state resolution here? - roomState := state.NewStateResolution(r.DB, *info) + roomState := state.NewStateResolution(r.DB, *info, nil) var stateEntries []types.StateEntry stateEntries, err = roomState.LoadStateAtSnapshot( diff --git a/roomserver/internal/perform/perform_invite.go b/roomserver/internal/perform/perform_invite.go index fa65ce9b..7293a1e7 100644 --- a/roomserver/internal/perform/perform_invite.go +++ b/roomserver/internal/perform/perform_invite.go @@ -231,7 +231,7 @@ func buildInviteStrippedState( StateKey: "", }) } - roomState := state.NewStateResolution(db, *info) + roomState := state.NewStateResolution(db, *info, nil) stateEntries, err := roomState.LoadStateAtSnapshotForStringTuples( ctx, info.StateSnapshotNID, stateWanted, ) diff --git a/roomserver/internal/query/query.go b/roomserver/internal/query/query.go index f69f67f7..60581133 100644 --- a/roomserver/internal/query/query.go +++ b/roomserver/internal/query/query.go @@ -62,7 +62,7 @@ func (r *Queryer) QueryStateAfterEvents( return nil } - roomState := state.NewStateResolution(r.DB, *info) + roomState := state.NewStateResolution(r.DB, *info, nil) response.RoomExists = true response.RoomVersion = info.RoomVersion @@ -507,7 +507,7 @@ func (r *Queryer) QueryStateAndAuthChain( } func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo types.RoomInfo, eventIDs []string) ([]*gomatrixserverlib.Event, error) { - roomState := state.NewStateResolution(r.DB, roomInfo) + roomState := state.NewStateResolution(r.DB, roomInfo, nil) prevStates, err := r.DB.StateAtEventIDs(ctx, eventIDs) if err != nil { switch err.(type) { diff --git a/roomserver/state/state.go b/roomserver/state/state.go index 0d9511ac..a1243e16 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -18,6 +18,7 @@ package state import ( "context" + "database/sql" "fmt" "sort" "time" @@ -34,13 +35,15 @@ type StateResolution struct { db storage.Database roomInfo types.RoomInfo events map[types.EventNID]*gomatrixserverlib.Event + txn *sql.Tx // optional } -func NewStateResolution(db storage.Database, roomInfo types.RoomInfo) StateResolution { +func NewStateResolution(db storage.Database, roomInfo types.RoomInfo, tx *sql.Tx) StateResolution { return StateResolution{ db: db, roomInfo: roomInfo, events: make(map[types.EventNID]*gomatrixserverlib.Event), + txn: tx, } } @@ -549,7 +552,7 @@ func (v *StateResolution) CalculateAndStoreStateAfterEvents( // 2) There weren't any prev_events for this event so the state is // empty. metrics.algorithm = "empty_state" - stateNID, err := v.db.AddState(ctx, v.roomInfo.RoomNID, nil, nil) + stateNID, err := v.db.AddState(ctx, v.txn, v.roomInfo.RoomNID, nil, nil) if err != nil { err = fmt.Errorf("v.db.AddState: %w", err) } @@ -581,7 +584,7 @@ func (v *StateResolution) CalculateAndStoreStateAfterEvents( // add the state event as a block of size one to the end of the blocks. metrics.algorithm = "single_delta" stateNID, err := v.db.AddState( - ctx, v.roomInfo.RoomNID, stateBlockNIDs, []types.StateEntry{prevState.StateEntry}, + ctx, v.txn, v.roomInfo.RoomNID, stateBlockNIDs, []types.StateEntry{prevState.StateEntry}, ) if err != nil { err = fmt.Errorf("v.db.AddState: %w", err) @@ -626,7 +629,7 @@ func (v *StateResolution) calculateAndStoreStateAfterManyEvents( // previous state. metrics.conflictLength = conflictLength metrics.fullStateLength = len(state) - return metrics.stop(v.db.AddState(ctx, roomNID, nil, state)) + return metrics.stop(v.db.AddState(ctx, v.txn, roomNID, nil, state)) } func (v *StateResolution) calculateStateAfterManyEvents( diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index d2b0e75c..b3bc1188 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -16,6 +16,7 @@ package storage import ( "context" + "database/sql" "github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/storage/shared" @@ -32,6 +33,7 @@ type Database interface { // Store the room state at an event in the database AddState( ctx context.Context, + txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry, diff --git a/roomserver/storage/shared/latest_events_updater.go b/roomserver/storage/shared/latest_events_updater.go index 36865081..c2a9b313 100644 --- a/roomserver/storage/shared/latest_events_updater.go +++ b/roomserver/storage/shared/latest_events_updater.go @@ -16,6 +16,7 @@ type LatestEventsUpdater struct { latestEvents []types.StateAtEventAndReference lastEventIDSent string currentStateSnapshotNID types.StateSnapshotNID + Txn *sql.Tx } func rollback(txn *sql.Tx) { @@ -46,7 +47,7 @@ func NewLatestEventsUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomI } } return &LatestEventsUpdater{ - transaction{ctx, txn}, d, roomInfo, stateAndRefs, lastEventIDSent, currentStateSnapshotNID, + transaction{ctx, txn}, d, roomInfo, stateAndRefs, lastEventIDSent, currentStateSnapshotNID, txn, }, nil } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 24b48772..1689b882 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -137,11 +137,12 @@ func (d *Database) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo func (d *Database) AddState( ctx context.Context, + txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry, ) (stateNID types.StateSnapshotNID, err error) { - err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { + err = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error { if len(state) > 0 { var stateBlockNID types.StateBlockNID stateBlockNID, err = d.StateBlockTable.BulkInsertStateData(ctx, txn, state)