Full roomserver input transactional isolation (#2141)

* Add transaction to all database tables in roomserver, rename latest events updater to room updater, use room updater for all RS input

* Better transaction management

* Tweak order

* Handle cases where the room does not exist

* Other fixes

* More tweaks

* Fill some gaps

* Fill in the gaps

* good lord it gets worse

* Don't roll back transactions when events rejected

* Pass through errors properly

* Fix bugs

* Fix incorrect error check

* Don't panic on nil txns

* Tweaks

* Hopefully fix panics for good in SQLite this time

* Fix rollback

* Minor bug fixes with latest event updater

* Some review comments

* Revert "Some review comments"

This reverts commit 0caf8cf53e62c33f7b83c52e9df1d963871f751e.

* Fix a couple of bugs

* Clearer commit and rollback results

* Remove unnecessary prepares
This commit is contained in:
Neil Alexander 2022-02-04 10:39:34 +00:00 committed by GitHub
parent 4d9f5b2e57
commit eb352a5f6b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
35 changed files with 867 additions and 499 deletions

View file

@ -186,8 +186,8 @@ func prepareMembershipTable(db *sql.DB) (tables.Membership, error) {
}
func (s *membershipStatements) InsertMembership(
ctx context.Context,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
localTarget bool,
) error {
stmt := sqlutil.TxStmt(txn, s.insertMembershipStmt)
@ -196,8 +196,8 @@ func (s *membershipStatements) InsertMembership(
}
func (s *membershipStatements) SelectMembershipForUpdate(
ctx context.Context,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (membership tables.MembershipState, err error) {
err = sqlutil.TxStmt(txn, s.selectMembershipForUpdateStmt).QueryRowContext(
ctx, roomNID, targetUserNID,
@ -206,17 +206,19 @@ func (s *membershipStatements) SelectMembershipForUpdate(
}
func (s *membershipStatements) SelectMembershipFromRoomAndTarget(
ctx context.Context,
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) (eventNID types.EventNID, membership tables.MembershipState, forgotten bool, err error) {
err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext(
stmt := sqlutil.TxStmt(txn, s.selectMembershipFromRoomAndTargetStmt)
err = stmt.QueryRowContext(
ctx, roomNID, targetUserNID,
).Scan(&membership, &eventNID, &forgotten)
return
}
func (s *membershipStatements) SelectMembershipsFromRoom(
ctx context.Context, roomNID types.RoomNID, localOnly bool,
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, localOnly bool,
) (eventNIDs []types.EventNID, err error) {
var stmt *sql.Stmt
if localOnly {
@ -224,6 +226,7 @@ func (s *membershipStatements) SelectMembershipsFromRoom(
} else {
stmt = s.selectMembershipsFromRoomStmt
}
stmt = sqlutil.TxStmt(txn, stmt)
rows, err := stmt.QueryContext(ctx, roomNID)
if err != nil {
return
@ -241,7 +244,7 @@ func (s *membershipStatements) SelectMembershipsFromRoom(
}
func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
ctx context.Context,
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, membership tables.MembershipState, localOnly bool,
) (eventNIDs []types.EventNID, err error) {
var rows *sql.Rows
@ -251,6 +254,7 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
} else {
stmt = s.selectMembershipsFromRoomAndMembershipStmt
}
stmt = sqlutil.TxStmt(txn, stmt)
rows, err = stmt.QueryContext(ctx, roomNID, membership)
if err != nil {
return
@ -268,8 +272,8 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
}
func (s *membershipStatements) UpdateMembership(
ctx context.Context,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership tables.MembershipState,
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, senderUserNID types.EventStateKeyNID, membership tables.MembershipState,
eventNID types.EventNID, forgotten bool,
) error {
_, err := sqlutil.TxStmt(txn, s.updateMembershipStmt).ExecContext(
@ -279,9 +283,11 @@ func (s *membershipStatements) UpdateMembership(
}
func (s *membershipStatements) SelectRoomsWithMembership(
ctx context.Context, userID types.EventStateKeyNID, membershipState tables.MembershipState,
ctx context.Context, txn *sql.Tx,
userID types.EventStateKeyNID, membershipState tables.MembershipState,
) ([]types.RoomNID, error) {
rows, err := s.selectRoomsWithMembershipStmt.QueryContext(ctx, membershipState, userID)
stmt := sqlutil.TxStmt(txn, s.selectRoomsWithMembershipStmt)
rows, err := stmt.QueryContext(ctx, membershipState, userID)
if err != nil {
return nil, err
}
@ -297,12 +303,16 @@ func (s *membershipStatements) SelectRoomsWithMembership(
return roomNIDs, nil
}
func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) {
func (s *membershipStatements) SelectJoinedUsersSetForRooms(
ctx context.Context, txn *sql.Tx,
roomNIDs []types.RoomNID,
) (map[types.EventStateKeyNID]int, error) {
roomIDarray := make([]int64, len(roomNIDs))
for i := range roomNIDs {
roomIDarray[i] = int64(roomNIDs[i])
}
rows, err := s.selectJoinedUsersSetForRoomsStmt.QueryContext(ctx, pq.Int64Array(roomIDarray))
stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersSetForRoomsStmt)
rows, err := stmt.QueryContext(ctx, pq.Int64Array(roomIDarray))
if err != nil {
return nil, err
}
@ -319,8 +329,12 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context,
return result, rows.Err()
}
func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) {
rows, err := s.selectKnownUsersStmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit)
func (s *membershipStatements) SelectKnownUsers(
ctx context.Context, txn *sql.Tx,
userID types.EventStateKeyNID, searchString string, limit int,
) ([]string, error) {
stmt := sqlutil.TxStmt(txn, s.selectKnownUsersStmt)
rows, err := stmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit)
if err != nil {
return nil, err
}
@ -337,9 +351,8 @@ func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID type
}
func (s *membershipStatements) UpdateForgetMembership(
ctx context.Context,
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
forget bool,
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool,
) error {
_, err := sqlutil.TxStmt(txn, s.updateMembershipForgetRoomStmt).ExecContext(
ctx, roomNID, targetUserNID, forget,
@ -347,9 +360,13 @@ func (s *membershipStatements) UpdateForgetMembership(
return err
}
func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) {
func (s *membershipStatements) SelectLocalServerInRoom(
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID,
) (bool, error) {
var nid types.RoomNID
err := s.selectLocalServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid)
stmt := sqlutil.TxStmt(txn, s.selectLocalServerInRoomStmt)
err := stmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid)
if err != nil {
if err == sql.ErrNoRows {
return false, nil
@ -360,9 +377,13 @@ func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, room
return found, nil
}
func (s *membershipStatements) SelectServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) {
func (s *membershipStatements) SelectServerInRoom(
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, serverName gomatrixserverlib.ServerName,
) (bool, error) {
var nid types.RoomNID
err := s.selectServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid)
stmt := sqlutil.TxStmt(txn, s.selectServerInRoomStmt)
err := stmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid)
if err != nil {
if err == sql.ErrNoRows {
return false, nil