Full roomserver input transactional isolation (#2141)

* Add transaction to all database tables in roomserver, rename latest events updater to room updater, use room updater for all RS input

* Better transaction management

* Tweak order

* Handle cases where the room does not exist

* Other fixes

* More tweaks

* Fill some gaps

* Fill in the gaps

* good lord it gets worse

* Don't roll back transactions when events rejected

* Pass through errors properly

* Fix bugs

* Fix incorrect error check

* Don't panic on nil txns

* Tweaks

* Hopefully fix panics for good in SQLite this time

* Fix rollback

* Minor bug fixes with latest event updater

* Some review comments

* Revert "Some review comments"

This reverts commit 0caf8cf53e62c33f7b83c52e9df1d963871f751e.

* Fix a couple of bugs

* Clearer commit and rollback results

* Remove unnecessary prepares
This commit is contained in:
Neil Alexander 2022-02-04 10:39:34 +00:00 committed by GitHub
parent 4d9f5b2e57
commit eb352a5f6b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
35 changed files with 867 additions and 499 deletions

View file

@ -86,11 +86,10 @@ type Database interface {
// Lookup the event IDs for a batch of event numeric IDs.
// Returns an error if the retrieval went wrong.
EventIDs(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error)
// Look up the latest events in a room in preparation for an update.
// The RoomRecentEventsUpdater must have Commit or Rollback called on it if this doesn't return an error.
// Returns the latest events in the room and the last eventID sent to the log along with an updater.
// Opens and returns a room updater, which locks the room and opens a transaction.
// The GetRoomUpdater must have Commit or Rollback called on it if this doesn't return an error.
// If this returns an error then no further action is required.
GetLatestEventsForUpdate(ctx context.Context, roomInfo types.RoomInfo) (*shared.LatestEventsUpdater, error)
GetRoomUpdater(ctx context.Context, roomInfo *types.RoomInfo) (*shared.RoomUpdater, error)
// Look up event references for the latest events in the room and the current state snapshot.
// Returns the latest events, the current state and the maximum depth of the latest events plus 1.
// Returns an error if there was a problem talking to the database.

View file

@ -81,9 +81,10 @@ func (s *eventJSONStatements) InsertEventJSON(
}
func (s *eventJSONStatements) BulkSelectEventJSON(
ctx context.Context, eventNIDs []types.EventNID,
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
) ([]tables.EventJSONPair, error) {
rows, err := s.bulkSelectEventJSONStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
stmt := sqlutil.TxStmt(txn, s.bulkSelectEventJSONStmt)
rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
if err != nil {
return nil, err
}

View file

@ -111,9 +111,10 @@ func (s *eventStateKeyStatements) SelectEventStateKeyNID(
}
func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID(
ctx context.Context, eventStateKeys []string,
ctx context.Context, txn *sql.Tx, eventStateKeys []string,
) (map[string]types.EventStateKeyNID, error) {
rows, err := s.bulkSelectEventStateKeyNIDStmt.QueryContext(
stmt := sqlutil.TxStmt(txn, s.bulkSelectEventStateKeyNIDStmt)
rows, err := stmt.QueryContext(
ctx, pq.StringArray(eventStateKeys),
)
if err != nil {
@ -134,13 +135,14 @@ func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID(
}
func (s *eventStateKeyStatements) BulkSelectEventStateKey(
ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID,
) (map[types.EventStateKeyNID]string, error) {
nIDs := make(pq.Int64Array, len(eventStateKeyNIDs))
for i := range eventStateKeyNIDs {
nIDs[i] = int64(eventStateKeyNIDs[i])
}
rows, err := s.bulkSelectEventStateKeyStmt.QueryContext(ctx, nIDs)
stmt := sqlutil.TxStmt(txn, s.bulkSelectEventStateKeyStmt)
rows, err := stmt.QueryContext(ctx, nIDs)
if err != nil {
return nil, err
}

View file

@ -133,9 +133,10 @@ func (s *eventTypeStatements) SelectEventTypeNID(
}
func (s *eventTypeStatements) BulkSelectEventTypeNID(
ctx context.Context, eventTypes []string,
ctx context.Context, txn *sql.Tx, eventTypes []string,
) (map[string]types.EventTypeNID, error) {
rows, err := s.bulkSelectEventTypeNIDStmt.QueryContext(ctx, pq.StringArray(eventTypes))
stmt := sqlutil.TxStmt(txn, s.bulkSelectEventTypeNIDStmt)
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventTypes))
if err != nil {
return nil, err
}

View file

@ -212,9 +212,10 @@ func (s *eventStatements) SelectEvent(
// bulkSelectStateEventByID lookups a list of state events by event ID.
// If any of the requested events are missing from the database it returns a types.MissingEventError
func (s *eventStatements) BulkSelectStateEventByID(
ctx context.Context, eventIDs []string,
ctx context.Context, txn *sql.Tx, eventIDs []string,
) ([]types.StateEntry, error) {
rows, err := s.bulkSelectStateEventByIDStmt.QueryContext(ctx, pq.StringArray(eventIDs))
stmt := sqlutil.TxStmt(txn, s.bulkSelectStateEventByIDStmt)
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs))
if err != nil {
return nil, err
}
@ -254,13 +255,14 @@ func (s *eventStatements) BulkSelectStateEventByID(
// bulkSelectStateEventByNID lookups a list of state events by event NID.
// If any of the requested events are missing from the database it returns a types.MissingEventError
func (s *eventStatements) BulkSelectStateEventByNID(
ctx context.Context, eventNIDs []types.EventNID,
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntry, error) {
tuples := stateKeyTupleSorter(stateKeyTuples)
sort.Sort(tuples)
eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays()
rows, err := s.bulkSelectStateEventByNIDStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs), eventTypeNIDArray, eventStateKeyNIDArray)
stmt := sqlutil.TxStmt(txn, s.bulkSelectStateEventByNIDStmt)
rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs), eventTypeNIDArray, eventStateKeyNIDArray)
if err != nil {
return nil, err
}
@ -291,9 +293,10 @@ func (s *eventStatements) BulkSelectStateEventByNID(
// If any of the requested events are missing from the database it returns a types.MissingEventError.
// If we do not have the state for any of the requested events it returns a types.MissingEventError.
func (s *eventStatements) BulkSelectStateAtEventByID(
ctx context.Context, eventIDs []string,
ctx context.Context, txn *sql.Tx, eventIDs []string,
) ([]types.StateAtEvent, error) {
rows, err := s.bulkSelectStateAtEventByIDStmt.QueryContext(ctx, pq.StringArray(eventIDs))
stmt := sqlutil.TxStmt(txn, s.bulkSelectStateAtEventByIDStmt)
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs))
if err != nil {
return nil, err
}
@ -428,8 +431,9 @@ func (s *eventStatements) BulkSelectEventReference(
}
// bulkSelectEventID returns a map from numeric event ID to string event ID.
func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) {
rows, err := s.bulkSelectEventIDStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error) {
stmt := sqlutil.TxStmt(txn, s.bulkSelectEventIDStmt)
rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
if err != nil {
return nil, err
}
@ -455,8 +459,9 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []typ
// bulkSelectEventNIDs returns a map from string event ID to numeric event ID.
// If an event ID is not in the database then it is omitted from the map.
func (s *eventStatements) BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) {
rows, err := s.bulkSelectEventNIDStmt.QueryContext(ctx, pq.StringArray(eventIDs))
func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) {
stmt := sqlutil.TxStmt(txn, s.bulkSelectEventNIDStmt)
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs))
if err != nil {
return nil, err
}
@ -484,9 +489,10 @@ func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx,
}
func (s *eventStatements) SelectRoomNIDsForEventNIDs(
ctx context.Context, eventNIDs []types.EventNID,
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
) (map[types.EventNID]types.RoomNID, error) {
rows, err := s.selectRoomNIDsForEventNIDsStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
stmt := sqlutil.TxStmt(txn, s.selectRoomNIDsForEventNIDsStmt)
rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
if err != nil {
return nil, err
}

View file

@ -97,8 +97,8 @@ func prepareInvitesTable(db *sql.DB) (tables.Invites, error) {
}
func (s *inviteStatements) InsertInviteEvent(
ctx context.Context,
txn *sql.Tx, inviteEventID string, roomNID types.RoomNID,
ctx context.Context, txn *sql.Tx,
inviteEventID string, roomNID types.RoomNID,
targetUserNID, senderUserNID types.EventStateKeyNID,
inviteEventJSON []byte,
) (bool, error) {
@ -116,8 +116,8 @@ func (s *inviteStatements) InsertInviteEvent(
}
func (s *inviteStatements) UpdateInviteRetired(
ctx context.Context,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) ([]string, error) {
stmt := sqlutil.TxStmt(txn, s.updateInviteRetiredStmt)
rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID)
@ -139,10 +139,11 @@ func (s *inviteStatements) UpdateInviteRetired(
// SelectInviteActiveForUserInRoom returns a list of sender state key NIDs
func (s *inviteStatements) SelectInviteActiveForUserInRoom(
ctx context.Context,
ctx context.Context, txn *sql.Tx,
targetUserNID types.EventStateKeyNID, roomNID types.RoomNID,
) ([]types.EventStateKeyNID, []string, error) {
rows, err := s.selectInviteActiveForUserInRoomStmt.QueryContext(
stmt := sqlutil.TxStmt(txn, s.selectInviteActiveForUserInRoomStmt)
rows, err := stmt.QueryContext(
ctx, targetUserNID, roomNID,
)
if err != nil {

View file

@ -186,8 +186,8 @@ func prepareMembershipTable(db *sql.DB) (tables.Membership, error) {
}
func (s *membershipStatements) InsertMembership(
ctx context.Context,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
localTarget bool,
) error {
stmt := sqlutil.TxStmt(txn, s.insertMembershipStmt)
@ -196,8 +196,8 @@ func (s *membershipStatements) InsertMembership(
}
func (s *membershipStatements) SelectMembershipForUpdate(
ctx context.Context,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (membership tables.MembershipState, err error) {
err = sqlutil.TxStmt(txn, s.selectMembershipForUpdateStmt).QueryRowContext(
ctx, roomNID, targetUserNID,
@ -206,17 +206,19 @@ func (s *membershipStatements) SelectMembershipForUpdate(
}
func (s *membershipStatements) SelectMembershipFromRoomAndTarget(
ctx context.Context,
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (eventNID types.EventNID, membership tables.MembershipState, forgotten bool, err error) {
err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext(
stmt := sqlutil.TxStmt(txn, s.selectMembershipFromRoomAndTargetStmt)
err = stmt.QueryRowContext(
ctx, roomNID, targetUserNID,
).Scan(&membership, &eventNID, &forgotten)
return
}
func (s *membershipStatements) SelectMembershipsFromRoom(
ctx context.Context, roomNID types.RoomNID, localOnly bool,
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, localOnly bool,
) (eventNIDs []types.EventNID, err error) {
var stmt *sql.Stmt
if localOnly {
@ -224,6 +226,7 @@ func (s *membershipStatements) SelectMembershipsFromRoom(
} else {
stmt = s.selectMembershipsFromRoomStmt
}
stmt = sqlutil.TxStmt(txn, stmt)
rows, err := stmt.QueryContext(ctx, roomNID)
if err != nil {
return
@ -241,7 +244,7 @@ func (s *membershipStatements) SelectMembershipsFromRoom(
}
func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
ctx context.Context,
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, membership tables.MembershipState, localOnly bool,
) (eventNIDs []types.EventNID, err error) {
var rows *sql.Rows
@ -251,6 +254,7 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
} else {
stmt = s.selectMembershipsFromRoomAndMembershipStmt
}
stmt = sqlutil.TxStmt(txn, stmt)
rows, err = stmt.QueryContext(ctx, roomNID, membership)
if err != nil {
return
@ -268,8 +272,8 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
}
func (s *membershipStatements) UpdateMembership(
ctx context.Context,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership tables.MembershipState,
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership tables.MembershipState,
eventNID types.EventNID, forgotten bool,
) error {
_, err := sqlutil.TxStmt(txn, s.updateMembershipStmt).ExecContext(
@ -279,9 +283,11 @@ func (s *membershipStatements) UpdateMembership(
}
func (s *membershipStatements) SelectRoomsWithMembership(
ctx context.Context, userID types.EventStateKeyNID, membershipState tables.MembershipState,
ctx context.Context, txn *sql.Tx,
userID types.EventStateKeyNID, membershipState tables.MembershipState,
) ([]types.RoomNID, error) {
rows, err := s.selectRoomsWithMembershipStmt.QueryContext(ctx, membershipState, userID)
stmt := sqlutil.TxStmt(txn, s.selectRoomsWithMembershipStmt)
rows, err := stmt.QueryContext(ctx, membershipState, userID)
if err != nil {
return nil, err
}
@ -297,12 +303,16 @@ func (s *membershipStatements) SelectRoomsWithMembership(
return roomNIDs, nil
}
func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) {
func (s *membershipStatements) SelectJoinedUsersSetForRooms(
ctx context.Context, txn *sql.Tx,
roomNIDs []types.RoomNID,
) (map[types.EventStateKeyNID]int, error) {
roomIDarray := make([]int64, len(roomNIDs))
for i := range roomNIDs {
roomIDarray[i] = int64(roomNIDs[i])
}
rows, err := s.selectJoinedUsersSetForRoomsStmt.QueryContext(ctx, pq.Int64Array(roomIDarray))
stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsStmt)
rows, err := stmt.QueryContext(ctx, pq.Int64Array(roomIDarray))
if err != nil {
return nil, err
}
@ -319,8 +329,12 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context,
return result, rows.Err()
}
func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) {
rows, err := s.selectKnownUsersStmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit)
func (s *membershipStatements) SelectKnownUsers(
ctx context.Context, txn *sql.Tx,
userID types.EventStateKeyNID, searchString string, limit int,
) ([]string, error) {
stmt := sqlutil.TxStmt(txn, s.selectKnownUsersStmt)
rows, err := stmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit)
if err != nil {
return nil, err
}
@ -337,9 +351,8 @@ func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID type
}
func (s *membershipStatements) UpdateForgetMembership(
ctx context.Context,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
forget bool,
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool,
) error {
_, err := sqlutil.TxStmt(txn, s.updateMembershipForgetRoomStmt).ExecContext(
ctx, roomNID, targetUserNID, forget,
@ -347,9 +360,13 @@ func (s *membershipStatements) UpdateForgetMembership(
return err
}
func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) {
func (s *membershipStatements) SelectLocalServerInRoom(
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID,
) (bool, error) {
var nid types.RoomNID
err := s.selectLocalServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid)
stmt := sqlutil.TxStmt(txn, s.selectLocalServerInRoomStmt)
err := stmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid)
if err != nil {
if err == sql.ErrNoRows {
return false, nil
@ -360,9 +377,13 @@ func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, room
return found, nil
}
func (s *membershipStatements) SelectServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) {
func (s *membershipStatements) SelectServerInRoom(
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, serverName gomatrixserverlib.ServerName,
) (bool, error) {
var nid types.RoomNID
err := s.selectServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid)
stmt := sqlutil.TxStmt(txn, s.selectServerInRoomStmt)
err := stmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid)
if err != nil {
if err == sql.ErrNoRows {
return false, nil

View file

@ -73,9 +73,10 @@ func (s *publishedStatements) UpsertRoomPublished(
}
func (s *publishedStatements) SelectPublishedFromRoomID(
ctx context.Context, roomID string,
ctx context.Context, txn *sql.Tx, roomID string,
) (published bool, err error) {
err = s.selectPublishedStmt.QueryRowContext(ctx, roomID).Scan(&published)
stmt := sqlutil.TxStmt(txn, s.selectPublishedStmt)
err = stmt.QueryRowContext(ctx, roomID).Scan(&published)
if err == sql.ErrNoRows {
return false, nil
}
@ -83,9 +84,10 @@ func (s *publishedStatements) SelectPublishedFromRoomID(
}
func (s *publishedStatements) SelectAllPublishedRooms(
ctx context.Context, published bool,
ctx context.Context, txn *sql.Tx, published bool,
) ([]string, error) {
rows, err := s.selectAllPublishedStmt.QueryContext(ctx, published)
stmt := sqlutil.TxStmt(txn, s.selectAllPublishedStmt)
rows, err := stmt.QueryContext(ctx, published)
if err != nil {
return nil, err
}

View file

@ -87,9 +87,10 @@ func (s *roomAliasesStatements) InsertRoomAlias(
}
func (s *roomAliasesStatements) SelectRoomIDFromAlias(
ctx context.Context, alias string,
ctx context.Context, txn *sql.Tx, alias string,
) (roomID string, err error) {
err = s.selectRoomIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&roomID)
stmt := sqlutil.TxStmt(txn, s.selectRoomIDFromAliasStmt)
err = stmt.QueryRowContext(ctx, alias).Scan(&roomID)
if err == sql.ErrNoRows {
return "", nil
}
@ -97,9 +98,10 @@ func (s *roomAliasesStatements) SelectRoomIDFromAlias(
}
func (s *roomAliasesStatements) SelectAliasesFromRoomID(
ctx context.Context, roomID string,
ctx context.Context, txn *sql.Tx, roomID string,
) ([]string, error) {
rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID)
stmt := sqlutil.TxStmt(txn, s.selectAliasesFromRoomIDStmt)
rows, err := stmt.QueryContext(ctx, roomID)
if err != nil {
return nil, err
}
@ -118,9 +120,10 @@ func (s *roomAliasesStatements) SelectAliasesFromRoomID(
}
func (s *roomAliasesStatements) SelectCreatorIDFromAlias(
ctx context.Context, alias string,
ctx context.Context, txn *sql.Tx, alias string,
) (creatorID string, err error) {
err = s.selectCreatorIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&creatorID)
stmt := sqlutil.TxStmt(txn, s.selectCreatorIDFromAliasStmt)
err = stmt.QueryRowContext(ctx, alias).Scan(&creatorID)
if err == sql.ErrNoRows {
return "", nil
}

View file

@ -117,8 +117,9 @@ func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
}.Prepare(db)
}
func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) {
rows, err := s.selectRoomIDsStmt.QueryContext(ctx)
func (s *roomStatements) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error) {
stmt := sqlutil.TxStmt(txn, s.selectRoomIDsStmt)
rows, err := stmt.QueryContext(ctx)
if err != nil {
return nil, err
}
@ -143,10 +144,11 @@ func (s *roomStatements) InsertRoomNID(
return types.RoomNID(roomNID), err
}
func (s *roomStatements) SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) {
func (s *roomStatements) SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) {
var info types.RoomInfo
var latestNIDs pq.Int64Array
err := s.selectRoomInfoStmt.QueryRowContext(ctx, roomID).Scan(
stmt := sqlutil.TxStmt(txn, s.selectRoomInfoStmt)
err := stmt.QueryRowContext(ctx, roomID).Scan(
&info.RoomVersion, &info.RoomNID, &info.StateSnapshotNID, &latestNIDs,
)
if err == sql.ErrNoRows {
@ -170,7 +172,7 @@ func (s *roomStatements) SelectLatestEventNIDs(
) ([]types.EventNID, types.StateSnapshotNID, error) {
var nids pq.Int64Array
var stateSnapshotNID int64
stmt := s.selectLatestEventNIDsStmt
stmt := sqlutil.TxStmt(txn, s.selectLatestEventNIDsStmt)
err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nids, &stateSnapshotNID)
if err != nil {
return nil, 0, err
@ -220,9 +222,10 @@ func (s *roomStatements) UpdateLatestEventNIDs(
}
func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
ctx context.Context, roomNIDs []types.RoomNID,
ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID,
) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) {
rows, err := s.selectRoomVersionsForRoomNIDsStmt.QueryContext(ctx, roomNIDsAsArray(roomNIDs))
stmt := sqlutil.TxStmt(txn, s.selectRoomVersionsForRoomNIDsStmt)
rows, err := stmt.QueryContext(ctx, roomNIDsAsArray(roomNIDs))
if err != nil {
return nil, err
}
@ -239,12 +242,13 @@ func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
return result, nil
}
func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) {
func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) ([]string, error) {
var array pq.Int64Array
for _, nid := range roomNIDs {
array = append(array, int64(nid))
}
rows, err := s.bulkSelectRoomIDsStmt.QueryContext(ctx, array)
stmt := sqlutil.TxStmt(txn, s.bulkSelectRoomIDsStmt)
rows, err := stmt.QueryContext(ctx, array)
if err != nil {
return nil, err
}
@ -260,12 +264,13 @@ func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types
return roomIDs, nil
}
func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error) {
func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, roomIDs []string) ([]types.RoomNID, error) {
var array pq.StringArray
for _, roomID := range roomIDs {
array = append(array, roomID)
}
rows, err := s.bulkSelectRoomNIDsStmt.QueryContext(ctx, array)
stmt := sqlutil.TxStmt(txn, s.bulkSelectRoomNIDsStmt)
rows, err := stmt.QueryContext(ctx, array)
if err != nil {
return nil, err
}

View file

@ -86,8 +86,7 @@ func prepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
}
func (s *stateBlockStatements) BulkInsertStateData(
ctx context.Context,
txn *sql.Tx,
ctx context.Context, txn *sql.Tx,
entries types.StateEntries,
) (id types.StateBlockNID, err error) {
entries = entries[:util.SortAndUnique(entries)]
@ -95,16 +94,18 @@ func (s *stateBlockStatements) BulkInsertStateData(
for _, e := range entries {
nids = append(nids, e.EventNID)
}
err = s.insertStateDataStmt.QueryRowContext(
stmt := sqlutil.TxStmt(txn, s.insertStateDataStmt)
err = stmt.QueryRowContext(
ctx, nids.Hash(), eventNIDsAsArray(nids),
).Scan(&id)
return
}
func (s *stateBlockStatements) BulkSelectStateBlockEntries(
ctx context.Context, stateBlockNIDs types.StateBlockNIDs,
ctx context.Context, txn *sql.Tx, stateBlockNIDs types.StateBlockNIDs,
) ([][]types.EventNID, error) {
rows, err := s.bulkSelectStateBlockEntriesStmt.QueryContext(ctx, stateBlockNIDsAsArray(stateBlockNIDs))
stmt := sqlutil.TxStmt(txn, s.bulkSelectStateBlockEntriesStmt)
rows, err := stmt.QueryContext(ctx, stateBlockNIDsAsArray(stateBlockNIDs))
if err != nil {
return nil, err
}

View file

@ -105,13 +105,14 @@ func (s *stateSnapshotStatements) InsertState(
}
func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
ctx context.Context, stateNIDs []types.StateSnapshotNID,
ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID,
) ([]types.StateBlockNIDList, error) {
nids := make([]int64, len(stateNIDs))
for i := range stateNIDs {
nids[i] = int64(stateNIDs[i])
}
rows, err := s.bulkSelectStateBlockNIDsStmt.QueryContext(ctx, pq.Int64Array(nids))
stmt := sqlutil.TxStmt(txn, s.bulkSelectStateBlockNIDsStmt)
rows, err := stmt.QueryContext(ctx, pq.Int64Array(nids))
if err != nil {
return nil, err
}

View file

@ -1,133 +0,0 @@
package shared
import (
"context"
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
)
type LatestEventsUpdater struct {
transaction
d *Database
roomInfo types.RoomInfo
latestEvents []types.StateAtEventAndReference
lastEventIDSent string
currentStateSnapshotNID types.StateSnapshotNID
}
func rollback(txn *sql.Tx) {
if txn == nil {
return
}
txn.Rollback() // nolint: errcheck
}
func NewLatestEventsUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomInfo types.RoomInfo) (*LatestEventsUpdater, error) {
eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err :=
d.RoomsTable.SelectLatestEventsNIDsForUpdate(ctx, txn, roomInfo.RoomNID)
if err != nil {
rollback(txn)
return nil, err
}
stateAndRefs, err := d.EventsTable.BulkSelectStateAtEventAndReference(ctx, txn, eventNIDs)
if err != nil {
rollback(txn)
return nil, err
}
var lastEventIDSent string
if lastEventNIDSent != 0 {
lastEventIDSent, err = d.EventsTable.SelectEventID(ctx, txn, lastEventNIDSent)
if err != nil {
rollback(txn)
return nil, err
}
}
return &LatestEventsUpdater{
transaction{ctx, txn}, d, roomInfo, stateAndRefs, lastEventIDSent, currentStateSnapshotNID,
}, nil
}
// RoomVersion implements types.RoomRecentEventsUpdater
func (u *LatestEventsUpdater) RoomVersion() (version gomatrixserverlib.RoomVersion) {
return u.roomInfo.RoomVersion
}
// LatestEvents implements types.RoomRecentEventsUpdater
func (u *LatestEventsUpdater) LatestEvents() []types.StateAtEventAndReference {
return u.latestEvents
}
// LastEventIDSent implements types.RoomRecentEventsUpdater
func (u *LatestEventsUpdater) LastEventIDSent() string {
return u.lastEventIDSent
}
// CurrentStateSnapshotNID implements types.RoomRecentEventsUpdater
func (u *LatestEventsUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID {
return u.currentStateSnapshotNID
}
// StorePreviousEvents implements types.RoomRecentEventsUpdater - This must be called from a Writer
func (u *LatestEventsUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error {
for _, ref := range previousEventReferences {
if err := u.d.PrevEventsTable.InsertPreviousEvent(u.ctx, u.txn, ref.EventID, ref.EventSHA256, eventNID); err != nil {
return fmt.Errorf("u.d.PrevEventsTable.InsertPreviousEvent: %w", err)
}
}
return nil
}
// IsReferenced implements types.RoomRecentEventsUpdater
func (u *LatestEventsUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) {
err := u.d.PrevEventsTable.SelectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256)
if err == nil {
return true, nil
}
if err == sql.ErrNoRows {
return false, nil
}
return false, fmt.Errorf("u.d.PrevEventsTable.SelectPreviousEventExists: %w", err)
}
// SetLatestEvents implements types.RoomRecentEventsUpdater
func (u *LatestEventsUpdater) SetLatestEvents(
roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID,
currentStateSnapshotNID types.StateSnapshotNID,
) error {
eventNIDs := make([]types.EventNID, len(latest))
for i := range latest {
eventNIDs[i] = latest[i].EventNID
}
return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
if err := u.d.RoomsTable.UpdateLatestEventNIDs(u.ctx, txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID); err != nil {
return fmt.Errorf("u.d.RoomsTable.updateLatestEventNIDs: %w", err)
}
if roomID, ok := u.d.Cache.GetRoomServerRoomID(roomNID); ok {
if roomInfo, ok := u.d.Cache.GetRoomInfo(roomID); ok {
roomInfo.StateSnapshotNID = currentStateSnapshotNID
roomInfo.IsStub = false
u.d.Cache.StoreRoomInfo(roomID, roomInfo)
}
}
return nil
})
}
// HasEventBeenSent implements types.RoomRecentEventsUpdater
func (u *LatestEventsUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) {
return u.d.EventsTable.SelectEventSentToOutput(u.ctx, u.txn, eventNID)
}
// MarkEventAsSent implements types.RoomRecentEventsUpdater
func (u *LatestEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error {
return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
return u.d.EventsTable.UpdateEventSentToOutput(u.ctx, txn, eventNID)
})
}
func (u *LatestEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) {
return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal)
}

View file

@ -0,0 +1,262 @@
package shared
import (
"context"
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
)
type RoomUpdater struct {
transaction
d *Database
roomInfo *types.RoomInfo
latestEvents []types.StateAtEventAndReference
lastEventIDSent string
currentStateSnapshotNID types.StateSnapshotNID
}
func rollback(txn *sql.Tx) {
if txn == nil {
return
}
txn.Rollback() // nolint: errcheck
}
func NewRoomUpdater(ctx context.Context, d *Database, txn *sql.Tx, roomInfo *types.RoomInfo) (*RoomUpdater, error) {
// If the roomInfo is nil then that means that the room doesn't exist
// yet, so we can't do `SelectLatestEventsNIDsForUpdate` because that
// would involve locking a row on the table that doesn't exist. Instead
// we will just run with a normal database transaction. It'll either
// succeed, processing a create event which creates the room, or it won't.
if roomInfo == nil {
return &RoomUpdater{
transaction{ctx, txn}, d, nil, nil, "", 0,
}, nil
}
eventNIDs, lastEventNIDSent, currentStateSnapshotNID, err :=
d.RoomsTable.SelectLatestEventsNIDsForUpdate(ctx, txn, roomInfo.RoomNID)
if err != nil {
rollback(txn)
return nil, err
}
stateAndRefs, err := d.EventsTable.BulkSelectStateAtEventAndReference(ctx, txn, eventNIDs)
if err != nil {
rollback(txn)
return nil, err
}
var lastEventIDSent string
if lastEventNIDSent != 0 {
lastEventIDSent, err = d.EventsTable.SelectEventID(ctx, txn, lastEventNIDSent)
if err != nil {
rollback(txn)
return nil, err
}
}
return &RoomUpdater{
transaction{ctx, txn}, d, roomInfo, stateAndRefs, lastEventIDSent, currentStateSnapshotNID,
}, nil
}
// Implements sqlutil.Transaction
func (u *RoomUpdater) Commit() error {
if u.txn == nil { // SQLite mode probably
return nil
}
return u.txn.Commit()
}
// Implements sqlutil.Transaction
func (u *RoomUpdater) Rollback() error {
if u.txn == nil { // SQLite mode probably
return nil
}
return u.txn.Rollback()
}
// RoomVersion implements types.RoomRecentEventsUpdater
func (u *RoomUpdater) RoomVersion() (version gomatrixserverlib.RoomVersion) {
return u.roomInfo.RoomVersion
}
// LatestEvents implements types.RoomRecentEventsUpdater
func (u *RoomUpdater) LatestEvents() []types.StateAtEventAndReference {
return u.latestEvents
}
// LastEventIDSent implements types.RoomRecentEventsUpdater
func (u *RoomUpdater) LastEventIDSent() string {
return u.lastEventIDSent
}
// CurrentStateSnapshotNID implements types.RoomRecentEventsUpdater
func (u *RoomUpdater) CurrentStateSnapshotNID() types.StateSnapshotNID {
return u.currentStateSnapshotNID
}
// StorePreviousEvents implements types.RoomRecentEventsUpdater - This must be called from a Writer
func (u *RoomUpdater) StorePreviousEvents(eventNID types.EventNID, previousEventReferences []gomatrixserverlib.EventReference) error {
return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
for _, ref := range previousEventReferences {
if err := u.d.PrevEventsTable.InsertPreviousEvent(u.ctx, txn, ref.EventID, ref.EventSHA256, eventNID); err != nil {
return fmt.Errorf("u.d.PrevEventsTable.InsertPreviousEvent: %w", err)
}
}
return nil
})
}
func (u *RoomUpdater) Events(
ctx context.Context, eventNIDs []types.EventNID,
) ([]types.Event, error) {
return u.d.events(ctx, u.txn, eventNIDs)
}
func (u *RoomUpdater) SnapshotNIDFromEventID(
ctx context.Context, eventID string,
) (types.StateSnapshotNID, error) {
return u.d.snapshotNIDFromEventID(ctx, u.txn, eventID)
}
func (u *RoomUpdater) StoreEvent(
ctx context.Context, event *gomatrixserverlib.Event,
authEventNIDs []types.EventNID, isRejected bool,
) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) {
return u.d.storeEvent(ctx, u, event, authEventNIDs, isRejected)
}
func (u *RoomUpdater) StateBlockNIDs(
ctx context.Context, stateNIDs []types.StateSnapshotNID,
) ([]types.StateBlockNIDList, error) {
return u.d.stateBlockNIDs(ctx, u.txn, stateNIDs)
}
func (u *RoomUpdater) StateEntries(
ctx context.Context, stateBlockNIDs []types.StateBlockNID,
) ([]types.StateEntryList, error) {
return u.d.stateEntries(ctx, u.txn, stateBlockNIDs)
}
func (u *RoomUpdater) StateEntriesForTuples(
ctx context.Context,
stateBlockNIDs []types.StateBlockNID,
stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntryList, error) {
return u.d.stateEntriesForTuples(ctx, u.txn, stateBlockNIDs, stateKeyTuples)
}
func (u *RoomUpdater) AddState(
ctx context.Context,
roomNID types.RoomNID,
stateBlockNIDs []types.StateBlockNID,
state []types.StateEntry,
) (stateNID types.StateSnapshotNID, err error) {
return u.d.addState(ctx, u.txn, roomNID, stateBlockNIDs, state)
}
func (u *RoomUpdater) SetState(
ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID,
) error {
return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
return u.d.EventsTable.UpdateEventState(ctx, txn, eventNID, stateNID)
})
}
func (u *RoomUpdater) EventTypeNIDs(
ctx context.Context, eventTypes []string,
) (map[string]types.EventTypeNID, error) {
return u.d.eventTypeNIDs(ctx, u.txn, eventTypes)
}
func (u *RoomUpdater) EventStateKeyNIDs(
ctx context.Context, eventStateKeys []string,
) (map[string]types.EventStateKeyNID, error) {
return u.d.eventStateKeyNIDs(ctx, u.txn, eventStateKeys)
}
func (u *RoomUpdater) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) {
return u.d.roomInfo(ctx, u.txn, roomID)
}
func (u *RoomUpdater) EventIDs(
ctx context.Context, eventNIDs []types.EventNID,
) (map[types.EventNID]string, error) {
return u.d.EventsTable.BulkSelectEventID(ctx, u.txn, eventNIDs)
}
func (u *RoomUpdater) StateAtEventIDs(
ctx context.Context, eventIDs []string,
) ([]types.StateAtEvent, error) {
return u.d.EventsTable.BulkSelectStateAtEventByID(ctx, u.txn, eventIDs)
}
func (u *RoomUpdater) StateEntriesForEventIDs(
ctx context.Context, eventIDs []string,
) ([]types.StateEntry, error) {
return u.d.EventsTable.BulkSelectStateEventByID(ctx, u.txn, eventIDs)
}
func (u *RoomUpdater) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) {
return u.d.eventsFromIDs(ctx, u.txn, eventIDs)
}
func (u *RoomUpdater) GetMembershipEventNIDsForRoom(
ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool,
) ([]types.EventNID, error) {
return u.d.getMembershipEventNIDsForRoom(ctx, u.txn, roomNID, joinOnly, localOnly)
}
// IsReferenced implements types.RoomRecentEventsUpdater
func (u *RoomUpdater) IsReferenced(eventReference gomatrixserverlib.EventReference) (bool, error) {
err := u.d.PrevEventsTable.SelectPreviousEventExists(u.ctx, u.txn, eventReference.EventID, eventReference.EventSHA256)
if err == nil {
return true, nil
}
if err == sql.ErrNoRows {
return false, nil
}
return false, fmt.Errorf("u.d.PrevEventsTable.SelectPreviousEventExists: %w", err)
}
// SetLatestEvents implements types.RoomRecentEventsUpdater
func (u *RoomUpdater) SetLatestEvents(
roomNID types.RoomNID, latest []types.StateAtEventAndReference, lastEventNIDSent types.EventNID,
currentStateSnapshotNID types.StateSnapshotNID,
) error {
eventNIDs := make([]types.EventNID, len(latest))
for i := range latest {
eventNIDs[i] = latest[i].EventNID
}
return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
if err := u.d.RoomsTable.UpdateLatestEventNIDs(u.ctx, txn, roomNID, eventNIDs, lastEventNIDSent, currentStateSnapshotNID); err != nil {
return fmt.Errorf("u.d.RoomsTable.updateLatestEventNIDs: %w", err)
}
if roomID, ok := u.d.Cache.GetRoomServerRoomID(roomNID); ok {
if roomInfo, ok := u.d.Cache.GetRoomInfo(roomID); ok {
roomInfo.StateSnapshotNID = currentStateSnapshotNID
roomInfo.IsStub = false
u.d.Cache.StoreRoomInfo(roomID, roomInfo)
}
}
return nil
})
}
// HasEventBeenSent implements types.RoomRecentEventsUpdater
func (u *RoomUpdater) HasEventBeenSent(eventNID types.EventNID) (bool, error) {
return u.d.EventsTable.SelectEventSentToOutput(u.ctx, u.txn, eventNID)
}
// MarkEventAsSent implements types.RoomRecentEventsUpdater
func (u *RoomUpdater) MarkEventAsSent(eventNID types.EventNID) error {
return u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
return u.d.EventsTable.UpdateEventSentToOutput(u.ctx, txn, eventNID)
})
}
func (u *RoomUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (*MembershipUpdater, error) {
return u.d.membershipUpdaterTxn(u.ctx, u.txn, u.roomInfo.RoomNID, targetUserNID, targetLocal)
}

View file

@ -26,23 +26,23 @@ import (
const redactionsArePermanent = true
type Database struct {
DB *sql.DB
Cache caching.RoomServerCaches
Writer sqlutil.Writer
EventsTable tables.Events
EventJSONTable tables.EventJSON
EventTypesTable tables.EventTypes
EventStateKeysTable tables.EventStateKeys
RoomsTable tables.Rooms
StateSnapshotTable tables.StateSnapshot
StateBlockTable tables.StateBlock
RoomAliasesTable tables.RoomAliases
PrevEventsTable tables.PreviousEvents
InvitesTable tables.Invites
MembershipTable tables.Membership
PublishedTable tables.Published
RedactionsTable tables.Redactions
GetLatestEventsForUpdateFn func(ctx context.Context, roomInfo types.RoomInfo) (*LatestEventsUpdater, error)
DB *sql.DB
Cache caching.RoomServerCaches
Writer sqlutil.Writer
EventsTable tables.Events
EventJSONTable tables.EventJSON
EventTypesTable tables.EventTypes
EventStateKeysTable tables.EventStateKeys
RoomsTable tables.Rooms
StateSnapshotTable tables.StateSnapshot
StateBlockTable tables.StateBlock
RoomAliasesTable tables.RoomAliases
PrevEventsTable tables.PreviousEvents
InvitesTable tables.Invites
MembershipTable tables.Membership
PublishedTable tables.Published
RedactionsTable tables.Redactions
GetRoomUpdaterFn func(ctx context.Context, roomInfo *types.RoomInfo) (*RoomUpdater, error)
}
func (d *Database) SupportsConcurrentRoomInputs() bool {
@ -51,6 +51,12 @@ func (d *Database) SupportsConcurrentRoomInputs() bool {
func (d *Database) EventTypeNIDs(
ctx context.Context, eventTypes []string,
) (map[string]types.EventTypeNID, error) {
return d.eventTypeNIDs(ctx, nil, eventTypes)
}
func (d *Database) eventTypeNIDs(
ctx context.Context, txn *sql.Tx, eventTypes []string,
) (map[string]types.EventTypeNID, error) {
result := make(map[string]types.EventTypeNID)
remaining := []string{}
@ -62,7 +68,7 @@ func (d *Database) EventTypeNIDs(
}
}
if len(remaining) > 0 {
nids, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, remaining)
nids, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, txn, remaining)
if err != nil {
return nil, err
}
@ -77,11 +83,17 @@ func (d *Database) EventTypeNIDs(
func (d *Database) EventStateKeys(
ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
) (map[types.EventStateKeyNID]string, error) {
return d.EventStateKeysTable.BulkSelectEventStateKey(ctx, eventStateKeyNIDs)
return d.EventStateKeysTable.BulkSelectEventStateKey(ctx, nil, eventStateKeyNIDs)
}
func (d *Database) EventStateKeyNIDs(
ctx context.Context, eventStateKeys []string,
) (map[string]types.EventStateKeyNID, error) {
return d.eventStateKeyNIDs(ctx, nil, eventStateKeys)
}
func (d *Database) eventStateKeyNIDs(
ctx context.Context, txn *sql.Tx, eventStateKeys []string,
) (map[string]types.EventStateKeyNID, error) {
result := make(map[string]types.EventStateKeyNID)
remaining := []string{}
@ -93,7 +105,7 @@ func (d *Database) EventStateKeyNIDs(
}
}
if len(remaining) > 0 {
nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, remaining)
nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, txn, remaining)
if err != nil {
return nil, err
}
@ -108,23 +120,31 @@ func (d *Database) EventStateKeyNIDs(
func (d *Database) StateEntriesForEventIDs(
ctx context.Context, eventIDs []string,
) ([]types.StateEntry, error) {
return d.EventsTable.BulkSelectStateEventByID(ctx, eventIDs)
return d.EventsTable.BulkSelectStateEventByID(ctx, nil, eventIDs)
}
func (d *Database) StateEntriesForTuples(
ctx context.Context,
stateBlockNIDs []types.StateBlockNID,
stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntryList, error) {
return d.stateEntriesForTuples(ctx, nil, stateBlockNIDs, stateKeyTuples)
}
func (d *Database) stateEntriesForTuples(
ctx context.Context, txn *sql.Tx,
stateBlockNIDs []types.StateBlockNID,
stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntryList, error) {
entries, err := d.StateBlockTable.BulkSelectStateBlockEntries(
ctx, stateBlockNIDs,
ctx, txn, stateBlockNIDs,
)
if err != nil {
return nil, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", err)
}
lists := []types.StateEntryList{}
for i, entry := range entries {
entries, err := d.EventsTable.BulkSelectStateEventByNID(ctx, entry, stateKeyTuples)
entries, err := d.EventsTable.BulkSelectStateEventByNID(ctx, txn, entry, stateKeyTuples)
if err != nil {
return nil, fmt.Errorf("d.EventsTable.BulkSelectStateEventByNID: %w", err)
}
@ -137,10 +157,14 @@ func (d *Database) StateEntriesForTuples(
}
func (d *Database) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) {
return d.roomInfo(ctx, nil, roomID)
}
func (d *Database) roomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) {
if roomInfo, ok := d.Cache.GetRoomInfo(roomID); ok {
return &roomInfo, nil
}
roomInfo, err := d.RoomsTable.SelectRoomInfo(ctx, roomID)
roomInfo, err := d.RoomsTable.SelectRoomInfo(ctx, txn, roomID)
if err == nil && roomInfo != nil {
d.Cache.StoreRoomServerRoomID(roomInfo.RoomNID, roomID)
d.Cache.StoreRoomInfo(roomID, *roomInfo)
@ -153,13 +177,22 @@ func (d *Database) AddState(
roomNID types.RoomNID,
stateBlockNIDs []types.StateBlockNID,
state []types.StateEntry,
) (stateNID types.StateSnapshotNID, err error) {
return d.addState(ctx, nil, roomNID, stateBlockNIDs, state)
}
func (d *Database) addState(
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID,
stateBlockNIDs []types.StateBlockNID,
state []types.StateEntry,
) (stateNID types.StateSnapshotNID, err error) {
if len(stateBlockNIDs) > 0 && len(state) > 0 {
// Check to see if the event already appears in any of the existing state
// blocks. If it does then we should not add it again, as this will just
// result in excess state blocks and snapshots.
// TODO: Investigate why this is happening - probably input_events.go!
blocks, berr := d.StateBlockTable.BulkSelectStateBlockEntries(ctx, stateBlockNIDs)
blocks, berr := d.StateBlockTable.BulkSelectStateBlockEntries(ctx, txn, stateBlockNIDs)
if berr != nil {
return 0, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", berr)
}
@ -180,7 +213,7 @@ func (d *Database) AddState(
}
}
}
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
err = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error {
if len(state) > 0 {
// If there's any state left to add then let's add new blocks.
var stateBlockNID types.StateBlockNID
@ -205,7 +238,13 @@ func (d *Database) AddState(
func (d *Database) EventNIDs(
ctx context.Context, eventIDs []string,
) (map[string]types.EventNID, error) {
return d.EventsTable.BulkSelectEventNID(ctx, eventIDs)
return d.eventNIDs(ctx, nil, eventIDs)
}
func (d *Database) eventNIDs(
ctx context.Context, txn *sql.Tx, eventIDs []string,
) (map[string]types.EventNID, error) {
return d.EventsTable.BulkSelectEventNID(ctx, txn, eventIDs)
}
func (d *Database) SetState(
@ -219,24 +258,34 @@ func (d *Database) SetState(
func (d *Database) StateAtEventIDs(
ctx context.Context, eventIDs []string,
) ([]types.StateAtEvent, error) {
return d.EventsTable.BulkSelectStateAtEventByID(ctx, eventIDs)
return d.EventsTable.BulkSelectStateAtEventByID(ctx, nil, eventIDs)
}
func (d *Database) SnapshotNIDFromEventID(
ctx context.Context, eventID string,
) (types.StateSnapshotNID, error) {
_, stateNID, err := d.EventsTable.SelectEvent(ctx, nil, eventID)
return d.snapshotNIDFromEventID(ctx, nil, eventID)
}
func (d *Database) snapshotNIDFromEventID(
ctx context.Context, txn *sql.Tx, eventID string,
) (types.StateSnapshotNID, error) {
_, stateNID, err := d.EventsTable.SelectEvent(ctx, txn, eventID)
return stateNID, err
}
func (d *Database) EventIDs(
ctx context.Context, eventNIDs []types.EventNID,
) (map[types.EventNID]string, error) {
return d.EventsTable.BulkSelectEventID(ctx, eventNIDs)
return d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs)
}
func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) {
nidMap, err := d.EventNIDs(ctx, eventIDs)
return d.eventsFromIDs(ctx, nil, eventIDs)
}
func (d *Database) eventsFromIDs(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.Event, error) {
nidMap, err := d.eventNIDs(ctx, txn, eventIDs)
if err != nil {
return nil, err
}
@ -246,7 +295,7 @@ func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]type
nids = append(nids, nid)
}
return d.Events(ctx, nids)
return d.events(ctx, txn, nids)
}
func (d *Database) LatestEventIDs(
@ -271,21 +320,33 @@ func (d *Database) LatestEventIDs(
func (d *Database) StateBlockNIDs(
ctx context.Context, stateNIDs []types.StateSnapshotNID,
) ([]types.StateBlockNIDList, error) {
return d.StateSnapshotTable.BulkSelectStateBlockNIDs(ctx, stateNIDs)
return d.stateBlockNIDs(ctx, nil, stateNIDs)
}
func (d *Database) stateBlockNIDs(
ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID,
) ([]types.StateBlockNIDList, error) {
return d.StateSnapshotTable.BulkSelectStateBlockNIDs(ctx, txn, stateNIDs)
}
func (d *Database) StateEntries(
ctx context.Context, stateBlockNIDs []types.StateBlockNID,
) ([]types.StateEntryList, error) {
return d.stateEntries(ctx, nil, stateBlockNIDs)
}
func (d *Database) stateEntries(
ctx context.Context, txn *sql.Tx, stateBlockNIDs []types.StateBlockNID,
) ([]types.StateEntryList, error) {
entries, err := d.StateBlockTable.BulkSelectStateBlockEntries(
ctx, stateBlockNIDs,
ctx, txn, stateBlockNIDs,
)
if err != nil {
return nil, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", err)
}
lists := make([]types.StateEntryList, 0, len(entries))
for i, entry := range entries {
eventNIDs, err := d.EventsTable.BulkSelectStateEventByNID(ctx, entry, nil)
eventNIDs, err := d.EventsTable.BulkSelectStateEventByNID(ctx, txn, entry, nil)
if err != nil {
return nil, fmt.Errorf("d.EventsTable.BulkSelectStateEventByNID: %w", err)
}
@ -304,17 +365,17 @@ func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string
}
func (d *Database) GetRoomIDForAlias(ctx context.Context, alias string) (string, error) {
return d.RoomAliasesTable.SelectRoomIDFromAlias(ctx, alias)
return d.RoomAliasesTable.SelectRoomIDFromAlias(ctx, nil, alias)
}
func (d *Database) GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error) {
return d.RoomAliasesTable.SelectAliasesFromRoomID(ctx, roomID)
return d.RoomAliasesTable.SelectAliasesFromRoomID(ctx, nil, roomID)
}
func (d *Database) GetCreatorIDForAlias(
ctx context.Context, alias string,
) (string, error) {
return d.RoomAliasesTable.SelectCreatorIDFromAlias(ctx, alias)
return d.RoomAliasesTable.SelectCreatorIDFromAlias(ctx, nil, alias)
}
func (d *Database) RemoveRoomAlias(ctx context.Context, alias string) error {
@ -335,7 +396,7 @@ func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, req
senderMembershipEventNID, senderMembership, isRoomforgotten, err :=
d.MembershipTable.SelectMembershipFromRoomAndTarget(
ctx, roomNID, requestSenderUserNID,
ctx, nil, roomNID, requestSenderUserNID,
)
if err == sql.ErrNoRows {
// The user has never been a member of that room
@ -349,14 +410,20 @@ func (d *Database) GetMembership(ctx context.Context, roomNID types.RoomNID, req
func (d *Database) GetMembershipEventNIDsForRoom(
ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool,
) ([]types.EventNID, error) {
return d.getMembershipEventNIDsForRoom(ctx, nil, roomNID, joinOnly, localOnly)
}
func (d *Database) getMembershipEventNIDsForRoom(
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, joinOnly bool, localOnly bool,
) ([]types.EventNID, error) {
if joinOnly {
return d.MembershipTable.SelectMembershipsFromRoomAndMembership(
ctx, roomNID, tables.MembershipStateJoin, localOnly,
ctx, txn, roomNID, tables.MembershipStateJoin, localOnly,
)
}
return d.MembershipTable.SelectMembershipsFromRoom(ctx, roomNID, localOnly)
return d.MembershipTable.SelectMembershipsFromRoom(ctx, txn, roomNID, localOnly)
}
func (d *Database) GetInvitesForUser(
@ -364,22 +431,28 @@ func (d *Database) GetInvitesForUser(
roomNID types.RoomNID,
targetUserNID types.EventStateKeyNID,
) (senderUserIDs []types.EventStateKeyNID, eventIDs []string, err error) {
return d.InvitesTable.SelectInviteActiveForUserInRoom(ctx, targetUserNID, roomNID)
return d.InvitesTable.SelectInviteActiveForUserInRoom(ctx, nil, targetUserNID, roomNID)
}
func (d *Database) Events(
ctx context.Context, eventNIDs []types.EventNID,
) ([]types.Event, error) {
eventJSONs, err := d.EventJSONTable.BulkSelectEventJSON(ctx, eventNIDs)
return d.events(ctx, nil, eventNIDs)
}
func (d *Database) events(
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
) ([]types.Event, error) {
eventJSONs, err := d.EventJSONTable.BulkSelectEventJSON(ctx, txn, eventNIDs)
if err != nil {
return nil, err
}
eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, eventNIDs)
eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, txn, eventNIDs)
if err != nil {
eventIDs = map[types.EventNID]string{}
}
var roomNIDs map[types.EventNID]types.RoomNID
roomNIDs, err = d.EventsTable.SelectRoomNIDsForEventNIDs(ctx, eventNIDs)
roomNIDs, err = d.EventsTable.SelectRoomNIDsForEventNIDs(ctx, txn, eventNIDs)
if err != nil {
return nil, err
}
@ -398,7 +471,7 @@ func (d *Database) Events(
}
fetchNIDList = append(fetchNIDList, n)
}
dbRoomVersions, err := d.RoomsTable.SelectRoomVersionsForRoomNIDs(ctx, fetchNIDList)
dbRoomVersions, err := d.RoomsTable.SelectRoomVersionsForRoomNIDs(ctx, txn, fetchNIDList)
if err != nil {
return nil, err
}
@ -440,19 +513,19 @@ func (d *Database) MembershipUpdater(
return updater, err
}
func (d *Database) GetLatestEventsForUpdate(
ctx context.Context, roomInfo types.RoomInfo,
) (*LatestEventsUpdater, error) {
if d.GetLatestEventsForUpdateFn != nil {
return d.GetLatestEventsForUpdateFn(ctx, roomInfo)
func (d *Database) GetRoomUpdater(
ctx context.Context, roomInfo *types.RoomInfo,
) (*RoomUpdater, error) {
if d.GetRoomUpdaterFn != nil {
return d.GetRoomUpdaterFn(ctx, roomInfo)
}
txn, err := d.DB.Begin()
if err != nil {
return nil, err
}
var updater *LatestEventsUpdater
var updater *RoomUpdater
_ = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error {
updater, err = NewLatestEventsUpdater(ctx, d, txn, roomInfo)
updater, err = NewRoomUpdater(ctx, d, txn, roomInfo)
return err
})
return updater, err
@ -461,6 +534,13 @@ func (d *Database) GetLatestEventsForUpdate(
func (d *Database) StoreEvent(
ctx context.Context, event *gomatrixserverlib.Event,
authEventNIDs []types.EventNID, isRejected bool,
) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) {
return d.storeEvent(ctx, nil, event, authEventNIDs, isRejected)
}
func (d *Database) storeEvent(
ctx context.Context, updater *RoomUpdater, event *gomatrixserverlib.Event,
authEventNIDs []types.EventNID, isRejected bool,
) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) {
var (
roomNID types.RoomNID
@ -472,8 +552,11 @@ func (d *Database) StoreEvent(
redactedEventID string
err error
)
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
var txn *sql.Tx
if updater != nil {
txn = updater.txn
}
err = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error {
// TODO: Here we should aim to have two different code paths for new rooms
// vs existing ones.
@ -546,42 +629,32 @@ func (d *Database) StoreEvent(
// events updater because it somewhat works as a mutex, ensuring
// that there's a row-level lock on the latest room events (well,
// on Postgres at least).
var roomInfo *types.RoomInfo
var updater *LatestEventsUpdater
if prevEvents := event.PrevEvents(); len(prevEvents) > 0 {
roomInfo, err = d.RoomInfo(ctx, event.RoomID())
if err != nil {
return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.RoomInfo: %w", err)
}
if roomInfo == nil && len(prevEvents) > 0 {
return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("expected room %q to exist", event.RoomID())
}
// Create an updater - NB: on sqlite this WILL create a txn as we are directly calling the shared DB form of
// GetLatestEventsForUpdate - not via the SQLiteDatabase form which has `nil` txns. This
// function only does SELECTs though so the created txn (at this point) is just a read txn like
// any other so this is fine. If we ever update GetLatestEventsForUpdate or NewLatestEventsUpdater
// to do writes however then this will need to go inside `Writer.Do`.
updater, err = d.GetLatestEventsForUpdate(ctx, *roomInfo)
if err != nil {
return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("NewLatestEventsUpdater: %w", err)
}
// Ensure that we atomically store prev events AND commit them. If we don't wrap StorePreviousEvents
// and EndTransaction in a writer then it's possible for a new write txn to be made between the two
// function calls which will then fail with 'database is locked'. This new write txn would HAVE to be
// something like SetRoomAlias/RemoveRoomAlias as normal input events are already done sequentially due to
// SupportsConcurrentRoomInputs() == false on sqlite, though this does not apply to setting room aliases
// as they don't go via InputRoomEvents
err = d.Writer.Do(d.DB, updater.txn, func(txn *sql.Tx) error {
if err = updater.StorePreviousEvents(eventNID, prevEvents); err != nil {
return fmt.Errorf("updater.StorePreviousEvents: %w", err)
succeeded := false
if updater == nil {
var roomInfo *types.RoomInfo
roomInfo, err = d.RoomInfo(ctx, event.RoomID())
if err != nil {
return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.RoomInfo: %w", err)
}
succeeded := true
err = sqlutil.EndTransaction(updater, &succeeded)
return err
})
if err != nil {
return 0, 0, types.StateAtEvent{}, nil, "", err
if roomInfo == nil && len(prevEvents) > 0 {
return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("expected room %q to exist", event.RoomID())
}
updater, err = d.GetRoomUpdater(ctx, roomInfo)
if err != nil {
return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("GetRoomUpdater: %w", err)
}
defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err)
}
if err = updater.StorePreviousEvents(eventNID, prevEvents); err != nil {
return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("updater.StorePreviousEvents: %w", err)
}
succeeded = true
}
return eventNID, roomNID, types.StateAtEvent{
@ -603,7 +676,7 @@ func (d *Database) PublishRoom(ctx context.Context, roomID string, publish bool)
}
func (d *Database) GetPublishedRooms(ctx context.Context) ([]string, error) {
return d.PublishedTable.SelectAllPublishedRooms(ctx, true)
return d.PublishedTable.SelectAllPublishedRooms(ctx, nil, true)
}
func (d *Database) assignRoomNID(
@ -875,14 +948,14 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s
eventNIDs = append(eventNIDs, e.EventNID)
}
}
eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, eventNIDs)
eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs)
if err != nil {
eventIDs = map[types.EventNID]string{}
}
// return the event requested
for _, e := range entries {
if e.EventTypeNID == eventTypeNID && e.EventStateKeyNID == stateKeyNID {
data, err := d.EventJSONTable.BulkSelectEventJSON(ctx, []types.EventNID{e.EventNID})
data, err := d.EventJSONTable.BulkSelectEventJSON(ctx, nil, []types.EventNID{e.EventNID})
if err != nil {
return nil, err
}
@ -922,11 +995,11 @@ func (d *Database) GetRoomsByMembership(ctx context.Context, userID, membership
}
return nil, fmt.Errorf("GetRoomsByMembership: cannot map user ID to state key NID: %w", err)
}
roomNIDs, err := d.MembershipTable.SelectRoomsWithMembership(ctx, stateKeyNID, membershipState)
roomNIDs, err := d.MembershipTable.SelectRoomsWithMembership(ctx, nil, stateKeyNID, membershipState)
if err != nil {
return nil, fmt.Errorf("GetRoomsByMembership: failed to SelectRoomsWithMembership: %w", err)
}
roomIDs, err := d.RoomsTable.BulkSelectRoomIDs(ctx, roomNIDs)
roomIDs, err := d.RoomsTable.BulkSelectRoomIDs(ctx, nil, roomNIDs)
if err != nil {
return nil, fmt.Errorf("GetRoomsByMembership: failed to lookup room nids: %w", err)
}
@ -945,7 +1018,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
}
// we don't bother failing the request if we get asked for event types we don't know about, as all that would result in is no matches which
// isn't a failure.
eventTypeNIDMap, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, eventTypes)
eventTypeNIDMap, err := d.EventTypesTable.BulkSelectEventTypeNID(ctx, nil, eventTypes)
if err != nil {
return nil, fmt.Errorf("GetBulkStateContent: failed to map event type nids: %w", err)
}
@ -965,7 +1038,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
}
eventStateKeyNIDMap, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, eventStateKeys)
eventStateKeyNIDMap, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, nil, eventStateKeys)
if err != nil {
return nil, fmt.Errorf("GetBulkStateContent: failed to map state key nids: %w", err)
}
@ -999,11 +1072,11 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
}
}
}
eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, eventNIDs)
eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs)
if err != nil {
eventIDs = map[types.EventNID]string{}
}
events, err := d.EventJSONTable.BulkSelectEventJSON(ctx, eventNIDs)
events, err := d.EventJSONTable.BulkSelectEventJSON(ctx, nil, eventNIDs)
if err != nil {
return nil, fmt.Errorf("GetBulkStateContent: failed to load event JSON for event nids: %w", err)
}
@ -1027,11 +1100,11 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
// JoinedUsersSetInRooms returns all joined users in the rooms given, along with the count of how many times they appear.
func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string) (map[string]int, error) {
roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, roomIDs)
roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, nil, roomIDs)
if err != nil {
return nil, err
}
userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, roomNIDs)
userNIDToCount, err := d.MembershipTable.SelectJoinedUsersSetForRooms(ctx, nil, roomNIDs)
if err != nil {
return nil, err
}
@ -1041,7 +1114,7 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string)
stateKeyNIDs[i] = nid
i++
}
nidToUserID, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, stateKeyNIDs)
nidToUserID, err := d.EventStateKeysTable.BulkSelectEventStateKey(ctx, nil, stateKeyNIDs)
if err != nil {
return nil, err
}
@ -1057,12 +1130,12 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs []string)
// GetLocalServerInRoom returns true if we think we're in a given room or false otherwise.
func (d *Database) GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) {
return d.MembershipTable.SelectLocalServerInRoom(ctx, roomNID)
return d.MembershipTable.SelectLocalServerInRoom(ctx, nil, roomNID)
}
// GetServerInRoom returns true if we think a server is in a given room or false otherwise.
func (d *Database) GetServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) {
return d.MembershipTable.SelectServerInRoom(ctx, roomNID, serverName)
return d.MembershipTable.SelectServerInRoom(ctx, nil, roomNID, serverName)
}
// GetKnownUsers searches all users that userID knows about.
@ -1071,17 +1144,17 @@ func (d *Database) GetKnownUsers(ctx context.Context, userID, searchString strin
if err != nil {
return nil, err
}
return d.MembershipTable.SelectKnownUsers(ctx, stateKeyNID, searchString, limit)
return d.MembershipTable.SelectKnownUsers(ctx, nil, stateKeyNID, searchString, limit)
}
// GetKnownRooms returns a list of all rooms we know about.
func (d *Database) GetKnownRooms(ctx context.Context) ([]string, error) {
return d.RoomsTable.SelectRoomIDs(ctx)
return d.RoomsTable.SelectRoomIDs(ctx, nil)
}
// ForgetRoom sets a users room to forgotten
func (d *Database) ForgetRoom(ctx context.Context, userID, roomID string, forget bool) error {
roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, []string{roomID})
roomNIDs, err := d.RoomsTable.BulkSelectRoomNIDs(ctx, nil, []string{roomID})
if err != nil {
return err
}

View file

@ -76,15 +76,20 @@ func (s *eventJSONStatements) InsertEventJSON(
}
func (s *eventJSONStatements) BulkSelectEventJSON(
ctx context.Context, eventNIDs []types.EventNID,
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
) ([]tables.EventJSONPair, error) {
iEventNIDs := make([]interface{}, len(eventNIDs))
for k, v := range eventNIDs {
iEventNIDs[k] = v
}
selectOrig := strings.Replace(bulkSelectEventJSONSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1)
rows, err := s.db.QueryContext(ctx, selectOrig, iEventNIDs...)
var rows *sql.Rows
var err error
if txn != nil {
rows, err = txn.QueryContext(ctx, selectOrig, iEventNIDs...)
} else {
rows, err = s.db.QueryContext(ctx, selectOrig, iEventNIDs...)
}
if err != nil {
return nil, err
}

View file

@ -112,15 +112,20 @@ func (s *eventStateKeyStatements) SelectEventStateKeyNID(
}
func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID(
ctx context.Context, eventStateKeys []string,
ctx context.Context, txn *sql.Tx, eventStateKeys []string,
) (map[string]types.EventStateKeyNID, error) {
iEventStateKeys := make([]interface{}, len(eventStateKeys))
for k, v := range eventStateKeys {
iEventStateKeys[k] = v
}
selectOrig := strings.Replace(bulkSelectEventStateKeySQL, "($1)", sqlutil.QueryVariadic(len(eventStateKeys)), 1)
rows, err := s.db.QueryContext(ctx, selectOrig, iEventStateKeys...)
var rows *sql.Rows
var err error
if txn != nil {
rows, err = txn.QueryContext(ctx, selectOrig, iEventStateKeys...)
} else {
rows, err = s.db.QueryContext(ctx, selectOrig, iEventStateKeys...)
}
if err != nil {
return nil, err
}
@ -138,15 +143,19 @@ func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID(
}
func (s *eventStateKeyStatements) BulkSelectEventStateKey(
ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID,
) (map[types.EventStateKeyNID]string, error) {
iEventStateKeyNIDs := make([]interface{}, len(eventStateKeyNIDs))
for k, v := range eventStateKeyNIDs {
iEventStateKeyNIDs[k] = v
}
selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", sqlutil.QueryVariadic(len(eventStateKeyNIDs)), 1)
rows, err := s.db.QueryContext(ctx, selectOrig, iEventStateKeyNIDs...)
selectPrep, err := s.db.Prepare(selectOrig)
if err != nil {
return nil, err
}
stmt := sqlutil.TxStmt(txn, selectPrep)
rows, err := stmt.QueryContext(ctx, iEventStateKeyNIDs...)
if err != nil {
return nil, err
}

View file

@ -128,7 +128,7 @@ func (s *eventTypeStatements) SelectEventTypeNID(
}
func (s *eventTypeStatements) BulkSelectEventTypeNID(
ctx context.Context, eventTypes []string,
ctx context.Context, txn *sql.Tx, eventTypes []string,
) (map[string]types.EventTypeNID, error) {
///////////////
iEventTypes := make([]interface{}, len(eventTypes))
@ -140,9 +140,10 @@ func (s *eventTypeStatements) BulkSelectEventTypeNID(
if err != nil {
return nil, err
}
stmt := sqlutil.TxStmt(txn, selectPrep)
///////////////
rows, err := selectPrep.QueryContext(ctx, iEventTypes...)
rows, err := stmt.QueryContext(ctx, iEventTypes...)
if err != nil {
return nil, err
}

View file

@ -184,7 +184,7 @@ func (s *eventStatements) SelectEvent(
// bulkSelectStateEventByID lookups a list of state events by event ID.
// If any of the requested events are missing from the database it returns a types.MissingEventError
func (s *eventStatements) BulkSelectStateEventByID(
ctx context.Context, eventIDs []string,
ctx context.Context, txn *sql.Tx, eventIDs []string,
) ([]types.StateEntry, error) {
///////////////
iEventIDs := make([]interface{}, len(eventIDs))
@ -196,6 +196,7 @@ func (s *eventStatements) BulkSelectStateEventByID(
if err != nil {
return nil, err
}
selectStmt = sqlutil.TxStmt(txn, selectStmt)
///////////////
rows, err := selectStmt.QueryContext(ctx, iEventIDs...)
@ -235,7 +236,7 @@ func (s *eventStatements) BulkSelectStateEventByID(
// bulkSelectStateEventByID lookups a list of state events by event ID.
// If any of the requested events are missing from the database it returns a types.MissingEventError
func (s *eventStatements) BulkSelectStateEventByNID(
ctx context.Context, eventNIDs []types.EventNID,
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntry, error) {
tuples := stateKeyTupleSorter(stateKeyTuples)
@ -263,6 +264,7 @@ func (s *eventStatements) BulkSelectStateEventByNID(
if err != nil {
return nil, fmt.Errorf("s.db.Prepare: %w", err)
}
selectStmt = sqlutil.TxStmt(txn, selectStmt)
rows, err := selectStmt.QueryContext(ctx, params...)
if err != nil {
return nil, fmt.Errorf("selectStmt.QueryContext: %w", err)
@ -291,7 +293,7 @@ func (s *eventStatements) BulkSelectStateEventByNID(
// If any of the requested events are missing from the database it returns a types.MissingEventError.
// If we do not have the state for any of the requested events it returns a types.MissingEventError.
func (s *eventStatements) BulkSelectStateAtEventByID(
ctx context.Context, eventIDs []string,
ctx context.Context, txn *sql.Tx, eventIDs []string,
) ([]types.StateAtEvent, error) {
///////////////
iEventIDs := make([]interface{}, len(eventIDs))
@ -303,6 +305,7 @@ func (s *eventStatements) BulkSelectStateAtEventByID(
if err != nil {
return nil, err
}
selectStmt = sqlutil.TxStmt(txn, selectStmt)
///////////////
rows, err := selectStmt.QueryContext(ctx, iEventIDs...)
if err != nil {
@ -381,6 +384,7 @@ func (s *eventStatements) BulkSelectStateAtEventAndReference(
if err != nil {
return nil, err
}
selectPrep = sqlutil.TxStmt(txn, selectPrep)
//////////////
rows, err := sqlutil.TxStmt(txn, selectPrep).QueryContext(ctx, iEventNIDs...)
@ -454,7 +458,7 @@ func (s *eventStatements) BulkSelectEventReference(
}
// bulkSelectEventID returns a map from numeric event ID to string event ID.
func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) {
func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error) {
///////////////
iEventNIDs := make([]interface{}, len(eventNIDs))
for k, v := range eventNIDs {
@ -465,6 +469,7 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []typ
if err != nil {
return nil, err
}
selectStmt = sqlutil.TxStmt(txn, selectStmt)
///////////////
rows, err := selectStmt.QueryContext(ctx, iEventNIDs...)
@ -490,7 +495,7 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []typ
// bulkSelectEventNIDs returns a map from string event ID to numeric event ID.
// If an event ID is not in the database then it is omitted from the map.
func (s *eventStatements) BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) {
func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) {
///////////////
iEventIDs := make([]interface{}, len(eventIDs))
for k, v := range eventIDs {
@ -501,6 +506,7 @@ func (s *eventStatements) BulkSelectEventNID(ctx context.Context, eventIDs []str
if err != nil {
return nil, err
}
selectStmt = sqlutil.TxStmt(txn, selectStmt)
///////////////
rows, err := selectStmt.QueryContext(ctx, iEventIDs...)
if err != nil {
@ -538,13 +544,14 @@ func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx,
}
func (s *eventStatements) SelectRoomNIDsForEventNIDs(
ctx context.Context, eventNIDs []types.EventNID,
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
) (map[types.EventNID]types.RoomNID, error) {
sqlStr := strings.Replace(selectRoomNIDsForEventNIDsSQL, "($1)", sqlutil.QueryVariadic(len(eventNIDs)), 1)
sqlPrep, err := s.db.Prepare(sqlStr)
if err != nil {
return nil, err
}
sqlPrep = sqlutil.TxStmt(txn, sqlPrep)
iEventNIDs := make([]interface{}, len(eventNIDs))
for i, v := range eventNIDs {
iEventNIDs[i] = v

View file

@ -88,8 +88,8 @@ func prepareInvitesTable(db *sql.DB) (tables.Invites, error) {
}
func (s *inviteStatements) InsertInviteEvent(
ctx context.Context,
txn *sql.Tx, inviteEventID string, roomNID types.RoomNID,
ctx context.Context, txn *sql.Tx,
inviteEventID string, roomNID types.RoomNID,
targetUserNID, senderUserNID types.EventStateKeyNID,
inviteEventJSON []byte,
) (bool, error) {
@ -109,8 +109,8 @@ func (s *inviteStatements) InsertInviteEvent(
}
func (s *inviteStatements) UpdateInviteRetired(
ctx context.Context,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (eventIDs []string, err error) {
// gather all the event IDs we will retire
stmt := sqlutil.TxStmt(txn, s.selectInvitesAboutToRetireStmt)
@ -134,10 +134,11 @@ func (s *inviteStatements) UpdateInviteRetired(
// selectInviteActiveForUserInRoom returns a list of sender state key NIDs
func (s *inviteStatements) SelectInviteActiveForUserInRoom(
ctx context.Context,
ctx context.Context, txn *sql.Tx,
targetUserNID types.EventStateKeyNID, roomNID types.RoomNID,
) ([]types.EventStateKeyNID, []string, error) {
rows, err := s.selectInviteActiveForUserInRoomStmt.QueryContext(
stmt := sqlutil.TxStmt(txn, s.selectInviteActiveForUserInRoomStmt)
rows, err := stmt.QueryContext(
ctx, targetUserNID, roomNID,
)
if err != nil {

View file

@ -184,17 +184,18 @@ func (s *membershipStatements) SelectMembershipForUpdate(
}
func (s *membershipStatements) SelectMembershipFromRoomAndTarget(
ctx context.Context,
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (eventNID types.EventNID, membership tables.MembershipState, forgotten bool, err error) {
err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext(
stmt := sqlutil.TxStmt(txn, s.selectMembershipFromRoomAndTargetStmt)
err = stmt.QueryRowContext(
ctx, roomNID, targetUserNID,
).Scan(&membership, &eventNID, &forgotten)
return
}
func (s *membershipStatements) SelectMembershipsFromRoom(
ctx context.Context,
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, localOnly bool,
) (eventNIDs []types.EventNID, err error) {
var selectStmt *sql.Stmt
@ -203,6 +204,7 @@ func (s *membershipStatements) SelectMembershipsFromRoom(
} else {
selectStmt = s.selectMembershipsFromRoomStmt
}
selectStmt = sqlutil.TxStmt(txn, selectStmt)
rows, err := selectStmt.QueryContext(ctx, roomNID)
if err != nil {
return nil, err
@ -220,7 +222,7 @@ func (s *membershipStatements) SelectMembershipsFromRoom(
}
func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
ctx context.Context,
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, membership tables.MembershipState, localOnly bool,
) (eventNIDs []types.EventNID, err error) {
var stmt *sql.Stmt
@ -229,6 +231,7 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
} else {
stmt = s.selectMembershipsFromRoomAndMembershipStmt
}
stmt = sqlutil.TxStmt(txn, stmt)
rows, err := stmt.QueryContext(ctx, roomNID, membership)
if err != nil {
return
@ -258,9 +261,10 @@ func (s *membershipStatements) UpdateMembership(
}
func (s *membershipStatements) SelectRoomsWithMembership(
ctx context.Context, userID types.EventStateKeyNID, membershipState tables.MembershipState,
ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState tables.MembershipState,
) ([]types.RoomNID, error) {
rows, err := s.selectRoomsWithMembershipStmt.QueryContext(ctx, membershipState, userID)
stmt := sqlutil.TxStmt(txn, s.selectRoomsWithMembershipStmt)
rows, err := stmt.QueryContext(ctx, membershipState, userID)
if err != nil {
return nil, err
}
@ -276,13 +280,19 @@ func (s *membershipStatements) SelectRoomsWithMembership(
return roomNIDs, nil
}
func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) {
func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) {
iRoomNIDs := make([]interface{}, len(roomNIDs))
for i, v := range roomNIDs {
iRoomNIDs[i] = v
}
query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomNIDs)), 1)
rows, err := s.db.QueryContext(ctx, query, iRoomNIDs...)
var rows *sql.Rows
var err error
if txn != nil {
rows, err = txn.QueryContext(ctx, query, iRoomNIDs...)
} else {
rows, err = s.db.QueryContext(ctx, query, iRoomNIDs...)
}
if err != nil {
return nil, err
}
@ -299,8 +309,9 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context,
return result, rows.Err()
}
func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) {
rows, err := s.selectKnownUsersStmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit)
func (s *membershipStatements) SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) {
stmt := sqlutil.TxStmt(txn, s.selectKnownUsersStmt)
rows, err := stmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit)
if err != nil {
return nil, err
}
@ -317,8 +328,8 @@ func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID type
}
func (s *membershipStatements) UpdateForgetMembership(
ctx context.Context,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
forget bool,
) error {
_, err := sqlutil.TxStmt(txn, s.updateMembershipForgetRoomStmt).ExecContext(
@ -327,9 +338,10 @@ func (s *membershipStatements) UpdateForgetMembership(
return err
}
func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) {
func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error) {
var nid types.RoomNID
err := s.selectLocalServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid)
stmt := sqlutil.TxStmt(txn, s.selectLocalServerInRoomStmt)
err := stmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid)
if err != nil {
if err == sql.ErrNoRows {
return false, nil
@ -340,9 +352,10 @@ func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, room
return found, nil
}
func (s *membershipStatements) SelectServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) {
func (s *membershipStatements) SelectServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) {
var nid types.RoomNID
err := s.selectServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid)
stmt := sqlutil.TxStmt(txn, s.selectServerInRoomStmt)
err := stmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid)
if err != nil {
if err == sql.ErrNoRows {
return false, nil

View file

@ -75,9 +75,10 @@ func (s *publishedStatements) UpsertRoomPublished(
}
func (s *publishedStatements) SelectPublishedFromRoomID(
ctx context.Context, roomID string,
ctx context.Context, txn *sql.Tx, roomID string,
) (published bool, err error) {
err = s.selectPublishedStmt.QueryRowContext(ctx, roomID).Scan(&published)
stmt := sqlutil.TxStmt(txn, s.selectPublishedStmt)
err = stmt.QueryRowContext(ctx, roomID).Scan(&published)
if err == sql.ErrNoRows {
return false, nil
}
@ -85,9 +86,10 @@ func (s *publishedStatements) SelectPublishedFromRoomID(
}
func (s *publishedStatements) SelectAllPublishedRooms(
ctx context.Context, published bool,
ctx context.Context, txn *sql.Tx, published bool,
) ([]string, error) {
rows, err := s.selectAllPublishedStmt.QueryContext(ctx, published)
stmt := sqlutil.TxStmt(txn, s.selectAllPublishedStmt)
rows, err := stmt.QueryContext(ctx, published)
if err != nil {
return nil, err
}

View file

@ -91,9 +91,10 @@ func (s *roomAliasesStatements) InsertRoomAlias(
}
func (s *roomAliasesStatements) SelectRoomIDFromAlias(
ctx context.Context, alias string,
ctx context.Context, txn *sql.Tx, alias string,
) (roomID string, err error) {
err = s.selectRoomIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&roomID)
stmt := sqlutil.TxStmt(txn, s.selectRoomIDFromAliasStmt)
err = stmt.QueryRowContext(ctx, alias).Scan(&roomID)
if err == sql.ErrNoRows {
return "", nil
}
@ -101,10 +102,11 @@ func (s *roomAliasesStatements) SelectRoomIDFromAlias(
}
func (s *roomAliasesStatements) SelectAliasesFromRoomID(
ctx context.Context, roomID string,
ctx context.Context, txn *sql.Tx, roomID string,
) (aliases []string, err error) {
aliases = []string{}
rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID)
stmt := sqlutil.TxStmt(txn, s.selectAliasesFromRoomIDStmt)
rows, err := stmt.QueryContext(ctx, roomID)
if err != nil {
return
}
@ -124,9 +126,10 @@ func (s *roomAliasesStatements) SelectAliasesFromRoomID(
}
func (s *roomAliasesStatements) SelectCreatorIDFromAlias(
ctx context.Context, alias string,
ctx context.Context, txn *sql.Tx, alias string,
) (creatorID string, err error) {
err = s.selectCreatorIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&creatorID)
stmt := sqlutil.TxStmt(txn, s.selectCreatorIDFromAliasStmt)
err = stmt.QueryRowContext(ctx, alias).Scan(&creatorID)
if err == sql.ErrNoRows {
return "", nil
}

View file

@ -107,8 +107,9 @@ func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
}.Prepare(db)
}
func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) {
rows, err := s.selectRoomIDsStmt.QueryContext(ctx)
func (s *roomStatements) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error) {
stmt := sqlutil.TxStmt(txn, s.selectRoomIDsStmt)
rows, err := stmt.QueryContext(ctx)
if err != nil {
return nil, err
}
@ -124,10 +125,11 @@ func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) {
return roomIDs, nil
}
func (s *roomStatements) SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) {
func (s *roomStatements) SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) {
var info types.RoomInfo
var latestNIDsJSON string
err := s.selectRoomInfoStmt.QueryRowContext(ctx, roomID).Scan(
stmt := sqlutil.TxStmt(txn, s.selectRoomInfoStmt)
err := stmt.QueryRowContext(ctx, roomID).Scan(
&info.RoomVersion, &info.RoomNID, &info.StateSnapshotNID, &latestNIDsJSON,
)
if err != nil {
@ -224,13 +226,14 @@ func (s *roomStatements) UpdateLatestEventNIDs(
}
func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
ctx context.Context, roomNIDs []types.RoomNID,
ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID,
) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) {
sqlStr := strings.Replace(selectRoomVersionsForRoomNIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
sqlPrep, err := s.db.Prepare(sqlStr)
if err != nil {
return nil, err
}
sqlPrep = sqlutil.TxStmt(txn, sqlPrep)
iRoomNIDs := make([]interface{}, len(roomNIDs))
for i, v := range roomNIDs {
iRoomNIDs[i] = v
@ -252,13 +255,19 @@ func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
return result, nil
}
func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) {
func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) ([]string, error) {
iRoomNIDs := make([]interface{}, len(roomNIDs))
for i, v := range roomNIDs {
iRoomNIDs[i] = v
}
sqlQuery := strings.Replace(bulkSelectRoomIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
rows, err := s.db.QueryContext(ctx, sqlQuery, iRoomNIDs...)
var rows *sql.Rows
var err error
if txn != nil {
rows, err = txn.QueryContext(ctx, sqlQuery, iRoomNIDs...)
} else {
rows, err = s.db.QueryContext(ctx, sqlQuery, iRoomNIDs...)
}
if err != nil {
return nil, err
}
@ -274,13 +283,19 @@ func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types
return roomIDs, nil
}
func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error) {
func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, roomIDs []string) ([]types.RoomNID, error) {
iRoomIDs := make([]interface{}, len(roomIDs))
for i, v := range roomIDs {
iRoomIDs[i] = v
}
sqlQuery := strings.Replace(bulkSelectRoomNIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomIDs)), 1)
rows, err := s.db.QueryContext(ctx, sqlQuery, iRoomIDs...)
var rows *sql.Rows
var err error
if txn != nil {
rows, err = txn.QueryContext(ctx, sqlQuery, iRoomIDs...)
} else {
rows, err = s.db.QueryContext(ctx, sqlQuery, iRoomIDs...)
}
if err != nil {
return nil, err
}

View file

@ -81,8 +81,7 @@ func prepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
}
func (s *stateBlockStatements) BulkInsertStateData(
ctx context.Context,
txn *sql.Tx,
ctx context.Context, txn *sql.Tx,
entries types.StateEntries,
) (id types.StateBlockNID, err error) {
entries = entries[:util.SortAndUnique(entries)]
@ -94,14 +93,15 @@ func (s *stateBlockStatements) BulkInsertStateData(
if err != nil {
return 0, fmt.Errorf("json.Marshal: %w", err)
}
err = s.insertStateDataStmt.QueryRowContext(
stmt := sqlutil.TxStmt(txn, s.insertStateDataStmt)
err = stmt.QueryRowContext(
ctx, nids.Hash(), js,
).Scan(&id)
return
}
func (s *stateBlockStatements) BulkSelectStateBlockEntries(
ctx context.Context, stateBlockNIDs types.StateBlockNIDs,
ctx context.Context, txn *sql.Tx, stateBlockNIDs types.StateBlockNIDs,
) ([][]types.EventNID, error) {
intfs := make([]interface{}, len(stateBlockNIDs))
for i := range stateBlockNIDs {
@ -112,6 +112,7 @@ func (s *stateBlockStatements) BulkSelectStateBlockEntries(
if err != nil {
return nil, err
}
selectStmt = sqlutil.TxStmt(txn, selectStmt)
rows, err := selectStmt.QueryContext(ctx, intfs...)
if err != nil {
return nil, err

View file

@ -106,7 +106,7 @@ func (s *stateSnapshotStatements) InsertState(
}
func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
ctx context.Context, stateNIDs []types.StateSnapshotNID,
ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID,
) ([]types.StateBlockNIDList, error) {
nids := make([]interface{}, len(stateNIDs))
for k, v := range stateNIDs {
@ -117,6 +117,7 @@ func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
if err != nil {
return nil, err
}
selectStmt = sqlutil.TxStmt(txn, selectStmt)
rows, err := selectStmt.QueryContext(ctx, nids...)
if err != nil {

View file

@ -172,23 +172,23 @@ func (d *Database) prepare(db *sql.DB, cache caching.RoomServerCaches) error {
return err
}
d.Database = shared.Database{
DB: db,
Cache: cache,
Writer: sqlutil.NewExclusiveWriter(),
EventsTable: events,
EventTypesTable: eventTypes,
EventStateKeysTable: eventStateKeys,
EventJSONTable: eventJSON,
RoomsTable: rooms,
StateBlockTable: stateBlock,
StateSnapshotTable: stateSnapshot,
PrevEventsTable: prevEvents,
RoomAliasesTable: roomAliases,
InvitesTable: invites,
MembershipTable: membership,
PublishedTable: published,
RedactionsTable: redactions,
GetLatestEventsForUpdateFn: d.GetLatestEventsForUpdate,
DB: db,
Cache: cache,
Writer: sqlutil.NewExclusiveWriter(),
EventsTable: events,
EventTypesTable: eventTypes,
EventStateKeysTable: eventStateKeys,
EventJSONTable: eventJSON,
RoomsTable: rooms,
StateBlockTable: stateBlock,
StateSnapshotTable: stateSnapshot,
PrevEventsTable: prevEvents,
RoomAliasesTable: roomAliases,
InvitesTable: invites,
MembershipTable: membership,
PublishedTable: published,
RedactionsTable: redactions,
GetRoomUpdaterFn: d.GetRoomUpdater,
}
return nil
}
@ -201,16 +201,16 @@ func (d *Database) SupportsConcurrentRoomInputs() bool {
return false
}
func (d *Database) GetLatestEventsForUpdate(
ctx context.Context, roomInfo types.RoomInfo,
) (*shared.LatestEventsUpdater, error) {
func (d *Database) GetRoomUpdater(
ctx context.Context, roomInfo *types.RoomInfo,
) (*shared.RoomUpdater, error) {
// TODO: Do not use transactions. We should be holding open this transaction but we cannot have
// multiple write transactions on sqlite. The code will perform additional
// write transactions independent of this one which will consistently cause
// 'database is locked' errors. As sqlite doesn't support multi-process on the
// same DB anyway, and we only execute updates sequentially, the only worries
// are for rolling back when things go wrong. (atomicity)
return shared.NewLatestEventsUpdater(ctx, &d.Database, nil, roomInfo)
return shared.NewRoomUpdater(ctx, &d.Database, nil, roomInfo)
}
func (d *Database) MembershipUpdater(

View file

@ -18,20 +18,20 @@ type EventJSONPair struct {
type EventJSON interface {
// Insert the event JSON. On conflict, replace the event JSON with the new value (for redactions).
InsertEventJSON(ctx context.Context, tx *sql.Tx, eventNID types.EventNID, eventJSON []byte) error
BulkSelectEventJSON(ctx context.Context, eventNIDs []types.EventNID) ([]EventJSONPair, error)
BulkSelectEventJSON(ctx context.Context, tx *sql.Tx, eventNIDs []types.EventNID) ([]EventJSONPair, error)
}
type EventTypes interface {
InsertEventTypeNID(ctx context.Context, tx *sql.Tx, eventType string) (types.EventTypeNID, error)
SelectEventTypeNID(ctx context.Context, tx *sql.Tx, eventType string) (types.EventTypeNID, error)
BulkSelectEventTypeNID(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error)
BulkSelectEventTypeNID(ctx context.Context, txn *sql.Tx, eventTypes []string) (map[string]types.EventTypeNID, error)
}
type EventStateKeys interface {
InsertEventStateKeyNID(ctx context.Context, txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error)
SelectEventStateKeyNID(ctx context.Context, txn *sql.Tx, eventStateKey string) (types.EventStateKeyNID, error)
BulkSelectEventStateKeyNID(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error)
BulkSelectEventStateKey(ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error)
BulkSelectEventStateKeyNID(ctx context.Context, txn *sql.Tx, eventStateKeys []string) (map[string]types.EventStateKeyNID, error)
BulkSelectEventStateKey(ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error)
}
type Events interface {
@ -42,12 +42,12 @@ type Events interface {
SelectEvent(ctx context.Context, txn *sql.Tx, eventID string) (types.EventNID, types.StateSnapshotNID, error)
// bulkSelectStateEventByID lookups a list of state events by event ID.
// If any of the requested events are missing from the database it returns a types.MissingEventError
BulkSelectStateEventByID(ctx context.Context, eventIDs []string) ([]types.StateEntry, error)
BulkSelectStateEventByNID(ctx context.Context, eventNIDs []types.EventNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntry, error)
BulkSelectStateEventByID(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StateEntry, error)
BulkSelectStateEventByNID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntry, error)
// BulkSelectStateAtEventByID lookups the state at a list of events by event ID.
// If any of the requested events are missing from the database it returns a types.MissingEventError.
// If we do not have the state for any of the requested events it returns a types.MissingEventError.
BulkSelectStateAtEventByID(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error)
BulkSelectStateAtEventByID(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StateAtEvent, error)
UpdateEventState(ctx context.Context, txn *sql.Tx, eventNID types.EventNID, stateNID types.StateSnapshotNID) error
SelectEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) (sentToOutput bool, err error)
UpdateEventSentToOutput(ctx context.Context, txn *sql.Tx, eventNID types.EventNID) error
@ -55,12 +55,12 @@ type Events interface {
BulkSelectStateAtEventAndReference(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) ([]types.StateAtEventAndReference, error)
BulkSelectEventReference(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) ([]gomatrixserverlib.EventReference, error)
// BulkSelectEventID returns a map from numeric event ID to string event ID.
BulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error)
BulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error)
// BulkSelectEventNIDs returns a map from string event ID to numeric event ID.
// If an event ID is not in the database then it is omitted from the map.
BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error)
BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error)
SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error)
SelectRoomNIDsForEventNIDs(ctx context.Context, eventNIDs []types.EventNID) (roomNIDs map[types.EventNID]types.RoomNID, err error)
SelectRoomNIDsForEventNIDs(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (roomNIDs map[types.EventNID]types.RoomNID, err error)
}
type Rooms interface {
@ -69,29 +69,29 @@ type Rooms interface {
SelectLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.StateSnapshotNID, error)
SelectLatestEventsNIDsForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) ([]types.EventNID, types.EventNID, types.StateSnapshotNID, error)
UpdateLatestEventNIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID, lastEventSentNID types.EventNID, stateSnapshotNID types.StateSnapshotNID) error
SelectRoomVersionsForRoomNIDs(ctx context.Context, roomNID []types.RoomNID) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error)
SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error)
SelectRoomIDs(ctx context.Context) ([]string, error)
BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error)
BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error)
SelectRoomVersionsForRoomNIDs(ctx context.Context, txn *sql.Tx, roomNID []types.RoomNID) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error)
SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error)
SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error)
BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) ([]string, error)
BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, roomIDs []string) ([]types.RoomNID, error)
}
type StateSnapshot interface {
InsertState(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs types.StateBlockNIDs) (stateNID types.StateSnapshotNID, err error)
BulkSelectStateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error)
BulkSelectStateBlockNIDs(ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error)
}
type StateBlock interface {
BulkInsertStateData(ctx context.Context, txn *sql.Tx, entries types.StateEntries) (types.StateBlockNID, error)
BulkSelectStateBlockEntries(ctx context.Context, stateBlockNIDs types.StateBlockNIDs) ([][]types.EventNID, error)
BulkSelectStateBlockEntries(ctx context.Context, txn *sql.Tx, stateBlockNIDs types.StateBlockNIDs) ([][]types.EventNID, error)
//BulkSelectFilteredStateBlockEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error)
}
type RoomAliases interface {
InsertRoomAlias(ctx context.Context, txn *sql.Tx, alias string, roomID string, creatorUserID string) (err error)
SelectRoomIDFromAlias(ctx context.Context, alias string) (roomID string, err error)
SelectAliasesFromRoomID(ctx context.Context, roomID string) ([]string, error)
SelectCreatorIDFromAlias(ctx context.Context, alias string) (creatorID string, err error)
SelectRoomIDFromAlias(ctx context.Context, txn *sql.Tx, alias string) (roomID string, err error)
SelectAliasesFromRoomID(ctx context.Context, txn *sql.Tx, roomID string) ([]string, error)
SelectCreatorIDFromAlias(ctx context.Context, txn *sql.Tx, alias string) (creatorID string, err error)
DeleteRoomAlias(ctx context.Context, txn *sql.Tx, alias string) (err error)
}
@ -106,7 +106,7 @@ type Invites interface {
InsertInviteEvent(ctx context.Context, txn *sql.Tx, inviteEventID string, roomNID types.RoomNID, targetUserNID, senderUserNID types.EventStateKeyNID, inviteEventJSON []byte) (bool, error)
UpdateInviteRetired(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) ([]string, error)
// SelectInviteActiveForUserInRoom returns a list of sender state key NIDs and invite event IDs matching those nids.
SelectInviteActiveForUserInRoom(ctx context.Context, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID) ([]types.EventStateKeyNID, []string, error)
SelectInviteActiveForUserInRoom(ctx context.Context, txn *sql.Tx, targetUserNID types.EventStateKeyNID, roomNID types.RoomNID) ([]types.EventStateKeyNID, []string, error)
}
type MembershipState int64
@ -121,24 +121,24 @@ const (
type Membership interface {
InsertMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, localTarget bool) error
SelectMembershipForUpdate(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (MembershipState, error)
SelectMembershipFromRoomAndTarget(ctx context.Context, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (types.EventNID, MembershipState, bool, error)
SelectMembershipsFromRoom(ctx context.Context, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error)
SelectMembershipsFromRoomAndMembership(ctx context.Context, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error)
SelectMembershipFromRoomAndTarget(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) (types.EventNID, MembershipState, bool, error)
SelectMembershipsFromRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, localOnly bool) (eventNIDs []types.EventNID, err error)
SelectMembershipsFromRoomAndMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, membership MembershipState, localOnly bool) (eventNIDs []types.EventNID, err error)
UpdateMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership MembershipState, eventNID types.EventNID, forgotten bool) error
SelectRoomsWithMembership(ctx context.Context, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error)
SelectRoomsWithMembership(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState MembershipState) ([]types.RoomNID, error)
// SelectJoinedUsersSetForRooms returns the set of all users in the rooms who are joined to any of these rooms, along with the
// counts of how many rooms they are joined.
SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error)
SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error)
SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error)
SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error)
UpdateForgetMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool) error
SelectLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error)
SelectServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error)
SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error)
SelectServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error)
}
type Published interface {
UpsertRoomPublished(ctx context.Context, txn *sql.Tx, roomID string, published bool) (err error)
SelectPublishedFromRoomID(ctx context.Context, roomID string) (published bool, err error)
SelectAllPublishedRooms(ctx context.Context, published bool) ([]string, error)
SelectPublishedFromRoomID(ctx context.Context, txn *sql.Tx, roomID string) (published bool, err error)
SelectAllPublishedRooms(ctx context.Context, txn *sql.Tx, published bool) ([]string, error)
}
type RedactionInfo struct {