mirror of
https://github.com/hoernschen/dendrite.git
synced 2025-07-31 13:22:46 +00:00
Full roomserver input transactional isolation (#2141)
* Add transaction to all database tables in roomserver, rename latest events updater to room updater, use room updater for all RS input * Better transaction management * Tweak order * Handle cases where the room does not exist * Other fixes * More tweaks * Fill some gaps * Fill in the gaps * good lord it gets worse * Don't roll back transactions when events rejected * Pass through errors properly * Fix bugs * Fix incorrect error check * Don't panic on nil txns * Tweaks * Hopefully fix panics for good in SQLite this time * Fix rollback * Minor bug fixes with latest event updater * Some review comments * Revert "Some review comments" This reverts commit 0caf8cf53e62c33f7b83c52e9df1d963871f751e. * Fix a couple of bugs * Clearer commit and rollback results * Remove unnecessary prepares
This commit is contained in:
parent
4d9f5b2e57
commit
eb352a5f6b
35 changed files with 867 additions and 499 deletions
|
@ -81,9 +81,10 @@ func (s *eventJSONStatements) InsertEventJSON(
|
|||
}
|
||||
|
||||
func (s *eventJSONStatements) BulkSelectEventJSON(
|
||||
ctx context.Context, eventNIDs []types.EventNID,
|
||||
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
|
||||
) ([]tables.EventJSONPair, error) {
|
||||
rows, err := s.bulkSelectEventJSONStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
|
||||
stmt := sqlutil.TxStmt(txn, s.bulkSelectEventJSONStmt)
|
||||
rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -111,9 +111,10 @@ func (s *eventStateKeyStatements) SelectEventStateKeyNID(
|
|||
}
|
||||
|
||||
func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID(
|
||||
ctx context.Context, eventStateKeys []string,
|
||||
ctx context.Context, txn *sql.Tx, eventStateKeys []string,
|
||||
) (map[string]types.EventStateKeyNID, error) {
|
||||
rows, err := s.bulkSelectEventStateKeyNIDStmt.QueryContext(
|
||||
stmt := sqlutil.TxStmt(txn, s.bulkSelectEventStateKeyNIDStmt)
|
||||
rows, err := stmt.QueryContext(
|
||||
ctx, pq.StringArray(eventStateKeys),
|
||||
)
|
||||
if err != nil {
|
||||
|
@ -134,13 +135,14 @@ func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID(
|
|||
}
|
||||
|
||||
func (s *eventStateKeyStatements) BulkSelectEventStateKey(
|
||||
ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
|
||||
ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID,
|
||||
) (map[types.EventStateKeyNID]string, error) {
|
||||
nIDs := make(pq.Int64Array, len(eventStateKeyNIDs))
|
||||
for i := range eventStateKeyNIDs {
|
||||
nIDs[i] = int64(eventStateKeyNIDs[i])
|
||||
}
|
||||
rows, err := s.bulkSelectEventStateKeyStmt.QueryContext(ctx, nIDs)
|
||||
stmt := sqlutil.TxStmt(txn, s.bulkSelectEventStateKeyStmt)
|
||||
rows, err := stmt.QueryContext(ctx, nIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -133,9 +133,10 @@ func (s *eventTypeStatements) SelectEventTypeNID(
|
|||
}
|
||||
|
||||
func (s *eventTypeStatements) BulkSelectEventTypeNID(
|
||||
ctx context.Context, eventTypes []string,
|
||||
ctx context.Context, txn *sql.Tx, eventTypes []string,
|
||||
) (map[string]types.EventTypeNID, error) {
|
||||
rows, err := s.bulkSelectEventTypeNIDStmt.QueryContext(ctx, pq.StringArray(eventTypes))
|
||||
stmt := sqlutil.TxStmt(txn, s.bulkSelectEventTypeNIDStmt)
|
||||
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventTypes))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -212,9 +212,10 @@ func (s *eventStatements) SelectEvent(
|
|||
// bulkSelectStateEventByID lookups a list of state events by event ID.
|
||||
// If any of the requested events are missing from the database it returns a types.MissingEventError
|
||||
func (s *eventStatements) BulkSelectStateEventByID(
|
||||
ctx context.Context, eventIDs []string,
|
||||
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
||||
) ([]types.StateEntry, error) {
|
||||
rows, err := s.bulkSelectStateEventByIDStmt.QueryContext(ctx, pq.StringArray(eventIDs))
|
||||
stmt := sqlutil.TxStmt(txn, s.bulkSelectStateEventByIDStmt)
|
||||
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -254,13 +255,14 @@ func (s *eventStatements) BulkSelectStateEventByID(
|
|||
// bulkSelectStateEventByNID lookups a list of state events by event NID.
|
||||
// If any of the requested events are missing from the database it returns a types.MissingEventError
|
||||
func (s *eventStatements) BulkSelectStateEventByNID(
|
||||
ctx context.Context, eventNIDs []types.EventNID,
|
||||
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
|
||||
stateKeyTuples []types.StateKeyTuple,
|
||||
) ([]types.StateEntry, error) {
|
||||
tuples := stateKeyTupleSorter(stateKeyTuples)
|
||||
sort.Sort(tuples)
|
||||
eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays()
|
||||
rows, err := s.bulkSelectStateEventByNIDStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs), eventTypeNIDArray, eventStateKeyNIDArray)
|
||||
stmt := sqlutil.TxStmt(txn, s.bulkSelectStateEventByNIDStmt)
|
||||
rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs), eventTypeNIDArray, eventStateKeyNIDArray)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -291,9 +293,10 @@ func (s *eventStatements) BulkSelectStateEventByNID(
|
|||
// If any of the requested events are missing from the database it returns a types.MissingEventError.
|
||||
// If we do not have the state for any of the requested events it returns a types.MissingEventError.
|
||||
func (s *eventStatements) BulkSelectStateAtEventByID(
|
||||
ctx context.Context, eventIDs []string,
|
||||
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
||||
) ([]types.StateAtEvent, error) {
|
||||
rows, err := s.bulkSelectStateAtEventByIDStmt.QueryContext(ctx, pq.StringArray(eventIDs))
|
||||
stmt := sqlutil.TxStmt(txn, s.bulkSelectStateAtEventByIDStmt)
|
||||
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -428,8 +431,9 @@ func (s *eventStatements) BulkSelectEventReference(
|
|||
}
|
||||
|
||||
// bulkSelectEventID returns a map from numeric event ID to string event ID.
|
||||
func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) {
|
||||
rows, err := s.bulkSelectEventIDStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
|
||||
func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.bulkSelectEventIDStmt)
|
||||
rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -455,8 +459,9 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []typ
|
|||
|
||||
// bulkSelectEventNIDs returns a map from string event ID to numeric event ID.
|
||||
// If an event ID is not in the database then it is omitted from the map.
|
||||
func (s *eventStatements) BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) {
|
||||
rows, err := s.bulkSelectEventNIDStmt.QueryContext(ctx, pq.StringArray(eventIDs))
|
||||
func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.bulkSelectEventNIDStmt)
|
||||
rows, err := stmt.QueryContext(ctx, pq.StringArray(eventIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -484,9 +489,10 @@ func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx,
|
|||
}
|
||||
|
||||
func (s *eventStatements) SelectRoomNIDsForEventNIDs(
|
||||
ctx context.Context, eventNIDs []types.EventNID,
|
||||
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
|
||||
) (map[types.EventNID]types.RoomNID, error) {
|
||||
rows, err := s.selectRoomNIDsForEventNIDsStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
|
||||
stmt := sqlutil.TxStmt(txn, s.selectRoomNIDsForEventNIDsStmt)
|
||||
rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -97,8 +97,8 @@ func prepareInvitesTable(db *sql.DB) (tables.Invites, error) {
|
|||
}
|
||||
|
||||
func (s *inviteStatements) InsertInviteEvent(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx, inviteEventID string, roomNID types.RoomNID,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
inviteEventID string, roomNID types.RoomNID,
|
||||
targetUserNID, senderUserNID types.EventStateKeyNID,
|
||||
inviteEventJSON []byte,
|
||||
) (bool, error) {
|
||||
|
@ -116,8 +116,8 @@ func (s *inviteStatements) InsertInviteEvent(
|
|||
}
|
||||
|
||||
func (s *inviteStatements) UpdateInviteRetired(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||
) ([]string, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.updateInviteRetiredStmt)
|
||||
rows, err := stmt.QueryContext(ctx, roomNID, targetUserNID)
|
||||
|
@ -139,10 +139,11 @@ func (s *inviteStatements) UpdateInviteRetired(
|
|||
|
||||
// SelectInviteActiveForUserInRoom returns a list of sender state key NIDs
|
||||
func (s *inviteStatements) SelectInviteActiveForUserInRoom(
|
||||
ctx context.Context,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
targetUserNID types.EventStateKeyNID, roomNID types.RoomNID,
|
||||
) ([]types.EventStateKeyNID, []string, error) {
|
||||
rows, err := s.selectInviteActiveForUserInRoomStmt.QueryContext(
|
||||
stmt := sqlutil.TxStmt(txn, s.selectInviteActiveForUserInRoomStmt)
|
||||
rows, err := stmt.QueryContext(
|
||||
ctx, targetUserNID, roomNID,
|
||||
)
|
||||
if err != nil {
|
||||
|
|
|
@ -186,8 +186,8 @@ func prepareMembershipTable(db *sql.DB) (tables.Membership, error) {
|
|||
}
|
||||
|
||||
func (s *membershipStatements) InsertMembership(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||
localTarget bool,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertMembershipStmt)
|
||||
|
@ -196,8 +196,8 @@ func (s *membershipStatements) InsertMembership(
|
|||
}
|
||||
|
||||
func (s *membershipStatements) SelectMembershipForUpdate(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||
) (membership tables.MembershipState, err error) {
|
||||
err = sqlutil.TxStmt(txn, s.selectMembershipForUpdateStmt).QueryRowContext(
|
||||
ctx, roomNID, targetUserNID,
|
||||
|
@ -206,17 +206,19 @@ func (s *membershipStatements) SelectMembershipForUpdate(
|
|||
}
|
||||
|
||||
func (s *membershipStatements) SelectMembershipFromRoomAndTarget(
|
||||
ctx context.Context,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||
) (eventNID types.EventNID, membership tables.MembershipState, forgotten bool, err error) {
|
||||
err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext(
|
||||
stmt := sqlutil.TxStmt(txn, s.selectMembershipFromRoomAndTargetStmt)
|
||||
err = stmt.QueryRowContext(
|
||||
ctx, roomNID, targetUserNID,
|
||||
).Scan(&membership, &eventNID, &forgotten)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *membershipStatements) SelectMembershipsFromRoom(
|
||||
ctx context.Context, roomNID types.RoomNID, localOnly bool,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
roomNID types.RoomNID, localOnly bool,
|
||||
) (eventNIDs []types.EventNID, err error) {
|
||||
var stmt *sql.Stmt
|
||||
if localOnly {
|
||||
|
@ -224,6 +226,7 @@ func (s *membershipStatements) SelectMembershipsFromRoom(
|
|||
} else {
|
||||
stmt = s.selectMembershipsFromRoomStmt
|
||||
}
|
||||
stmt = sqlutil.TxStmt(txn, stmt)
|
||||
rows, err := stmt.QueryContext(ctx, roomNID)
|
||||
if err != nil {
|
||||
return
|
||||
|
@ -241,7 +244,7 @@ func (s *membershipStatements) SelectMembershipsFromRoom(
|
|||
}
|
||||
|
||||
func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
|
||||
ctx context.Context,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
roomNID types.RoomNID, membership tables.MembershipState, localOnly bool,
|
||||
) (eventNIDs []types.EventNID, err error) {
|
||||
var rows *sql.Rows
|
||||
|
@ -251,6 +254,7 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
|
|||
} else {
|
||||
stmt = s.selectMembershipsFromRoomAndMembershipStmt
|
||||
}
|
||||
stmt = sqlutil.TxStmt(txn, stmt)
|
||||
rows, err = stmt.QueryContext(ctx, roomNID, membership)
|
||||
if err != nil {
|
||||
return
|
||||
|
@ -268,8 +272,8 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
|
|||
}
|
||||
|
||||
func (s *membershipStatements) UpdateMembership(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership tables.MembershipState,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership tables.MembershipState,
|
||||
eventNID types.EventNID, forgotten bool,
|
||||
) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.updateMembershipStmt).ExecContext(
|
||||
|
@ -279,9 +283,11 @@ func (s *membershipStatements) UpdateMembership(
|
|||
}
|
||||
|
||||
func (s *membershipStatements) SelectRoomsWithMembership(
|
||||
ctx context.Context, userID types.EventStateKeyNID, membershipState tables.MembershipState,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
userID types.EventStateKeyNID, membershipState tables.MembershipState,
|
||||
) ([]types.RoomNID, error) {
|
||||
rows, err := s.selectRoomsWithMembershipStmt.QueryContext(ctx, membershipState, userID)
|
||||
stmt := sqlutil.TxStmt(txn, s.selectRoomsWithMembershipStmt)
|
||||
rows, err := stmt.QueryContext(ctx, membershipState, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -297,12 +303,16 @@ func (s *membershipStatements) SelectRoomsWithMembership(
|
|||
return roomNIDs, nil
|
||||
}
|
||||
|
||||
func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) {
|
||||
func (s *membershipStatements) SelectJoinedUsersSetForRooms(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
roomNIDs []types.RoomNID,
|
||||
) (map[types.EventStateKeyNID]int, error) {
|
||||
roomIDarray := make([]int64, len(roomNIDs))
|
||||
for i := range roomNIDs {
|
||||
roomIDarray[i] = int64(roomNIDs[i])
|
||||
}
|
||||
rows, err := s.selectJoinedUsersSetForRoomsStmt.QueryContext(ctx, pq.Int64Array(roomIDarray))
|
||||
stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsStmt)
|
||||
rows, err := stmt.QueryContext(ctx, pq.Int64Array(roomIDarray))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -319,8 +329,12 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context,
|
|||
return result, rows.Err()
|
||||
}
|
||||
|
||||
func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) {
|
||||
rows, err := s.selectKnownUsersStmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit)
|
||||
func (s *membershipStatements) SelectKnownUsers(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
userID types.EventStateKeyNID, searchString string, limit int,
|
||||
) ([]string, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectKnownUsersStmt)
|
||||
rows, err := stmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -337,9 +351,8 @@ func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID type
|
|||
}
|
||||
|
||||
func (s *membershipStatements) UpdateForgetMembership(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||
forget bool,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool,
|
||||
) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.updateMembershipForgetRoomStmt).ExecContext(
|
||||
ctx, roomNID, targetUserNID, forget,
|
||||
|
@ -347,9 +360,13 @@ func (s *membershipStatements) UpdateForgetMembership(
|
|||
return err
|
||||
}
|
||||
|
||||
func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) {
|
||||
func (s *membershipStatements) SelectLocalServerInRoom(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
roomNID types.RoomNID,
|
||||
) (bool, error) {
|
||||
var nid types.RoomNID
|
||||
err := s.selectLocalServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid)
|
||||
stmt := sqlutil.TxStmt(txn, s.selectLocalServerInRoomStmt)
|
||||
err := stmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return false, nil
|
||||
|
@ -360,9 +377,13 @@ func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, room
|
|||
return found, nil
|
||||
}
|
||||
|
||||
func (s *membershipStatements) SelectServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) {
|
||||
func (s *membershipStatements) SelectServerInRoom(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
roomNID types.RoomNID, serverName gomatrixserverlib.ServerName,
|
||||
) (bool, error) {
|
||||
var nid types.RoomNID
|
||||
err := s.selectServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid)
|
||||
stmt := sqlutil.TxStmt(txn, s.selectServerInRoomStmt)
|
||||
err := stmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return false, nil
|
||||
|
|
|
@ -73,9 +73,10 @@ func (s *publishedStatements) UpsertRoomPublished(
|
|||
}
|
||||
|
||||
func (s *publishedStatements) SelectPublishedFromRoomID(
|
||||
ctx context.Context, roomID string,
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) (published bool, err error) {
|
||||
err = s.selectPublishedStmt.QueryRowContext(ctx, roomID).Scan(&published)
|
||||
stmt := sqlutil.TxStmt(txn, s.selectPublishedStmt)
|
||||
err = stmt.QueryRowContext(ctx, roomID).Scan(&published)
|
||||
if err == sql.ErrNoRows {
|
||||
return false, nil
|
||||
}
|
||||
|
@ -83,9 +84,10 @@ func (s *publishedStatements) SelectPublishedFromRoomID(
|
|||
}
|
||||
|
||||
func (s *publishedStatements) SelectAllPublishedRooms(
|
||||
ctx context.Context, published bool,
|
||||
ctx context.Context, txn *sql.Tx, published bool,
|
||||
) ([]string, error) {
|
||||
rows, err := s.selectAllPublishedStmt.QueryContext(ctx, published)
|
||||
stmt := sqlutil.TxStmt(txn, s.selectAllPublishedStmt)
|
||||
rows, err := stmt.QueryContext(ctx, published)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -87,9 +87,10 @@ func (s *roomAliasesStatements) InsertRoomAlias(
|
|||
}
|
||||
|
||||
func (s *roomAliasesStatements) SelectRoomIDFromAlias(
|
||||
ctx context.Context, alias string,
|
||||
ctx context.Context, txn *sql.Tx, alias string,
|
||||
) (roomID string, err error) {
|
||||
err = s.selectRoomIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&roomID)
|
||||
stmt := sqlutil.TxStmt(txn, s.selectRoomIDFromAliasStmt)
|
||||
err = stmt.QueryRowContext(ctx, alias).Scan(&roomID)
|
||||
if err == sql.ErrNoRows {
|
||||
return "", nil
|
||||
}
|
||||
|
@ -97,9 +98,10 @@ func (s *roomAliasesStatements) SelectRoomIDFromAlias(
|
|||
}
|
||||
|
||||
func (s *roomAliasesStatements) SelectAliasesFromRoomID(
|
||||
ctx context.Context, roomID string,
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) ([]string, error) {
|
||||
rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID)
|
||||
stmt := sqlutil.TxStmt(txn, s.selectAliasesFromRoomIDStmt)
|
||||
rows, err := stmt.QueryContext(ctx, roomID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -118,9 +120,10 @@ func (s *roomAliasesStatements) SelectAliasesFromRoomID(
|
|||
}
|
||||
|
||||
func (s *roomAliasesStatements) SelectCreatorIDFromAlias(
|
||||
ctx context.Context, alias string,
|
||||
ctx context.Context, txn *sql.Tx, alias string,
|
||||
) (creatorID string, err error) {
|
||||
err = s.selectCreatorIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&creatorID)
|
||||
stmt := sqlutil.TxStmt(txn, s.selectCreatorIDFromAliasStmt)
|
||||
err = stmt.QueryRowContext(ctx, alias).Scan(&creatorID)
|
||||
if err == sql.ErrNoRows {
|
||||
return "", nil
|
||||
}
|
||||
|
|
|
@ -117,8 +117,9 @@ func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
|
|||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) {
|
||||
rows, err := s.selectRoomIDsStmt.QueryContext(ctx)
|
||||
func (s *roomStatements) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectRoomIDsStmt)
|
||||
rows, err := stmt.QueryContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -143,10 +144,11 @@ func (s *roomStatements) InsertRoomNID(
|
|||
return types.RoomNID(roomNID), err
|
||||
}
|
||||
|
||||
func (s *roomStatements) SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) {
|
||||
func (s *roomStatements) SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) {
|
||||
var info types.RoomInfo
|
||||
var latestNIDs pq.Int64Array
|
||||
err := s.selectRoomInfoStmt.QueryRowContext(ctx, roomID).Scan(
|
||||
stmt := sqlutil.TxStmt(txn, s.selectRoomInfoStmt)
|
||||
err := stmt.QueryRowContext(ctx, roomID).Scan(
|
||||
&info.RoomVersion, &info.RoomNID, &info.StateSnapshotNID, &latestNIDs,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
|
@ -170,7 +172,7 @@ func (s *roomStatements) SelectLatestEventNIDs(
|
|||
) ([]types.EventNID, types.StateSnapshotNID, error) {
|
||||
var nids pq.Int64Array
|
||||
var stateSnapshotNID int64
|
||||
stmt := s.selectLatestEventNIDsStmt
|
||||
stmt := sqlutil.TxStmt(txn, s.selectLatestEventNIDsStmt)
|
||||
err := stmt.QueryRowContext(ctx, int64(roomNID)).Scan(&nids, &stateSnapshotNID)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
|
@ -220,9 +222,10 @@ func (s *roomStatements) UpdateLatestEventNIDs(
|
|||
}
|
||||
|
||||
func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
|
||||
ctx context.Context, roomNIDs []types.RoomNID,
|
||||
ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID,
|
||||
) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) {
|
||||
rows, err := s.selectRoomVersionsForRoomNIDsStmt.QueryContext(ctx, roomNIDsAsArray(roomNIDs))
|
||||
stmt := sqlutil.TxStmt(txn, s.selectRoomVersionsForRoomNIDsStmt)
|
||||
rows, err := stmt.QueryContext(ctx, roomNIDsAsArray(roomNIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -239,12 +242,13 @@ func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
|
|||
return result, nil
|
||||
}
|
||||
|
||||
func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) {
|
||||
func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) ([]string, error) {
|
||||
var array pq.Int64Array
|
||||
for _, nid := range roomNIDs {
|
||||
array = append(array, int64(nid))
|
||||
}
|
||||
rows, err := s.bulkSelectRoomIDsStmt.QueryContext(ctx, array)
|
||||
stmt := sqlutil.TxStmt(txn, s.bulkSelectRoomIDsStmt)
|
||||
rows, err := stmt.QueryContext(ctx, array)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -260,12 +264,13 @@ func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types
|
|||
return roomIDs, nil
|
||||
}
|
||||
|
||||
func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error) {
|
||||
func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, roomIDs []string) ([]types.RoomNID, error) {
|
||||
var array pq.StringArray
|
||||
for _, roomID := range roomIDs {
|
||||
array = append(array, roomID)
|
||||
}
|
||||
rows, err := s.bulkSelectRoomNIDsStmt.QueryContext(ctx, array)
|
||||
stmt := sqlutil.TxStmt(txn, s.bulkSelectRoomNIDsStmt)
|
||||
rows, err := stmt.QueryContext(ctx, array)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -86,8 +86,7 @@ func prepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
|
|||
}
|
||||
|
||||
func (s *stateBlockStatements) BulkInsertStateData(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
entries types.StateEntries,
|
||||
) (id types.StateBlockNID, err error) {
|
||||
entries = entries[:util.SortAndUnique(entries)]
|
||||
|
@ -95,16 +94,18 @@ func (s *stateBlockStatements) BulkInsertStateData(
|
|||
for _, e := range entries {
|
||||
nids = append(nids, e.EventNID)
|
||||
}
|
||||
err = s.insertStateDataStmt.QueryRowContext(
|
||||
stmt := sqlutil.TxStmt(txn, s.insertStateDataStmt)
|
||||
err = stmt.QueryRowContext(
|
||||
ctx, nids.Hash(), eventNIDsAsArray(nids),
|
||||
).Scan(&id)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *stateBlockStatements) BulkSelectStateBlockEntries(
|
||||
ctx context.Context, stateBlockNIDs types.StateBlockNIDs,
|
||||
ctx context.Context, txn *sql.Tx, stateBlockNIDs types.StateBlockNIDs,
|
||||
) ([][]types.EventNID, error) {
|
||||
rows, err := s.bulkSelectStateBlockEntriesStmt.QueryContext(ctx, stateBlockNIDsAsArray(stateBlockNIDs))
|
||||
stmt := sqlutil.TxStmt(txn, s.bulkSelectStateBlockEntriesStmt)
|
||||
rows, err := stmt.QueryContext(ctx, stateBlockNIDsAsArray(stateBlockNIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -105,13 +105,14 @@ func (s *stateSnapshotStatements) InsertState(
|
|||
}
|
||||
|
||||
func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
|
||||
ctx context.Context, stateNIDs []types.StateSnapshotNID,
|
||||
ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID,
|
||||
) ([]types.StateBlockNIDList, error) {
|
||||
nids := make([]int64, len(stateNIDs))
|
||||
for i := range stateNIDs {
|
||||
nids[i] = int64(stateNIDs[i])
|
||||
}
|
||||
rows, err := s.bulkSelectStateBlockNIDsStmt.QueryContext(ctx, pq.Int64Array(nids))
|
||||
stmt := sqlutil.TxStmt(txn, s.bulkSelectStateBlockNIDsStmt)
|
||||
rows, err := stmt.QueryContext(ctx, pq.Int64Array(nids))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue