Full roomserver input transactional isolation (#2141)

* Add transaction to all database tables in roomserver, rename latest events updater to room updater, use room updater for all RS input

* Better transaction management

* Tweak order

* Handle cases where the room does not exist

* Other fixes

* More tweaks

* Fill some gaps

* Fill in the gaps

* good lord it gets worse

* Don't roll back transactions when events rejected

* Pass through errors properly

* Fix bugs

* Fix incorrect error check

* Don't panic on nil txns

* Tweaks

* Hopefully fix panics for good in SQLite this time

* Fix rollback

* Minor bug fixes with latest event updater

* Some review comments

* Revert "Some review comments"

This reverts commit 0caf8cf53e62c33f7b83c52e9df1d963871f751e.

* Fix a couple of bugs

* Clearer commit and rollback results

* Remove unnecessary prepares
This commit is contained in:
Neil Alexander 2022-02-04 10:39:34 +00:00 committed by GitHub
parent 4d9f5b2e57
commit eb352a5f6b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
35 changed files with 867 additions and 499 deletions

View file

@ -1,133 +0,0 @@
package shared
import (
"context"
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
)
type LatestEventsUpdater struct {
transaction
d *Database
roomInfo types.RoomInfo
latestEvents []types.StateAtEventAndReference
lastEventIDSent string
currentStateSnapshotNID types.StateSnapshotNID
}
func rollback(txn *sql.Tx) {
if txn == nil {
return
}
txn.Rollback() // nolint: errcheck
}
func NewLatestEventsUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomInfo types.RoomInfo) (*LatestEventsUpdater, error) {
eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err :=
d.RoomsTable.SelectLatestEventsNIDsForUpdate(ctx, txn, roomInfo.RoomNID)
if err != nil {
rollback(txn)
return nil, err
}
stateAndRefs, err := d.EventsTable.BulkSelectStateAtEventAndReference(ctx, txn, eventNIDs)
if err != nil {
rollback(txn)
return nil, err
}
var lastEventIDSent string
if lastEventNIDSent != 0 {
lastEventIDSent, err = d.EventsTable.SelectEventID(ctx, txn, lastEventNIDSent)
if err != nil {
rollback(txn)
return nil, err
}
}
return &LatestEventsUpdater{
transaction{ctx, txn}, d, roomInfo, stateAndRefs, lastEventIDSent, currentStateSnapshotNID,
}, nil
}
// RoomVersion implements types.RoomRecentEventsUpdater
func (u *LatestEventsUpdater) RoomVersion() (version gomatrixserverlib.RoomVersion) {
return u.roomInfo.RoomVersion
}
// LatestEvents implements types.RoomRecentEventsUpdater
func (u *LatestEventsUpdater) LatestEvents() []types.StateAtEventAndReference {
return u.latestEvents
}
// LastEventIDSent implements types.RoomRecentEventsUpdater
func (u *LatestEventsUpdater) LastEventIDSent() string {
return u.lastEventIDSent
}
// CurrentStateSnapshotNID implements types.RoomRecentEventsUpdater
func (u *LatestEventsUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID {
return u.currentStateSnapshotNID
}
// StorePreviousEvents implements types.RoomRecentEventsUpdater - This must be called from a Writer
func (u *LatestEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error {
for _, ref := range previousEventReferences {
if err := u.d.PrevEventsTable.InsertPreviousEvent(u.ctx, u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil {
return fmt.Errorf("u.d.PrevEventsTable.InsertPreviousEvent: %w", err)
}
}
return nil
}
// IsReferenced implements types.RoomRecentEventsUpdater
func (u *LatestEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) {
err := u.d.PrevEventsTable.SelectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256)
if err == nil {
return true, nil
}
if err == sql.ErrNoRows {
return false, nil
}
return false, fmt.Errorf("u.d.PrevEventsTable.SelectPreviousEventExists: %w", err)
}
// SetLatestEvents implements types.RoomRecentEventsUpdater
func (u *LatestEventsUpdater) SetLatestEvents(
roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID,
currentStateSnapshotNID types.StateSnapshotNID,
) error {
eventNIDs := make([]types.EventNID, len(latest))
for i := range latest {
eventNIDs[i] = latest[i].EventNID
}
return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
if err := u.d.RoomsTable.UpdateLatestEventNIDs(u.ctx, txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID); err != nil {
return fmt.Errorf("u.d.RoomsTable.updateLatestEventNIDs: %w", err)
}
if roomID, ok := u.d.Cache.GetRoomServerRoomID(roomNID); ok {
if roomInfo, ok := u.d.Cache.GetRoomInfo(roomID); ok {
roomInfo.StateSnapshotNID = currentStateSnapshotNID
roomInfo.IsStub = false
u.d.Cache.StoreRoomInfo(roomID, roomInfo)
}
}
return nil
})
}
// HasEventBeenSent implements types.RoomRecentEventsUpdater
func (u *LatestEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) {
return u.d.EventsTable.SelectEventSentToOutput(u.ctx, u.txn, eventNID)
}
// MarkEventAsSent implements types.RoomRecentEventsUpdater
func (u *LatestEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error {
return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
return u.d.EventsTable.UpdateEventSentToOutput(u.ctx, txn, eventNID)
})
}
func (u *LatestEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) {
return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal)
}

View file

@ -0,0 +1,262 @@
package shared
import (
"context"
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
)
type RoomUpdater struct {
transaction
d *Database
roomInfo *types.RoomInfo
latestEvents []types.StateAtEventAndReference
lastEventIDSent string
currentStateSnapshotNID types.StateSnapshotNID
}
func rollback(txn *sql.Tx) {
if txn == nil {
return
}
txn.Rollback() // nolint: errcheck
}
func NewRoomUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomInfo *types.RoomInfo) (*RoomUpdater, error) {
// If the roomInfo is nil then that means that the room doesn't exist
// yet, so we can't do `SelectLatestEventsNIDsForUpdate` because that
// would involve locking a row on the table that doesn't exist. Instead
// we will just run with a normal database transaction. It'll either
// succeed, processing a create event which creates the room, or it won't.
if roomInfo == nil {
return &RoomUpdater{
transaction{ctx, txn}, d, nil, nil, "", 0,
}, nil
}
eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err :=
d.RoomsTable.SelectLatestEventsNIDsForUpdate(ctx, txn, roomInfo.RoomNID)
if err != nil {
rollback(txn)
return nil, err
}
stateAndRefs, err := d.EventsTable.BulkSelectStateAtEventAndReference(ctx, txn, eventNIDs)
if err != nil {
rollback(txn)
return nil, err
}
var lastEventIDSent string
if lastEventNIDSent != 0 {
lastEventIDSent, err = d.EventsTable.SelectEventID(ctx, txn, lastEventNIDSent)
if err != nil {
rollback(txn)
return nil, err
}
}
return &RoomUpdater{
transaction{ctx, txn}, d, roomInfo, stateAndRefs, lastEventIDSent, currentStateSnapshotNID,
}, nil
}
// Implements sqlutil.Transaction
func (u *RoomUpdater) Commit() error {
if u.txn == nil { // SQLite mode probably
return nil
}
return u.txn.Commit()
}
// Implements sqlutil.Transaction
func (u *RoomUpdater) Rollback() error {
if u.txn == nil { // SQLite mode probably
return nil
}
return u.txn.Rollback()
}
// RoomVersion implements types.RoomRecentEventsUpdater
func (u *RoomUpdater) RoomVersion() (version gomatrixserverlib.RoomVersion) {
return u.roomInfo.RoomVersion
}
// LatestEvents implements types.RoomRecentEventsUpdater
func (u *RoomUpdater) LatestEvents() []types.StateAtEventAndReference {
return u.latestEvents
}
// LastEventIDSent implements types.RoomRecentEventsUpdater
func (u *RoomUpdater) LastEventIDSent() string {
return u.lastEventIDSent
}
// CurrentStateSnapshotNID implements types.RoomRecentEventsUpdater
func (u *RoomUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID {
return u.currentStateSnapshotNID
}
// StorePreviousEvents implements types.RoomRecentEventsUpdater - This must be called from a Writer
func (u *RoomUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error {
return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
for _, ref := range previousEventReferences {
if err := u.d.PrevEventsTable.InsertPreviousEvent(u.ctx, txn, ref.EventID, ref.EventSHA256, eventNID); err != nil {
return fmt.Errorf("u.d.PrevEventsTable.InsertPreviousEvent: %w", err)
}
}
return nil
})
}
func (u *RoomUpdater) Events(
ctx context.Context, eventNIDs []types.EventNID,
) ([]types.Event, error) {
return u.d.events(ctx, u.txn, eventNIDs)
}
func (u *RoomUpdater) SnapshotNIDFromEventID(
ctx context.Context, eventID string,
) (types.StateSnapshotNID, error) {
return u.d.snapshotNIDFromEventID(ctx, u.txn, eventID)
}
func (u *RoomUpdater) StoreEvent(
ctx context.Context, event *gomatrixserverlib.Event,
authEventNIDs []types.EventNID, isRejected bool,
) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) {
return u.d.storeEvent(ctx, u, event, authEventNIDs, isRejected)
}
func (u *RoomUpdater) StateBlockNIDs(
ctx context.Context, stateNIDs []types.StateSnapshotNID,
) ([]types.StateBlockNIDList, error) {
return u.d.stateBlockNIDs(ctx, u.txn, stateNIDs)
}
func (u *RoomUpdater) StateEntries(
ctx context.Context, stateBlockNIDs []types.StateBlockNID,
) ([]types.StateEntryList, error) {
return u.d.stateEntries(ctx, u.txn, stateBlockNIDs)
}
func (u *RoomUpdater) StateEntriesForTuples(
ctx context.Context,
stateBlockNIDs []types.StateBlockNID,
stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntryList, error) {
return u.d.stateEntriesForTuples(ctx, u.txn, stateBlockNIDs, stateKeyTuples)
}
func (u *RoomUpdater) AddState(
ctx context.Context,
roomNID types.RoomNID,
stateBlockNIDs []types.StateBlockNID,
state []types.StateEntry,
) (stateNID types.StateSnapshotNID, err error) {
return u.d.addState(ctx, u.txn, roomNID, stateBlockNIDs, state)
}
func (u *RoomUpdater) SetState(
ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID,
) error {
return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
return u.d.EventsTable.UpdateEventState(ctx, txn, eventNID, stateNID)
})
}
func (u *RoomUpdater) EventTypeNIDs(
ctx context.Context, eventTypes []string,
) (map[string]types.EventTypeNID, error) {
return u.d.eventTypeNIDs(ctx, u.txn, eventTypes)
}
func (u *RoomUpdater) EventStateKeyNIDs(
ctx context.Context, eventStateKeys []string,
) (map[string]types.EventStateKeyNID, error) {
return u.d.eventStateKeyNIDs(ctx, u.txn, eventStateKeys)
}
func (u *RoomUpdater) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) {
return u.d.roomInfo(ctx, u.txn, roomID)
}
func (u *RoomUpdater) EventIDs(
ctx context.Context, eventNIDs []types.EventNID,
) (map[types.EventNID]string, error) {
return u.d.EventsTable.BulkSelectEventID(ctx, u.txn, eventNIDs)
}
func (u *RoomUpdater) StateAtEventIDs(
ctx context.Context, eventIDs []string,
) ([]types.StateAtEvent, error) {
return u.d.EventsTable.BulkSelectStateAtEventByID(ctx, u.txn, eventIDs)
}
func (u *RoomUpdater) StateEntriesForEventIDs(
ctx context.Context, eventIDs []string,
) ([]types.StateEntry, error) {
return u.d.EventsTable.BulkSelectStateEventByID(ctx, u.txn, eventIDs)
}
func (u *RoomUpdater) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) {
return u.d.eventsFromIDs(ctx, u.txn, eventIDs)
}
func (u *RoomUpdater) GetMembershipEventNIDsForRoom(
ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool,
) ([]types.EventNID, error) {
return u.d.getMembershipEventNIDsForRoom(ctx, u.txn, roomNID, joinOnly, localOnly)
}
// IsReferenced implements types.RoomRecentEventsUpdater
func (u *RoomUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) {
err := u.d.PrevEventsTable.SelectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256)
if err == nil {
return true, nil
}
if err == sql.ErrNoRows {
return false, nil
}
return false, fmt.Errorf("u.d.PrevEventsTable.SelectPreviousEventExists: %w", err)
}
// SetLatestEvents implements types.RoomRecentEventsUpdater
func (u *RoomUpdater) SetLatestEvents(
roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID,
currentStateSnapshotNID types.StateSnapshotNID,
) error {
eventNIDs := make([]types.EventNID, len(latest))
for i := range latest {
eventNIDs[i] = latest[i].EventNID
}
return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
if err := u.d.RoomsTable.UpdateLatestEventNIDs(u.ctx, txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID); err != nil {
return fmt.Errorf("u.d.RoomsTable.updateLatestEventNIDs: %w", err)
}
if roomID, ok := u.d.Cache.GetRoomServerRoomID(roomNID); ok {
if roomInfo, ok := u.d.Cache.GetRoomInfo(roomID); ok {
roomInfo.StateSnapshotNID = currentStateSnapshotNID
roomInfo.IsStub = false
u.d.Cache.StoreRoomInfo(roomID, roomInfo)
}
}
return nil
})
}
// HasEventBeenSent implements types.RoomRecentEventsUpdater
func (u *RoomUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) {
return u.d.EventsTable.SelectEventSentToOutput(u.ctx, u.txn, eventNID)
}
// MarkEventAsSent implements types.RoomRecentEventsUpdater
func (u *RoomUpdater) MarkEventAsSent(eventNID types.EventNID) error {
return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
return u.d.EventsTable.UpdateEventSentToOutput(u.ctx, txn, eventNID)
})
}
func (u *RoomUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) {
return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal)
}

View file

@ -26,23 +26,23 @@ import (
const redactionsArePermanent = true
type Database struct {
DB *sql.DB
Cache caching.RoomServerCaches
Writer sqlutil.Writer
EventsTable tables.Events
EventJSONTable tables.EventJSON
EventTypesTable tables.EventTypes
EventStateKeysTable tables.EventStateKeys
RoomsTable tables.Rooms
StateSnapshotTable tables.StateSnapshot
StateBlockTable tables.StateBlock
RoomAliasesTable tables.RoomAliases
PrevEventsTable tables.PreviousEvents
InvitesTable tables.Invites
MembershipTable tables.Membership
PublishedTable tables.Published
RedactionsTable tables.Redactions
GetLatestEventsForUpdateFn func(ctx context.Context, roomInfo types.RoomInfo) (*LatestEventsUpdater, error)
DB *sql.DB
Cache caching.RoomServerCaches
Writer sqlutil.Writer
EventsTable tables.Events
EventJSONTable tables.EventJSON
EventTypesTable tables.EventTypes
EventStateKeysTable tables.EventStateKeys
RoomsTable tables.Rooms
StateSnapshotTable tables.StateSnapshot
StateBlockTable tables.StateBlock
RoomAliasesTable tables.RoomAliases
PrevEventsTable tables.PreviousEvents
InvitesTable tables.Invites
MembershipTable tables.Membership
PublishedTable tables.Published
RedactionsTable tables.Redactions
GetRoomUpdaterFn func(ctx context.Context, roomInfo *types.RoomInfo) (*RoomUpdater, error)
}
func (d *Database) SupportsConcurrentRoomInputs() bool {
@ -51,6 +51,12 @@ func (d *Database) SupportsConcurrentRoomInputs() bool {
func (d *Database) EventTypeNIDs(
ctx context.Context, eventTypes []string,
) (map[string]types.EventTypeNID, error) {
return d.eventTypeNIDs(ctx, nil, eventTypes)
}
func (d *Database) eventTypeNIDs(
ctx context.Context, txn *sql.Tx, eventTypes []string,
) (map[string]types.EventTypeNID, error) {
result := make(map[string]types.EventTypeNID)
remaining := []string{}
@ -62,7 +68,7 @@ func (d *Database) EventTypeNIDs(
}
}
if len(remaining) > 0 {
nids, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, remaining)
nids, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, txn, remaining)
if err != nil {
return nil, err
}
@ -77,11 +83,17 @@ func (d *Database) EventTypeNIDs(
func (d *Database) EventStateKeys(
ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
) (map[types.EventStateKeyNID]string, error) {
return d.EventStateKeysTable.BulkSelectEventStateKey(ctx, eventStateKeyNIDs)
return d.EventStateKeysTable.BulkSelectEventStateKey(ctx, nil, eventStateKeyNIDs)
}
func (d *Database) EventStateKeyNIDs(
ctx context.Context, eventStateKeys []string,
) (map[string]types.EventStateKeyNID, error) {
return d.eventStateKeyNIDs(ctx, nil, eventStateKeys)
}
func (d *Database) eventStateKeyNIDs(
ctx context.Context, txn *sql.Tx, eventStateKeys []string,
) (map[string]types.EventStateKeyNID, error) {
result := make(map[string]types.EventStateKeyNID)
remaining := []string{}
@ -93,7 +105,7 @@ func (d *Database) EventStateKeyNIDs(
}
}
if len(remaining) > 0 {
nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, remaining)
nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, txn, remaining)
if err != nil {
return nil, err
}
@ -108,23 +120,31 @@ func (d *Database) EventStateKeyNIDs(
func (d *Database) StateEntriesForEventIDs(
ctx context.Context, eventIDs []string,
) ([]types.StateEntry, error) {
return d.EventsTable.BulkSelectStateEventByID(ctx, eventIDs)
return d.EventsTable.BulkSelectStateEventByID(ctx, nil, eventIDs)
}
func (d *Database) StateEntriesForTuples(
ctx context.Context,
stateBlockNIDs []types.StateBlockNID,
stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntryList, error) {
return d.stateEntriesForTuples(ctx, nil, stateBlockNIDs, stateKeyTuples)
}
func (d *Database) stateEntriesForTuples(
ctx context.Context, txn *sql.Tx,
stateBlockNIDs []types.StateBlockNID,
stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntryList, error) {
entries, err := d.StateBlockTable.BulkSelectStateBlockEntries(
ctx, stateBlockNIDs,
ctx, txn, stateBlockNIDs,
)
if err != nil {
return nil, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", err)
}
lists := []types.StateEntryList{}
for i, entry := range entries {
entries, err := d.EventsTable.BulkSelectStateEventByNID(ctx, entry, stateKeyTuples)
entries, err := d.EventsTable.BulkSelectStateEventByNID(ctx, txn, entry, stateKeyTuples)
if err != nil {
return nil, fmt.Errorf("d.EventsTable.BulkSelectStateEventByNID: %w", err)
}
@ -137,10 +157,14 @@ func (d *Database) StateEntriesForTuples(
}
func (d *Database) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) {
return d.roomInfo(ctx, nil, roomID)
}
func (d *Database) roomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) {
if roomInfo, ok := d.Cache.GetRoomInfo(roomID); ok {
return &roomInfo, nil
}
roomInfo, err := d.RoomsTable.SelectRoomInfo(ctx, roomID)
roomInfo, err := d.RoomsTable.SelectRoomInfo(ctx, txn, roomID)
if err == nil && roomInfo != nil {
d.Cache.StoreRoomServerRoomID(roomInfo.RoomNID, roomID)
d.Cache.StoreRoomInfo(roomID, *roomInfo)
@ -153,13 +177,22 @@ func (d *Database) AddState(
roomNID types.RoomNID,
stateBlockNIDs []types.StateBlockNID,
state []types.StateEntry,
) (stateNID types.StateSnapshotNID, err error) {
return d.addState(ctx, nil, roomNID, stateBlockNIDs, state)
}
func (d *Database) addState(
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID,
stateBlockNIDs []types.StateBlockNID,
state []types.StateEntry,
) (stateNID types.StateSnapshotNID, err error) {
if len(stateBlockNIDs) > 0 && len(state) > 0 {
// Check to see if the event already appears in any of the existing state
// blocks. If it does then we should not add it again, as this will just
// result in excess state blocks and snapshots.
// TODO: Investigate why this is happening - probably input_events.go!
blocks, berr := d.StateBlockTable.BulkSelectStateBlockEntries(ctx, stateBlockNIDs)
blocks, berr := d.StateBlockTable.BulkSelectStateBlockEntries(ctx, txn, stateBlockNIDs)
if berr != nil {
return 0, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", berr)
}
@ -180,7 +213,7 @@ func (d *Database) AddState(
}
}
}
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
err = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error {
if len(state) > 0 {
// If there's any state left to add then let's add new blocks.
var stateBlockNID types.StateBlockNID
@ -205,7 +238,13 @@ func (d *Database) AddState(
func (d *Database) EventNIDs(
ctx context.Context, eventIDs []string,
) (map[string]types.EventNID, error) {
return d.EventsTable.BulkSelectEventNID(ctx, eventIDs)
return d.eventNIDs(ctx, nil, eventIDs)
}
func (d *Database) eventNIDs(
ctx context.Context, txn *sql.Tx, eventIDs []string,
) (map[string]types.EventNID, error) {
return d.EventsTable.BulkSelectEventNID(ctx, txn, eventIDs)
}
func (d *Database) SetState(
@ -219,24 +258,34 @@ func (d *Database) SetState(
func (d *Database) StateAtEventIDs(
ctx context.Context, eventIDs []string,
) ([]types.StateAtEvent, error) {
return d.EventsTable.BulkSelectStateAtEventByID(ctx, eventIDs)
return d.EventsTable.BulkSelectStateAtEventByID(ctx, nil, eventIDs)
}
func (d *Database) SnapshotNIDFromEventID(
ctx context.Context, eventID string,
) (types.StateSnapshotNID, error) {
_, stateNID, err := d.EventsTable.SelectEvent(ctx, nil, eventID)
return d.snapshotNIDFromEventID(ctx, nil, eventID)
}
func (d *Database) snapshotNIDFromEventID(
ctx context.Context, txn *sql.Tx, eventID string,
) (types.StateSnapshotNID, error) {
_, stateNID, err := d.EventsTable.SelectEvent(ctx, txn, eventID)
return stateNID, err
}
func (d *Database) EventIDs(
ctx context.Context, eventNIDs []types.EventNID,
) (map[types.EventNID]string, error) {
return d.EventsTable.BulkSelectEventID(ctx, eventNIDs)
return d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs)
}
func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) {
nidMap, err := d.EventNIDs(ctx, eventIDs)
return d.eventsFromIDs(ctx, nil, eventIDs)
}
func (d *Database) eventsFromIDs(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.Event, error) {
nidMap, err := d.eventNIDs(ctx, txn, eventIDs)
if err != nil {
return nil, err
}
@ -246,7 +295,7 @@ func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]type
nids = append(nids, nid)
}
return d.Events(ctx, nids)
return d.events(ctx, txn, nids)
}
func (d *Database) LatestEventIDs(
@ -271,21 +320,33 @@ func (d *Database) LatestEventIDs(
func (d *Database) StateBlockNIDs(
ctx context.Context, stateNIDs []types.StateSnapshotNID,
) ([]types.StateBlockNIDList, error) {
return d.StateSnapshotTable.BulkSelectStateBlockNIDs(ctx, stateNIDs)
return d.stateBlockNIDs(ctx, nil, stateNIDs)
}
func (d *Database) stateBlockNIDs(
ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID,
) ([]types.StateBlockNIDList, error) {
return d.StateSnapshotTable.BulkSelectStateBlockNIDs(ctx, txn, stateNIDs)
}
func (d *Database) StateEntries(
ctx context.Context, stateBlockNIDs []types.StateBlockNID,
) ([]types.StateEntryList, error) {
return d.stateEntries(ctx, nil, stateBlockNIDs)
}
func (d *Database) stateEntries(
ctx context.Context, txn *sql.Tx, stateBlockNIDs []types.StateBlockNID,
) ([]types.StateEntryList, error) {
entries, err := d.StateBlockTable.BulkSelectStateBlockEntries(
ctx, stateBlockNIDs,
ctx, txn, stateBlockNIDs,
)
if err != nil {
return nil, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", err)
}
lists := make([]types.StateEntryList, 0, len(entries))
for i, entry := range entries {
eventNIDs, err := d.EventsTable.BulkSelectStateEventByNID(ctx, entry, nil)
eventNIDs, err := d.EventsTable.BulkSelectStateEventByNID(ctx, txn, entry, nil)
if err != nil {
return nil, fmt.Errorf("d.EventsTable.BulkSelectStateEventByNID: %w", err)
}
@ -304,17 +365,17 @@ func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string
}
func (d *Database) GetRoomIDForAlias(ctx context.Context, alias string) (string, error) {
return d.RoomAliasesTable.SelectRoomIDFromAlias(ctx, alias)
return d.RoomAliasesTable.SelectRoomIDFromAlias(ctx, nil, alias)
}
func (d *Database) GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) {
return d.RoomAliasesTable.SelectAliasesFromRoomID(ctx, roomID)
return d.RoomAliasesTable.SelectAliasesFromRoomID(ctx, nil, roomID)
}
func (d *Database) GetCreatorIDForAlias(
ctx context.Context, alias string,
) (string, error) {
return d.RoomAliasesTable.SelectCreatorIDFromAlias(ctx, alias)
return d.RoomAliasesTable.SelectCreatorIDFromAlias(ctx, nil, alias)
}
func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error {
@ -335,7 +396,7 @@ func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, req
senderMembershipEventNID, senderMembership, isRoomforgotten, err :=
d.MembershipTable.SelectMembershipFromRoomAndTarget(
ctx, roomNID, requestSenderUserNID,
ctx, nil, roomNID, requestSenderUserNID,
)
if err == sql.ErrNoRows {
// The user has never been a member of that room
@ -349,14 +410,20 @@ func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, req
func (d *Database) GetMembershipEventNIDsForRoom(
ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool,
) ([]types.EventNID, error) {
return d.getMembershipEventNIDsForRoom(ctx, nil, roomNID, joinOnly, localOnly)
}
func (d *Database) getMembershipEventNIDsForRoom(
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, joinOnly bool, localOnly bool,
) ([]types.EventNID, error) {
if joinOnly {
return d.MembershipTable.SelectMembershipsFromRoomAndMembership(
ctx, roomNID, tables.MembershipStateJoin, localOnly,
ctx, txn, roomNID, tables.MembershipStateJoin, localOnly,
)
}
return d.MembershipTable.SelectMembershipsFromRoom(ctx, roomNID, localOnly)
return d.MembershipTable.SelectMembershipsFromRoom(ctx, txn, roomNID, localOnly)
}
func (d *Database) GetInvitesForUser(
@ -364,22 +431,28 @@ func (d *Database) GetInvitesForUser(
roomNID types.RoomNID,
targetUserNID types.EventStateKeyNID,
) (senderUserIDs []types.EventStateKeyNID, eventIDs []string, err error) {
return d.InvitesTable.SelectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID)
return d.InvitesTable.SelectInviteActiveForUserInRoom(ctx, nil, targetUserNID, roomNID)
}
func (d *Database) Events(
ctx context.Context, eventNIDs []types.EventNID,
) ([]types.Event, error) {
eventJSONs, err := d.EventJSONTable.BulkSelectEventJSON(ctx, eventNIDs)
return d.events(ctx, nil, eventNIDs)
}
func (d *Database) events(
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
) ([]types.Event, error) {
eventJSONs, err := d.EventJSONTable.BulkSelectEventJSON(ctx, txn, eventNIDs)
if err != nil {
return nil, err
}
eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, eventNIDs)
eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, txn, eventNIDs)
if err != nil {
eventIDs = map[types.EventNID]string{}
}
var roomNIDs map[types.EventNID]types.RoomNID
roomNIDs, err = d.EventsTable.SelectRoomNIDsForEventNIDs(ctx, eventNIDs)
roomNIDs, err = d.EventsTable.SelectRoomNIDsForEventNIDs(ctx, txn, eventNIDs)
if err != nil {
return nil, err
}
@ -398,7 +471,7 @@ func (d *Database) Events(
}
fetchNIDList = append(fetchNIDList, n)
}
dbRoomVersions, err := d.RoomsTable.SelectRoomVersionsForRoomNIDs(ctx, fetchNIDList)
dbRoomVersions, err := d.RoomsTable.SelectRoomVersionsForRoomNIDs(ctx, txn, fetchNIDList)
if err != nil {
return nil, err
}
@ -440,19 +513,19 @@ func (d *Database) MembershipUpdater(
return updater, err
}
func (d *Database) GetLatestEventsForUpdate(
ctx context.Context, roomInfo types.RoomInfo,
) (*LatestEventsUpdater, error) {
if d.GetLatestEventsForUpdateFn != nil {
return d.GetLatestEventsForUpdateFn(ctx, roomInfo)
func (d *Database) GetRoomUpdater(
ctx context.Context, roomInfo *types.RoomInfo,
) (*RoomUpdater, error) {
if d.GetRoomUpdaterFn != nil {
return d.GetRoomUpdaterFn(ctx, roomInfo)
}
txn, err := d.DB.Begin()
if err != nil {
return nil, err
}
var updater *LatestEventsUpdater
var updater *RoomUpdater
_ = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error {
updater, err = NewLatestEventsUpdater(ctx, d, txn, roomInfo)
updater, err = NewRoomUpdater(ctx, d, txn, roomInfo)
return err
})
return updater, err
@ -461,6 +534,13 @@ func (d *Database) GetLatestEventsForUpdate(
func (d *Database) StoreEvent(
ctx context.Context, event *gomatrixserverlib.Event,
authEventNIDs []types.EventNID, isRejected bool,
) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) {
return d.storeEvent(ctx, nil, event, authEventNIDs, isRejected)
}
func (d *Database) storeEvent(
ctx context.Context, updater *RoomUpdater, event *gomatrixserverlib.Event,
authEventNIDs []types.EventNID, isRejected bool,
) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) {
var (
roomNID types.RoomNID
@ -472,8 +552,11 @@ func (d *Database) StoreEvent(
redactedEventID string
err error
)
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
var txn *sql.Tx
if updater != nil {
txn = updater.txn
}
err = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error {
// TODO: Here we should aim to have two different code paths for new rooms
// vs existing ones.
@ -546,42 +629,32 @@ func (d *Database) StoreEvent(
// events updater because it somewhat works as a mutex, ensuring
// that there's a row-level lock on the latest room events (well,
// on Postgres at least).
var roomInfo *types.RoomInfo
var updater *LatestEventsUpdater
if prevEvents := event.PrevEvents(); len(prevEvents) > 0 {
roomInfo, err = d.RoomInfo(ctx, event.RoomID())
if err != nil {
return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.RoomInfo: %w", err)
}
if roomInfo == nil && len(prevEvents) > 0 {
return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("expected room %q to exist", event.RoomID())
}
// Create an updater - NB: on sqlite this WILL create a txn as we are directly calling the shared DB form of
// GetLatestEventsForUpdate - not via the SQLiteDatabase form which has `nil` txns. This
// function only does SELECTs though so the created txn (at this point) is just a read txn like
// any other so this is fine. If we ever update GetLatestEventsForUpdate or NewLatestEventsUpdater
// to do writes however then this will need to go inside `Writer.Do`.
updater, err = d.GetLatestEventsForUpdate(ctx, *roomInfo)
if err != nil {
return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("NewLatestEventsUpdater: %w", err)
}
// Ensure that we atomically store prev events AND commit them. If we don't wrap StorePreviousEvents
// and EndTransaction in a writer then it's possible for a new write txn to be made between the two
// function calls which will then fail with 'database is locked'. This new write txn would HAVE to be
// something like SetRoomAlias/RemoveRoomAlias as normal input events are already done sequentially due to
// SupportsConcurrentRoomInputs() == false on sqlite, though this does not apply to setting room aliases
// as they don't go via InputRoomEvents
err = d.Writer.Do(d.DB, updater.txn, func(txn *sql.Tx) error {
if err = updater.StorePreviousEvents(eventNID, prevEvents); err != nil {
return fmt.Errorf("updater.StorePreviousEvents: %w", err)
succeeded := false
if updater == nil {
var roomInfo *types.RoomInfo
roomInfo, err = d.RoomInfo(ctx, event.RoomID())
if err != nil {
return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.RoomInfo: %w", err)
}
succeeded := true
err = sqlutil.EndTransaction(updater, &succeeded)
return err
})
if err != nil {
return 0, 0, types.StateAtEvent{}, nil, "", err
if roomInfo == nil && len(prevEvents) > 0 {
return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("expected room %q to exist", event.RoomID())
}
updater, err = d.GetRoomUpdater(ctx, roomInfo)
if err != nil {
return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("GetRoomUpdater: %w", err)
}
defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err)
}
if err = updater.StorePreviousEvents(eventNID, prevEvents); err != nil {
return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("updater.StorePreviousEvents: %w", err)
}
succeeded = true
}
return eventNID, roomNID, types.StateAtEvent{
@ -603,7 +676,7 @@ func (d *Database) PublishRoom(ctx context.Context, roomID string, publish bool)
}
func (d *Database) GetPublishedRooms(ctx context.Context) ([]string, error) {
return d.PublishedTable.SelectAllPublishedRooms(ctx, true)
return d.PublishedTable.SelectAllPublishedRooms(ctx, nil, true)
}
func (d *Database) assignRoomNID(
@ -875,14 +948,14 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s
eventNIDs = append(eventNIDs, e.EventNID)
}
}
eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, eventNIDs)
eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs)
if err != nil {
eventIDs = map[types.EventNID]string{}
}
// return the event requested
for _, e := range entries {
if e.EventTypeNID == eventTypeNID && e.EventStateKeyNID == stateKeyNID {
data, err := d.EventJSONTable.BulkSelectEventJSON(ctx, []types.EventNID{e.EventNID})
data, err := d.EventJSONTable.BulkSelectEventJSON(ctx, nil, []types.EventNID{e.EventNID})
if err != nil {
return nil, err
}
@ -922,11 +995,11 @@ func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership
}
return nil, fmt.Errorf("GetRoomsByMembership: cannot map user ID to state key NID: %w", err)
}
roomNIDs, err := d.MembershipTable.SelectRoomsWithMembership(ctx, stateKeyNID, membershipState)
roomNIDs, err := d.MembershipTable.SelectRoomsWithMembership(ctx, nil, stateKeyNID, membershipState)
if err != nil {
return nil, fmt.Errorf("GetRoomsByMembership: failed to SelectRoomsWithMembership: %w", err)
}
roomIDs, err := d.RoomsTable.BulkSelectRoomIDs(ctx, roomNIDs)
roomIDs, err := d.RoomsTable.BulkSelectRoomIDs(ctx, nil, roomNIDs)
if err != nil {
return nil, fmt.Errorf("GetRoomsByMembership: failed to lookup room nids: %w", err)
}
@ -945,7 +1018,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
}
// we don't bother failing the request if we get asked for event types we don't know about, as all that would result in is no matches which
// isn't a failure.
eventTypeNIDMap, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, eventTypes)
eventTypeNIDMap, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, nil, eventTypes)
if err != nil {
return nil, fmt.Errorf("GetBulkStateContent: failed to map event type nids: %w", err)
}
@ -965,7 +1038,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
}
eventStateKeyNIDMap, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, eventStateKeys)
eventStateKeyNIDMap, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, nil, eventStateKeys)
if err != nil {
return nil, fmt.Errorf("GetBulkStateContent: failed to map state key nids: %w", err)
}
@ -999,11 +1072,11 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
}
}
}
eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, eventNIDs)
eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs)
if err != nil {
eventIDs = map[types.EventNID]string{}
}
events, err := d.EventJSONTable.BulkSelectEventJSON(ctx, eventNIDs)
events, err := d.EventJSONTable.BulkSelectEventJSON(ctx, nil, eventNIDs)
if err != nil {
return nil, fmt.Errorf("GetBulkStateContent: failed to load event JSON for event nids: %w", err)
}
@ -1027,11 +1100,11 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
// JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear.
func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) {
roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, roomIDs)
roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, nil, roomIDs)
if err != nil {
return nil, err
}
userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, roomNIDs)
userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, nil, roomNIDs)
if err != nil {
return nil, err
}
@ -1041,7 +1114,7 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string)
stateKeyNIDs[i] = nid
i++
}
nidToUserID, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, stateKeyNIDs)
nidToUserID, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, nil, stateKeyNIDs)
if err != nil {
return nil, err
}
@ -1057,12 +1130,12 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string)
// GetLocalServerInRoom returns true if we think we're in a given room or false otherwise.
func (d *Database) GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) {
return d.MembershipTable.SelectLocalServerInRoom(ctx, roomNID)
return d.MembershipTable.SelectLocalServerInRoom(ctx, nil, roomNID)
}
// GetServerInRoom returns true if we think a server is in a given room or false otherwise.
func (d *Database) GetServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) {
return d.MembershipTable.SelectServerInRoom(ctx, roomNID, serverName)
return d.MembershipTable.SelectServerInRoom(ctx, nil, roomNID, serverName)
}
// GetKnownUsers searches all users that userID knows about.
@ -1071,17 +1144,17 @@ func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString strin
if err != nil {
return nil, err
}
return d.MembershipTable.SelectKnownUsers(ctx, stateKeyNID, searchString, limit)
return d.MembershipTable.SelectKnownUsers(ctx, nil, stateKeyNID, searchString, limit)
}
// GetKnownRooms returns a list of all rooms we know about.
func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) {
return d.RoomsTable.SelectRoomIDs(ctx)
return d.RoomsTable.SelectRoomIDs(ctx, nil)
}
// ForgetRoom sets a users room to forgotten
func (d *Database) ForgetRoom(ctx context.Context, userID, roomID string, forget bool) error {
roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, []string{roomID})
roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, nil, []string{roomID})
if err != nil {
return err
}