mirror of
https://github.com/hoernschen/dendrite.git
synced 2025-08-01 13:52:46 +00:00
Consistent *sql.Tx
usage across sync API (#2744)
This tidies up the `storage` package so that everything takes a transaction parameter instead of something things that do and some that don't.
This commit is contained in:
parent
a574ed5369
commit
3f9e38e80a
20 changed files with 99 additions and 77 deletions
|
@ -91,14 +91,14 @@ func (s *accountDataStatements) InsertAccountData(
|
|||
}
|
||||
|
||||
func (s *accountDataStatements) SelectAccountDataInRange(
|
||||
ctx context.Context,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
userID string,
|
||||
r types.Range,
|
||||
filter *gomatrixserverlib.EventFilter,
|
||||
) (data map[string][]string, pos types.StreamPosition, err error) {
|
||||
data = make(map[string][]string)
|
||||
stmt, params, err := prepareWithFilters(
|
||||
s.db, nil, selectAccountDataInRangeSQL,
|
||||
s.db, txn, selectAccountDataInRangeSQL,
|
||||
[]interface{}{
|
||||
userID, r.Low(), r.High(),
|
||||
},
|
||||
|
|
|
@ -82,9 +82,9 @@ func (s *backwardExtremitiesStatements) InsertsBackwardExtremity(
|
|||
}
|
||||
|
||||
func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom(
|
||||
ctx context.Context, roomID string,
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) (bwExtrems map[string][]string, err error) {
|
||||
rows, err := s.selectBackwardExtremitiesForRoomStmt.QueryContext(ctx, roomID)
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectBackwardExtremitiesForRoomStmt).QueryContext(ctx, roomID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
|
|
@ -163,9 +163,9 @@ func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (t
|
|||
|
||||
// SelectJoinedUsers returns a map of room ID to a list of joined user IDs.
|
||||
func (s *currentRoomStateStatements) SelectJoinedUsers(
|
||||
ctx context.Context,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
) (map[string][]string, error) {
|
||||
rows, err := s.selectJoinedUsersStmt.QueryContext(ctx)
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectJoinedUsersStmt).QueryContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -187,7 +187,7 @@ func (s *currentRoomStateStatements) SelectJoinedUsers(
|
|||
|
||||
// SelectJoinedUsersInRoom returns a map of room ID to a list of joined user IDs for a given room.
|
||||
func (s *currentRoomStateStatements) SelectJoinedUsersInRoom(
|
||||
ctx context.Context, roomIDs []string,
|
||||
ctx context.Context, txn *sql.Tx, roomIDs []string,
|
||||
) (map[string][]string, error) {
|
||||
query := strings.Replace(selectJoinedUsersInRoomSQL, "($1)", sqlutil.QueryVariadic(len(roomIDs)), 1)
|
||||
params := make([]interface{}, 0, len(roomIDs))
|
||||
|
@ -200,7 +200,7 @@ func (s *currentRoomStateStatements) SelectJoinedUsersInRoom(
|
|||
}
|
||||
defer internal.CloseAndLogIfError(ctx, stmt, "SelectJoinedUsersInRoom: stmt.close() failed")
|
||||
|
||||
rows, err := stmt.QueryContext(ctx, params...)
|
||||
rows, err := sqlutil.TxStmt(txn, stmt).QueryContext(ctx, params...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -401,9 +401,9 @@ func rowsToEvents(rows *sql.Rows) ([]*gomatrixserverlib.HeaderedEvent, error) {
|
|||
}
|
||||
|
||||
func (s *currentRoomStateStatements) SelectStateEvent(
|
||||
ctx context.Context, roomID, evType, stateKey string,
|
||||
ctx context.Context, txn *sql.Tx, roomID, evType, stateKey string,
|
||||
) (*gomatrixserverlib.HeaderedEvent, error) {
|
||||
stmt := s.selectStateEventStmt
|
||||
stmt := sqlutil.TxStmt(txn, s.selectStateEventStmt)
|
||||
var res []byte
|
||||
err := stmt.QueryRowContext(ctx, roomID, evType, stateKey).Scan(&res)
|
||||
if err == sql.ErrNoRows {
|
||||
|
@ -429,10 +429,17 @@ func (s *currentRoomStateStatements) SelectSharedUsers(
|
|||
params[k+1] = v
|
||||
}
|
||||
|
||||
var provider sqlutil.QueryProvider
|
||||
if txn == nil {
|
||||
provider = s.db
|
||||
} else {
|
||||
provider = txn
|
||||
}
|
||||
|
||||
result := make([]string, 0, len(otherUserIDs))
|
||||
query := strings.Replace(selectSharedUsersSQL, "($2)", sqlutil.QueryVariadicOffset(len(otherUserIDs), 1), 1)
|
||||
err := sqlutil.RunLimitedVariablesQuery(
|
||||
ctx, query, s.db, params, sqlutil.SQLite3MaxVariables,
|
||||
ctx, query, provider, params, sqlutil.SQLite3MaxVariables,
|
||||
func(rows *sql.Rows) error {
|
||||
var stateKey string
|
||||
for rows.Next() {
|
||||
|
|
|
@ -20,6 +20,7 @@ import (
|
|||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
@ -77,11 +78,11 @@ func NewSqliteFilterTable(db *sql.DB) (tables.Filter, error) {
|
|||
}
|
||||
|
||||
func (s *filterStatements) SelectFilter(
|
||||
ctx context.Context, target *gomatrixserverlib.Filter, localpart string, filterID string,
|
||||
ctx context.Context, txn *sql.Tx, target *gomatrixserverlib.Filter, localpart string, filterID string,
|
||||
) error {
|
||||
// Retrieve filter from database (stored as canonical JSON)
|
||||
var filterData []byte
|
||||
err := s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filterData)
|
||||
err := sqlutil.TxStmt(txn, s.selectFilterStmt).QueryRowContext(ctx, localpart, filterID).Scan(&filterData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -94,7 +95,7 @@ func (s *filterStatements) SelectFilter(
|
|||
}
|
||||
|
||||
func (s *filterStatements) InsertFilter(
|
||||
ctx context.Context, filter *gomatrixserverlib.Filter, localpart string,
|
||||
ctx context.Context, txn *sql.Tx, filter *gomatrixserverlib.Filter, localpart string,
|
||||
) (filterID string, err error) {
|
||||
var existingFilterID string
|
||||
|
||||
|
@ -115,8 +116,9 @@ func (s *filterStatements) InsertFilter(
|
|||
// This can result in a race condition when two clients try to insert the
|
||||
// same filter and localpart at the same time, however this is not a
|
||||
// problem as both calls will result in the same filterID
|
||||
err = s.selectFilterIDByContentStmt.QueryRowContext(ctx,
|
||||
localpart, filterJSON).Scan(&existingFilterID)
|
||||
err = sqlutil.TxStmt(txn, s.selectFilterIDByContentStmt).QueryRowContext(
|
||||
ctx, localpart, filterJSON,
|
||||
).Scan(&existingFilterID)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return "", err
|
||||
}
|
||||
|
@ -126,7 +128,7 @@ func (s *filterStatements) InsertFilter(
|
|||
}
|
||||
|
||||
// Otherwise insert the filter and return the new ID
|
||||
res, err := s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart)
|
||||
res, err := sqlutil.TxStmt(txn, s.insertFilterStmt).ExecContext(ctx, filterJSON, localpart)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
|
|
@ -91,7 +91,12 @@ func (r *notificationDataStatements) SelectUserUnreadCountsForRooms(
|
|||
params[i+1] = roomIDs[i]
|
||||
}
|
||||
sql := strings.Replace(selectUserUnreadNotificationsForRooms, "($2)", sqlutil.QueryVariadicOffset(len(roomIDs), 1), 1)
|
||||
rows, err := r.db.QueryContext(ctx, sql, params...)
|
||||
prep, err := r.db.PrepareContext(ctx, sql)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, prep, "SelectUserUnreadCountsForRooms: prep.close() failed")
|
||||
rows, err := sqlutil.TxStmt(txn, prep).QueryContext(ctx, params...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -164,12 +164,12 @@ func NewSqliteEventsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Even
|
|||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error {
|
||||
func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent) error {
|
||||
headeredJSON, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = s.updateEventJSONStmt.ExecContext(ctx, headeredJSON, event.EventID())
|
||||
_, err = sqlutil.TxStmt(txn, s.updateEventJSONStmt).ExecContext(ctx, headeredJSON, event.EventID())
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -647,7 +647,7 @@ func (s *outputRoomEventsStatements) ReIndex(ctx context.Context, txn *sql.Tx, l
|
|||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, stmt, "selectEvents: stmt.close() failed")
|
||||
rows, err := stmt.QueryContext(ctx, params...)
|
||||
rows, err := sqlutil.TxStmt(txn, stmt).QueryContext(ctx, params...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -176,9 +176,9 @@ func (s *outputRoomEventsTopologyStatements) SelectStreamToTopologicalPosition(
|
|||
ctx context.Context, txn *sql.Tx, roomID string, streamPos types.StreamPosition, backwardOrdering bool,
|
||||
) (topoPos types.StreamPosition, err error) {
|
||||
if backwardOrdering {
|
||||
err = s.selectStreamToTopologicalPositionDescStmt.QueryRowContext(ctx, roomID, streamPos).Scan(&topoPos)
|
||||
err = sqlutil.TxStmt(txn, s.selectStreamToTopologicalPositionDescStmt).QueryRowContext(ctx, roomID, streamPos).Scan(&topoPos)
|
||||
} else {
|
||||
err = s.selectStreamToTopologicalPositionAscStmt.QueryRowContext(ctx, roomID, streamPos).Scan(&topoPos)
|
||||
err = sqlutil.TxStmt(txn, s.selectStreamToTopologicalPositionAscStmt).QueryRowContext(ctx, roomID, streamPos).Scan(&topoPos)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
|
@ -172,9 +172,9 @@ func (s *peekStatements) SelectPeeksInRange(
|
|||
}
|
||||
|
||||
func (s *peekStatements) SelectPeekingDevices(
|
||||
ctx context.Context,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
) (peekingDevices map[string][]types.PeekingDevice, err error) {
|
||||
rows, err := s.selectPeekingDevicesStmt.QueryContext(ctx)
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectPeekingDevicesStmt).QueryContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -108,7 +108,7 @@ func (r *receiptStatements) UpsertReceipt(ctx context.Context, txn *sql.Tx, room
|
|||
}
|
||||
|
||||
// SelectRoomReceiptsAfter select all receipts for a given room after a specific timestamp
|
||||
func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) {
|
||||
func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, txn *sql.Tx, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []types.OutputReceiptEvent, error) {
|
||||
selectSQL := strings.Replace(selectRoomReceipts, "($2)", sqlutil.QueryVariadicOffset(len(roomIDs), 1), 1)
|
||||
var lastPos types.StreamPosition
|
||||
params := make([]interface{}, len(roomIDs)+1)
|
||||
|
@ -116,7 +116,12 @@ func (r *receiptStatements) SelectRoomReceiptsAfter(ctx context.Context, roomIDs
|
|||
for k, v := range roomIDs {
|
||||
params[k+1] = v
|
||||
}
|
||||
rows, err := r.db.QueryContext(ctx, selectSQL, params...)
|
||||
prep, err := r.db.Prepare(selectSQL)
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("unable to prepare statement: %w", err)
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, prep, "SelectRoomReceiptsAfter: prep.close() failed")
|
||||
rows, err := sqlutil.TxStmt(txn, prep).QueryContext(ctx, params...)
|
||||
if err != nil {
|
||||
return 0, nil, fmt.Errorf("unable to query room receipts: %w", err)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue