mirror of
https://github.com/hoernschen/dendrite.git
synced 2025-07-31 21:32:46 +00:00
Full roomserver input transactional isolation (#2141)
* Add transaction to all database tables in roomserver, rename latest events updater to room updater, use room updater for all RS input * Better transaction management * Tweak order * Handle cases where the room does not exist * Other fixes * More tweaks * Fill some gaps * Fill in the gaps * good lord it gets worse * Don't roll back transactions when events rejected * Pass through errors properly * Fix bugs * Fix incorrect error check * Don't panic on nil txns * Tweaks * Hopefully fix panics for good in SQLite this time * Fix rollback * Minor bug fixes with latest event updater * Some review comments * Revert "Some review comments" This reverts commit 0caf8cf53e62c33f7b83c52e9df1d963871f751e. * Fix a couple of bugs * Clearer commit and rollback results * Remove unnecessary prepares
This commit is contained in:
parent
4d9f5b2e57
commit
eb352a5f6b
35 changed files with 867 additions and 499 deletions
|
@ -76,15 +76,20 @@ func (s *eventJSONStatements) InsertEventJSON(
|
|||
}
|
||||
|
||||
func (s *eventJSONStatements) BulkSelectEventJSON(
|
||||
ctx context.Context, eventNIDs []types.EventNID,
|
||||
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
|
||||
) ([]tables.EventJSONPair, error) {
|
||||
iEventNIDs := make([]interface{}, len(eventNIDs))
|
||||
for k, v := range eventNIDs {
|
||||
iEventNIDs[k] = v
|
||||
}
|
||||
selectOrig := strings.Replace(bulkSelectEventJSONSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1)
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, selectOrig, iEventNIDs...)
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
if txn != nil {
|
||||
rows, err = txn.QueryContext(ctx, selectOrig, iEventNIDs...)
|
||||
} else {
|
||||
rows, err = s.db.QueryContext(ctx, selectOrig, iEventNIDs...)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -112,15 +112,20 @@ func (s *eventStateKeyStatements) SelectEventStateKeyNID(
|
|||
}
|
||||
|
||||
func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID(
|
||||
ctx context.Context, eventStateKeys []string,
|
||||
ctx context.Context, txn *sql.Tx, eventStateKeys []string,
|
||||
) (map[string]types.EventStateKeyNID, error) {
|
||||
iEventStateKeys := make([]interface{}, len(eventStateKeys))
|
||||
for k, v := range eventStateKeys {
|
||||
iEventStateKeys[k] = v
|
||||
}
|
||||
selectOrig := strings.Replace(bulkSelectEventStateKeySQL, "($1)", sqlutil.QueryVariadic(len(eventStateKeys)), 1)
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, selectOrig, iEventStateKeys...)
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
if txn != nil {
|
||||
rows, err = txn.QueryContext(ctx, selectOrig, iEventStateKeys...)
|
||||
} else {
|
||||
rows, err = s.db.QueryContext(ctx, selectOrig, iEventStateKeys...)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -138,15 +143,19 @@ func (s *eventStateKeyStatements) BulkSelectEventStateKeyNID(
|
|||
}
|
||||
|
||||
func (s *eventStateKeyStatements) BulkSelectEventStateKey(
|
||||
ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID,
|
||||
ctx context.Context, txn *sql.Tx, eventStateKeyNIDs []types.EventStateKeyNID,
|
||||
) (map[types.EventStateKeyNID]string, error) {
|
||||
iEventStateKeyNIDs := make([]interface{}, len(eventStateKeyNIDs))
|
||||
for k, v := range eventStateKeyNIDs {
|
||||
iEventStateKeyNIDs[k] = v
|
||||
}
|
||||
selectOrig := strings.Replace(bulkSelectEventStateKeyNIDSQL, "($1)", sqlutil.QueryVariadic(len(eventStateKeyNIDs)), 1)
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, selectOrig, iEventStateKeyNIDs...)
|
||||
selectPrep, err := s.db.Prepare(selectOrig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stmt := sqlutil.TxStmt(txn, selectPrep)
|
||||
rows, err := stmt.QueryContext(ctx, iEventStateKeyNIDs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -128,7 +128,7 @@ func (s *eventTypeStatements) SelectEventTypeNID(
|
|||
}
|
||||
|
||||
func (s *eventTypeStatements) BulkSelectEventTypeNID(
|
||||
ctx context.Context, eventTypes []string,
|
||||
ctx context.Context, txn *sql.Tx, eventTypes []string,
|
||||
) (map[string]types.EventTypeNID, error) {
|
||||
///////////////
|
||||
iEventTypes := make([]interface{}, len(eventTypes))
|
||||
|
@ -140,9 +140,10 @@ func (s *eventTypeStatements) BulkSelectEventTypeNID(
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stmt := sqlutil.TxStmt(txn, selectPrep)
|
||||
///////////////
|
||||
|
||||
rows, err := selectPrep.QueryContext(ctx, iEventTypes...)
|
||||
rows, err := stmt.QueryContext(ctx, iEventTypes...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -184,7 +184,7 @@ func (s *eventStatements) SelectEvent(
|
|||
// bulkSelectStateEventByID lookups a list of state events by event ID.
|
||||
// If any of the requested events are missing from the database it returns a types.MissingEventError
|
||||
func (s *eventStatements) BulkSelectStateEventByID(
|
||||
ctx context.Context, eventIDs []string,
|
||||
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
||||
) ([]types.StateEntry, error) {
|
||||
///////////////
|
||||
iEventIDs := make([]interface{}, len(eventIDs))
|
||||
|
@ -196,6 +196,7 @@ func (s *eventStatements) BulkSelectStateEventByID(
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
selectStmt = sqlutil.TxStmt(txn, selectStmt)
|
||||
///////////////
|
||||
|
||||
rows, err := selectStmt.QueryContext(ctx, iEventIDs...)
|
||||
|
@ -235,7 +236,7 @@ func (s *eventStatements) BulkSelectStateEventByID(
|
|||
// bulkSelectStateEventByID lookups a list of state events by event ID.
|
||||
// If any of the requested events are missing from the database it returns a types.MissingEventError
|
||||
func (s *eventStatements) BulkSelectStateEventByNID(
|
||||
ctx context.Context, eventNIDs []types.EventNID,
|
||||
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
|
||||
stateKeyTuples []types.StateKeyTuple,
|
||||
) ([]types.StateEntry, error) {
|
||||
tuples := stateKeyTupleSorter(stateKeyTuples)
|
||||
|
@ -263,6 +264,7 @@ func (s *eventStatements) BulkSelectStateEventByNID(
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("s.db.Prepare: %w", err)
|
||||
}
|
||||
selectStmt = sqlutil.TxStmt(txn, selectStmt)
|
||||
rows, err := selectStmt.QueryContext(ctx, params...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("selectStmt.QueryContext: %w", err)
|
||||
|
@ -291,7 +293,7 @@ func (s *eventStatements) BulkSelectStateEventByNID(
|
|||
// If any of the requested events are missing from the database it returns a types.MissingEventError.
|
||||
// If we do not have the state for any of the requested events it returns a types.MissingEventError.
|
||||
func (s *eventStatements) BulkSelectStateAtEventByID(
|
||||
ctx context.Context, eventIDs []string,
|
||||
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
||||
) ([]types.StateAtEvent, error) {
|
||||
///////////////
|
||||
iEventIDs := make([]interface{}, len(eventIDs))
|
||||
|
@ -303,6 +305,7 @@ func (s *eventStatements) BulkSelectStateAtEventByID(
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
selectStmt = sqlutil.TxStmt(txn, selectStmt)
|
||||
///////////////
|
||||
rows, err := selectStmt.QueryContext(ctx, iEventIDs...)
|
||||
if err != nil {
|
||||
|
@ -381,6 +384,7 @@ func (s *eventStatements) BulkSelectStateAtEventAndReference(
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
selectPrep = sqlutil.TxStmt(txn, selectPrep)
|
||||
//////////////
|
||||
|
||||
rows, err := sqlutil.TxStmt(txn, selectPrep).QueryContext(ctx, iEventNIDs...)
|
||||
|
@ -454,7 +458,7 @@ func (s *eventStatements) BulkSelectEventReference(
|
|||
}
|
||||
|
||||
// bulkSelectEventID returns a map from numeric event ID to string event ID.
|
||||
func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []types.EventNID) (map[types.EventNID]string, error) {
|
||||
func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error) {
|
||||
///////////////
|
||||
iEventNIDs := make([]interface{}, len(eventNIDs))
|
||||
for k, v := range eventNIDs {
|
||||
|
@ -465,6 +469,7 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []typ
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
selectStmt = sqlutil.TxStmt(txn, selectStmt)
|
||||
///////////////
|
||||
|
||||
rows, err := selectStmt.QueryContext(ctx, iEventNIDs...)
|
||||
|
@ -490,7 +495,7 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, eventNIDs []typ
|
|||
|
||||
// bulkSelectEventNIDs returns a map from string event ID to numeric event ID.
|
||||
// If an event ID is not in the database then it is omitted from the map.
|
||||
func (s *eventStatements) BulkSelectEventNID(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error) {
|
||||
func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) {
|
||||
///////////////
|
||||
iEventIDs := make([]interface{}, len(eventIDs))
|
||||
for k, v := range eventIDs {
|
||||
|
@ -501,6 +506,7 @@ func (s *eventStatements) BulkSelectEventNID(ctx context.Context, eventIDs []str
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
selectStmt = sqlutil.TxStmt(txn, selectStmt)
|
||||
///////////////
|
||||
rows, err := selectStmt.QueryContext(ctx, iEventIDs...)
|
||||
if err != nil {
|
||||
|
@ -538,13 +544,14 @@ func (s *eventStatements) SelectMaxEventDepth(ctx context.Context, txn *sql.Tx,
|
|||
}
|
||||
|
||||
func (s *eventStatements) SelectRoomNIDsForEventNIDs(
|
||||
ctx context.Context, eventNIDs []types.EventNID,
|
||||
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
|
||||
) (map[types.EventNID]types.RoomNID, error) {
|
||||
sqlStr := strings.Replace(selectRoomNIDsForEventNIDsSQL, "($1)", sqlutil.QueryVariadic(len(eventNIDs)), 1)
|
||||
sqlPrep, err := s.db.Prepare(sqlStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sqlPrep = sqlutil.TxStmt(txn, sqlPrep)
|
||||
iEventNIDs := make([]interface{}, len(eventNIDs))
|
||||
for i, v := range eventNIDs {
|
||||
iEventNIDs[i] = v
|
||||
|
|
|
@ -88,8 +88,8 @@ func prepareInvitesTable(db *sql.DB) (tables.Invites, error) {
|
|||
}
|
||||
|
||||
func (s *inviteStatements) InsertInviteEvent(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx, inviteEventID string, roomNID types.RoomNID,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
inviteEventID string, roomNID types.RoomNID,
|
||||
targetUserNID, senderUserNID types.EventStateKeyNID,
|
||||
inviteEventJSON []byte,
|
||||
) (bool, error) {
|
||||
|
@ -109,8 +109,8 @@ func (s *inviteStatements) InsertInviteEvent(
|
|||
}
|
||||
|
||||
func (s *inviteStatements) UpdateInviteRetired(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||
) (eventIDs []string, err error) {
|
||||
// gather all the event IDs we will retire
|
||||
stmt := sqlutil.TxStmt(txn, s.selectInvitesAboutToRetireStmt)
|
||||
|
@ -134,10 +134,11 @@ func (s *inviteStatements) UpdateInviteRetired(
|
|||
|
||||
// selectInviteActiveForUserInRoom returns a list of sender state key NIDs
|
||||
func (s *inviteStatements) SelectInviteActiveForUserInRoom(
|
||||
ctx context.Context,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
targetUserNID types.EventStateKeyNID, roomNID types.RoomNID,
|
||||
) ([]types.EventStateKeyNID, []string, error) {
|
||||
rows, err := s.selectInviteActiveForUserInRoomStmt.QueryContext(
|
||||
stmt := sqlutil.TxStmt(txn, s.selectInviteActiveForUserInRoomStmt)
|
||||
rows, err := stmt.QueryContext(
|
||||
ctx, targetUserNID, roomNID,
|
||||
)
|
||||
if err != nil {
|
||||
|
|
|
@ -184,17 +184,18 @@ func (s *membershipStatements) SelectMembershipForUpdate(
|
|||
}
|
||||
|
||||
func (s *membershipStatements) SelectMembershipFromRoomAndTarget(
|
||||
ctx context.Context,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||
) (eventNID types.EventNID, membership tables.MembershipState, forgotten bool, err error) {
|
||||
err = s.selectMembershipFromRoomAndTargetStmt.QueryRowContext(
|
||||
stmt := sqlutil.TxStmt(txn, s.selectMembershipFromRoomAndTargetStmt)
|
||||
err = stmt.QueryRowContext(
|
||||
ctx, roomNID, targetUserNID,
|
||||
).Scan(&membership, &eventNID, &forgotten)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *membershipStatements) SelectMembershipsFromRoom(
|
||||
ctx context.Context,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
roomNID types.RoomNID, localOnly bool,
|
||||
) (eventNIDs []types.EventNID, err error) {
|
||||
var selectStmt *sql.Stmt
|
||||
|
@ -203,6 +204,7 @@ func (s *membershipStatements) SelectMembershipsFromRoom(
|
|||
} else {
|
||||
selectStmt = s.selectMembershipsFromRoomStmt
|
||||
}
|
||||
selectStmt = sqlutil.TxStmt(txn, selectStmt)
|
||||
rows, err := selectStmt.QueryContext(ctx, roomNID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -220,7 +222,7 @@ func (s *membershipStatements) SelectMembershipsFromRoom(
|
|||
}
|
||||
|
||||
func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
|
||||
ctx context.Context,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
roomNID types.RoomNID, membership tables.MembershipState, localOnly bool,
|
||||
) (eventNIDs []types.EventNID, err error) {
|
||||
var stmt *sql.Stmt
|
||||
|
@ -229,6 +231,7 @@ func (s *membershipStatements) SelectMembershipsFromRoomAndMembership(
|
|||
} else {
|
||||
stmt = s.selectMembershipsFromRoomAndMembershipStmt
|
||||
}
|
||||
stmt = sqlutil.TxStmt(txn, stmt)
|
||||
rows, err := stmt.QueryContext(ctx, roomNID, membership)
|
||||
if err != nil {
|
||||
return
|
||||
|
@ -258,9 +261,10 @@ func (s *membershipStatements) UpdateMembership(
|
|||
}
|
||||
|
||||
func (s *membershipStatements) SelectRoomsWithMembership(
|
||||
ctx context.Context, userID types.EventStateKeyNID, membershipState tables.MembershipState,
|
||||
ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, membershipState tables.MembershipState,
|
||||
) ([]types.RoomNID, error) {
|
||||
rows, err := s.selectRoomsWithMembershipStmt.QueryContext(ctx, membershipState, userID)
|
||||
stmt := sqlutil.TxStmt(txn, s.selectRoomsWithMembershipStmt)
|
||||
rows, err := stmt.QueryContext(ctx, membershipState, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -276,13 +280,19 @@ func (s *membershipStatements) SelectRoomsWithMembership(
|
|||
return roomNIDs, nil
|
||||
}
|
||||
|
||||
func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) {
|
||||
func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) (map[types.EventStateKeyNID]int, error) {
|
||||
iRoomNIDs := make([]interface{}, len(roomNIDs))
|
||||
for i, v := range roomNIDs {
|
||||
iRoomNIDs[i] = v
|
||||
}
|
||||
query := strings.Replace(selectJoinedUsersSetForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomNIDs)), 1)
|
||||
rows, err := s.db.QueryContext(ctx, query, iRoomNIDs...)
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
if txn != nil {
|
||||
rows, err = txn.QueryContext(ctx, query, iRoomNIDs...)
|
||||
} else {
|
||||
rows, err = s.db.QueryContext(ctx, query, iRoomNIDs...)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -299,8 +309,9 @@ func (s *membershipStatements) SelectJoinedUsersSetForRooms(ctx context.Context,
|
|||
return result, rows.Err()
|
||||
}
|
||||
|
||||
func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) {
|
||||
rows, err := s.selectKnownUsersStmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit)
|
||||
func (s *membershipStatements) SelectKnownUsers(ctx context.Context, txn *sql.Tx, userID types.EventStateKeyNID, searchString string, limit int) ([]string, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectKnownUsersStmt)
|
||||
rows, err := stmt.QueryContext(ctx, userID, fmt.Sprintf("%%%s%%", searchString), limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -317,8 +328,8 @@ func (s *membershipStatements) SelectKnownUsers(ctx context.Context, userID type
|
|||
}
|
||||
|
||||
func (s *membershipStatements) UpdateForgetMembership(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||
forget bool,
|
||||
) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.updateMembershipForgetRoomStmt).ExecContext(
|
||||
|
@ -327,9 +338,10 @@ func (s *membershipStatements) UpdateForgetMembership(
|
|||
return err
|
||||
}
|
||||
|
||||
func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) {
|
||||
func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error) {
|
||||
var nid types.RoomNID
|
||||
err := s.selectLocalServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid)
|
||||
stmt := sqlutil.TxStmt(txn, s.selectLocalServerInRoomStmt)
|
||||
err := stmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID).Scan(&nid)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return false, nil
|
||||
|
@ -340,9 +352,10 @@ func (s *membershipStatements) SelectLocalServerInRoom(ctx context.Context, room
|
|||
return found, nil
|
||||
}
|
||||
|
||||
func (s *membershipStatements) SelectServerInRoom(ctx context.Context, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) {
|
||||
func (s *membershipStatements) SelectServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) {
|
||||
var nid types.RoomNID
|
||||
err := s.selectServerInRoomStmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid)
|
||||
stmt := sqlutil.TxStmt(txn, s.selectServerInRoomStmt)
|
||||
err := stmt.QueryRowContext(ctx, tables.MembershipStateJoin, roomNID, serverName).Scan(&nid)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return false, nil
|
||||
|
|
|
@ -75,9 +75,10 @@ func (s *publishedStatements) UpsertRoomPublished(
|
|||
}
|
||||
|
||||
func (s *publishedStatements) SelectPublishedFromRoomID(
|
||||
ctx context.Context, roomID string,
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) (published bool, err error) {
|
||||
err = s.selectPublishedStmt.QueryRowContext(ctx, roomID).Scan(&published)
|
||||
stmt := sqlutil.TxStmt(txn, s.selectPublishedStmt)
|
||||
err = stmt.QueryRowContext(ctx, roomID).Scan(&published)
|
||||
if err == sql.ErrNoRows {
|
||||
return false, nil
|
||||
}
|
||||
|
@ -85,9 +86,10 @@ func (s *publishedStatements) SelectPublishedFromRoomID(
|
|||
}
|
||||
|
||||
func (s *publishedStatements) SelectAllPublishedRooms(
|
||||
ctx context.Context, published bool,
|
||||
ctx context.Context, txn *sql.Tx, published bool,
|
||||
) ([]string, error) {
|
||||
rows, err := s.selectAllPublishedStmt.QueryContext(ctx, published)
|
||||
stmt := sqlutil.TxStmt(txn, s.selectAllPublishedStmt)
|
||||
rows, err := stmt.QueryContext(ctx, published)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -91,9 +91,10 @@ func (s *roomAliasesStatements) InsertRoomAlias(
|
|||
}
|
||||
|
||||
func (s *roomAliasesStatements) SelectRoomIDFromAlias(
|
||||
ctx context.Context, alias string,
|
||||
ctx context.Context, txn *sql.Tx, alias string,
|
||||
) (roomID string, err error) {
|
||||
err = s.selectRoomIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&roomID)
|
||||
stmt := sqlutil.TxStmt(txn, s.selectRoomIDFromAliasStmt)
|
||||
err = stmt.QueryRowContext(ctx, alias).Scan(&roomID)
|
||||
if err == sql.ErrNoRows {
|
||||
return "", nil
|
||||
}
|
||||
|
@ -101,10 +102,11 @@ func (s *roomAliasesStatements) SelectRoomIDFromAlias(
|
|||
}
|
||||
|
||||
func (s *roomAliasesStatements) SelectAliasesFromRoomID(
|
||||
ctx context.Context, roomID string,
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) (aliases []string, err error) {
|
||||
aliases = []string{}
|
||||
rows, err := s.selectAliasesFromRoomIDStmt.QueryContext(ctx, roomID)
|
||||
stmt := sqlutil.TxStmt(txn, s.selectAliasesFromRoomIDStmt)
|
||||
rows, err := stmt.QueryContext(ctx, roomID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -124,9 +126,10 @@ func (s *roomAliasesStatements) SelectAliasesFromRoomID(
|
|||
}
|
||||
|
||||
func (s *roomAliasesStatements) SelectCreatorIDFromAlias(
|
||||
ctx context.Context, alias string,
|
||||
ctx context.Context, txn *sql.Tx, alias string,
|
||||
) (creatorID string, err error) {
|
||||
err = s.selectCreatorIDFromAliasStmt.QueryRowContext(ctx, alias).Scan(&creatorID)
|
||||
stmt := sqlutil.TxStmt(txn, s.selectCreatorIDFromAliasStmt)
|
||||
err = stmt.QueryRowContext(ctx, alias).Scan(&creatorID)
|
||||
if err == sql.ErrNoRows {
|
||||
return "", nil
|
||||
}
|
||||
|
|
|
@ -107,8 +107,9 @@ func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
|
|||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) {
|
||||
rows, err := s.selectRoomIDsStmt.QueryContext(ctx)
|
||||
func (s *roomStatements) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectRoomIDsStmt)
|
||||
rows, err := stmt.QueryContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -124,10 +125,11 @@ func (s *roomStatements) SelectRoomIDs(ctx context.Context) ([]string, error) {
|
|||
return roomIDs, nil
|
||||
}
|
||||
|
||||
func (s *roomStatements) SelectRoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) {
|
||||
func (s *roomStatements) SelectRoomInfo(ctx context.Context, txn *sql.Tx, roomID string) (*types.RoomInfo, error) {
|
||||
var info types.RoomInfo
|
||||
var latestNIDsJSON string
|
||||
err := s.selectRoomInfoStmt.QueryRowContext(ctx, roomID).Scan(
|
||||
stmt := sqlutil.TxStmt(txn, s.selectRoomInfoStmt)
|
||||
err := stmt.QueryRowContext(ctx, roomID).Scan(
|
||||
&info.RoomVersion, &info.RoomNID, &info.StateSnapshotNID, &latestNIDsJSON,
|
||||
)
|
||||
if err != nil {
|
||||
|
@ -224,13 +226,14 @@ func (s *roomStatements) UpdateLatestEventNIDs(
|
|||
}
|
||||
|
||||
func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
|
||||
ctx context.Context, roomNIDs []types.RoomNID,
|
||||
ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID,
|
||||
) (map[types.RoomNID]gomatrixserverlib.RoomVersion, error) {
|
||||
sqlStr := strings.Replace(selectRoomVersionsForRoomNIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
|
||||
sqlPrep, err := s.db.Prepare(sqlStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sqlPrep = sqlutil.TxStmt(txn, sqlPrep)
|
||||
iRoomNIDs := make([]interface{}, len(roomNIDs))
|
||||
for i, v := range roomNIDs {
|
||||
iRoomNIDs[i] = v
|
||||
|
@ -252,13 +255,19 @@ func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
|
|||
return result, nil
|
||||
}
|
||||
|
||||
func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types.RoomNID) ([]string, error) {
|
||||
func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roomNIDs []types.RoomNID) ([]string, error) {
|
||||
iRoomNIDs := make([]interface{}, len(roomNIDs))
|
||||
for i, v := range roomNIDs {
|
||||
iRoomNIDs[i] = v
|
||||
}
|
||||
sqlQuery := strings.Replace(bulkSelectRoomIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomNIDs)), 1)
|
||||
rows, err := s.db.QueryContext(ctx, sqlQuery, iRoomNIDs...)
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
if txn != nil {
|
||||
rows, err = txn.QueryContext(ctx, sqlQuery, iRoomNIDs...)
|
||||
} else {
|
||||
rows, err = s.db.QueryContext(ctx, sqlQuery, iRoomNIDs...)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -274,13 +283,19 @@ func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, roomNIDs []types
|
|||
return roomIDs, nil
|
||||
}
|
||||
|
||||
func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, roomIDs []string) ([]types.RoomNID, error) {
|
||||
func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, roomIDs []string) ([]types.RoomNID, error) {
|
||||
iRoomIDs := make([]interface{}, len(roomIDs))
|
||||
for i, v := range roomIDs {
|
||||
iRoomIDs[i] = v
|
||||
}
|
||||
sqlQuery := strings.Replace(bulkSelectRoomNIDsSQL, "($1)", sqlutil.QueryVariadic(len(roomIDs)), 1)
|
||||
rows, err := s.db.QueryContext(ctx, sqlQuery, iRoomIDs...)
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
if txn != nil {
|
||||
rows, err = txn.QueryContext(ctx, sqlQuery, iRoomIDs...)
|
||||
} else {
|
||||
rows, err = s.db.QueryContext(ctx, sqlQuery, iRoomIDs...)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -81,8 +81,7 @@ func prepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
|
|||
}
|
||||
|
||||
func (s *stateBlockStatements) BulkInsertStateData(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
entries types.StateEntries,
|
||||
) (id types.StateBlockNID, err error) {
|
||||
entries = entries[:util.SortAndUnique(entries)]
|
||||
|
@ -94,14 +93,15 @@ func (s *stateBlockStatements) BulkInsertStateData(
|
|||
if err != nil {
|
||||
return 0, fmt.Errorf("json.Marshal: %w", err)
|
||||
}
|
||||
err = s.insertStateDataStmt.QueryRowContext(
|
||||
stmt := sqlutil.TxStmt(txn, s.insertStateDataStmt)
|
||||
err = stmt.QueryRowContext(
|
||||
ctx, nids.Hash(), js,
|
||||
).Scan(&id)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *stateBlockStatements) BulkSelectStateBlockEntries(
|
||||
ctx context.Context, stateBlockNIDs types.StateBlockNIDs,
|
||||
ctx context.Context, txn *sql.Tx, stateBlockNIDs types.StateBlockNIDs,
|
||||
) ([][]types.EventNID, error) {
|
||||
intfs := make([]interface{}, len(stateBlockNIDs))
|
||||
for i := range stateBlockNIDs {
|
||||
|
@ -112,6 +112,7 @@ func (s *stateBlockStatements) BulkSelectStateBlockEntries(
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
selectStmt = sqlutil.TxStmt(txn, selectStmt)
|
||||
rows, err := selectStmt.QueryContext(ctx, intfs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -106,7 +106,7 @@ func (s *stateSnapshotStatements) InsertState(
|
|||
}
|
||||
|
||||
func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
|
||||
ctx context.Context, stateNIDs []types.StateSnapshotNID,
|
||||
ctx context.Context, txn *sql.Tx, stateNIDs []types.StateSnapshotNID,
|
||||
) ([]types.StateBlockNIDList, error) {
|
||||
nids := make([]interface{}, len(stateNIDs))
|
||||
for k, v := range stateNIDs {
|
||||
|
@ -117,6 +117,7 @@ func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
selectStmt = sqlutil.TxStmt(txn, selectStmt)
|
||||
|
||||
rows, err := selectStmt.QueryContext(ctx, nids...)
|
||||
if err != nil {
|
||||
|
|
|
@ -172,23 +172,23 @@ func (d *Database) prepare(db *sql.DB, cache caching.RoomServerCaches) error {
|
|||
return err
|
||||
}
|
||||
d.Database = shared.Database{
|
||||
DB: db,
|
||||
Cache: cache,
|
||||
Writer: sqlutil.NewExclusiveWriter(),
|
||||
EventsTable: events,
|
||||
EventTypesTable: eventTypes,
|
||||
EventStateKeysTable: eventStateKeys,
|
||||
EventJSONTable: eventJSON,
|
||||
RoomsTable: rooms,
|
||||
StateBlockTable: stateBlock,
|
||||
StateSnapshotTable: stateSnapshot,
|
||||
PrevEventsTable: prevEvents,
|
||||
RoomAliasesTable: roomAliases,
|
||||
InvitesTable: invites,
|
||||
MembershipTable: membership,
|
||||
PublishedTable: published,
|
||||
RedactionsTable: redactions,
|
||||
GetLatestEventsForUpdateFn: d.GetLatestEventsForUpdate,
|
||||
DB: db,
|
||||
Cache: cache,
|
||||
Writer: sqlutil.NewExclusiveWriter(),
|
||||
EventsTable: events,
|
||||
EventTypesTable: eventTypes,
|
||||
EventStateKeysTable: eventStateKeys,
|
||||
EventJSONTable: eventJSON,
|
||||
RoomsTable: rooms,
|
||||
StateBlockTable: stateBlock,
|
||||
StateSnapshotTable: stateSnapshot,
|
||||
PrevEventsTable: prevEvents,
|
||||
RoomAliasesTable: roomAliases,
|
||||
InvitesTable: invites,
|
||||
MembershipTable: membership,
|
||||
PublishedTable: published,
|
||||
RedactionsTable: redactions,
|
||||
GetRoomUpdaterFn: d.GetRoomUpdater,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -201,16 +201,16 @@ func (d *Database) SupportsConcurrentRoomInputs() bool {
|
|||
return false
|
||||
}
|
||||
|
||||
func (d *Database) GetLatestEventsForUpdate(
|
||||
ctx context.Context, roomInfo types.RoomInfo,
|
||||
) (*shared.LatestEventsUpdater, error) {
|
||||
func (d *Database) GetRoomUpdater(
|
||||
ctx context.Context, roomInfo *types.RoomInfo,
|
||||
) (*shared.RoomUpdater, error) {
|
||||
// TODO: Do not use transactions. We should be holding open this transaction but we cannot have
|
||||
// multiple write transactions on sqlite. The code will perform additional
|
||||
// write transactions independent of this one which will consistently cause
|
||||
// 'database is locked' errors. As sqlite doesn't support multi-process on the
|
||||
// same DB anyway, and we only execute updates sequentially, the only worries
|
||||
// are for rolling back when things go wrong. (atomicity)
|
||||
return shared.NewLatestEventsUpdater(ctx, &d.Database, nil, roomInfo)
|
||||
return shared.NewRoomUpdater(ctx, &d.Database, nil, roomInfo)
|
||||
}
|
||||
|
||||
func (d *Database) MembershipUpdater(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue