Component-wide TransactionWriters (#1290)

* Offset updates take place using TransactionWriter

* Refactor TransactionWriter in current state server

* Refactor TransactionWriter in federation sender

* Refactor TransactionWriter in key server

* Refactor TransactionWriter in media API

* Refactor TransactionWriter in server key API

* Refactor TransactionWriter in sync API

* Refactor TransactionWriter in user API

* Fix deadlocking Sync API tests

* Un-deadlock device database

* Fix appservice API

* Rename TransactionWriters to Writers

* Move writers up a layer in sync API

* Document sqlutil.Writer interface

* Add note to Writer documentation
This commit is contained in:
Neil Alexander 2020-08-21 10:42:08 +01:00 committed by GitHub
parent 5aaf32bbed
commit 9d53351dc2
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
56 changed files with 483 additions and 483 deletions

View file

@ -20,7 +20,6 @@ import (
"database/sql"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
@ -51,7 +50,6 @@ const selectMaxAccountDataIDSQL = "" +
type accountDataStatements struct {
db *sql.DB
writer sqlutil.TransactionWriter
streamIDStatements *streamIDStatements
insertAccountDataStmt *sql.Stmt
selectMaxAccountDataIDStmt *sql.Stmt
@ -61,7 +59,6 @@ type accountDataStatements struct {
func NewSqliteAccountDataTable(db *sql.DB, streamID *streamIDStatements) (tables.AccountData, error) {
s := &accountDataStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
streamIDStatements: streamID,
}
_, err := db.Exec(accountDataSchema)
@ -84,15 +81,12 @@ func (s *accountDataStatements) InsertAccountData(
ctx context.Context, txn *sql.Tx,
userID, roomID, dataType string,
) (pos types.StreamPosition, err error) {
return pos, s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
var err error
pos, err = s.streamIDStatements.nextStreamID(ctx, txn)
if err != nil {
return err
}
_, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType)
return err
})
pos, err = s.streamIDStatements.nextStreamID(ctx, txn)
if err != nil {
return
}
_, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType)
return
}
func (s *accountDataStatements) SelectAccountDataInRange(

View file

@ -19,7 +19,6 @@ import (
"database/sql"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
)
@ -49,7 +48,6 @@ const deleteBackwardExtremitySQL = "" +
type backwardExtremitiesStatements struct {
db *sql.DB
writer sqlutil.TransactionWriter
insertBackwardExtremityStmt *sql.Stmt
selectBackwardExtremitiesForRoomStmt *sql.Stmt
deleteBackwardExtremityStmt *sql.Stmt
@ -57,8 +55,7 @@ type backwardExtremitiesStatements struct {
func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) {
s := &backwardExtremitiesStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
db: db,
}
_, err := db.Exec(backwardExtremitiesSchema)
if err != nil {
@ -79,10 +76,8 @@ func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities
func (s *backwardExtremitiesStatements) InsertsBackwardExtremity(
ctx context.Context, txn *sql.Tx, roomID, eventID string, prevEventID string,
) (err error) {
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
_, err := txn.Stmt(s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID)
return err
})
_, err = txn.Stmt(s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID)
return err
}
func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom(
@ -110,8 +105,6 @@ func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom(
func (s *backwardExtremitiesStatements) DeleteBackwardExtremity(
ctx context.Context, txn *sql.Tx, roomID, knownEventID string,
) (err error) {
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
_, err := txn.Stmt(s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
return err
})
_, err = txn.Stmt(s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
return err
}

View file

@ -85,7 +85,6 @@ const selectEventsWithEventIDsSQL = "" +
type currentRoomStateStatements struct {
db *sql.DB
writer sqlutil.TransactionWriter
streamIDStatements *streamIDStatements
upsertRoomStateStmt *sql.Stmt
deleteRoomStateByEventIDStmt *sql.Stmt
@ -98,7 +97,6 @@ type currentRoomStateStatements struct {
func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (tables.CurrentRoomState, error) {
s := &currentRoomStateStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
streamIDStatements: streamID,
}
_, err := db.Exec(currentRoomStateSchema)
@ -200,11 +198,9 @@ func (s *currentRoomStateStatements) SelectCurrentState(
func (s *currentRoomStateStatements) DeleteRoomStateByEventID(
ctx context.Context, txn *sql.Tx, eventID string,
) error {
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt)
_, err := stmt.ExecContext(ctx, eventID)
return err
})
stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt)
_, err := stmt.ExecContext(ctx, eventID)
return err
}
func (s *currentRoomStateStatements) UpsertRoomState(
@ -225,22 +221,20 @@ func (s *currentRoomStateStatements) UpsertRoomState(
}
// upsert state event
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt)
_, err := stmt.ExecContext(
ctx,
event.RoomID(),
event.EventID(),
event.Type(),
event.Sender(),
containsURL,
*event.StateKey(),
headeredJSON,
membership,
addedAt,
)
return err
})
stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt)
_, err = stmt.ExecContext(
ctx,
event.RoomID(),
event.EventID(),
event.Type(),
event.Sender(),
containsURL,
*event.StateKey(),
headeredJSON,
membership,
addedAt,
)
return err
}
func minOfInts(a, b int) int {

View file

@ -20,7 +20,6 @@ import (
"encoding/json"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/syncapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
)
@ -52,7 +51,6 @@ const insertFilterSQL = "" +
type filterStatements struct {
db *sql.DB
writer sqlutil.TransactionWriter
selectFilterStmt *sql.Stmt
selectFilterIDByContentStmt *sql.Stmt
insertFilterStmt *sql.Stmt
@ -64,8 +62,7 @@ func NewSqliteFilterTable(db *sql.DB) (tables.Filter, error) {
return nil, err
}
s := &filterStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
db: db,
}
if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil {
return nil, err
@ -114,33 +111,30 @@ func (s *filterStatements) InsertFilter(
return "", err
}
err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
// Check if filter already exists in the database using its localpart and content
//
// This can result in a race condition when two clients try to insert the
// same filter and localpart at the same time, however this is not a
// problem as both calls will result in the same filterID
err = s.selectFilterIDByContentStmt.QueryRowContext(ctx,
localpart, filterJSON).Scan(&existingFilterID)
if err != nil && err != sql.ErrNoRows {
return err
}
// If it does, return the existing ID
if existingFilterID != "" {
return nil
}
// Check if filter already exists in the database using its localpart and content
//
// This can result in a race condition when two clients try to insert the
// same filter and localpart at the same time, however this is not a
// problem as both calls will result in the same filterID
err = s.selectFilterIDByContentStmt.QueryRowContext(ctx,
localpart, filterJSON).Scan(&existingFilterID)
if err != nil && err != sql.ErrNoRows {
return "", err
}
// If it does, return the existing ID
if existingFilterID != "" {
return existingFilterID, nil
}
// Otherwise insert the filter and return the new ID
res, err := s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart)
if err != nil {
return err
}
rowid, err := res.LastInsertId()
if err != nil {
return err
}
filterID = fmt.Sprintf("%d", rowid)
return nil
})
// Otherwise insert the filter and return the new ID
res, err := s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart)
if err != nil {
return "", err
}
rowid, err := res.LastInsertId()
if err != nil {
return "", err
}
filterID = fmt.Sprintf("%d", rowid)
return
}

View file

@ -59,7 +59,6 @@ const selectMaxInviteIDSQL = "" +
type inviteEventsStatements struct {
db *sql.DB
writer sqlutil.TransactionWriter
streamIDStatements *streamIDStatements
insertInviteEventStmt *sql.Stmt
selectInviteEventsInRangeStmt *sql.Stmt
@ -70,7 +69,6 @@ type inviteEventsStatements struct {
func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Invites, error) {
s := &inviteEventsStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
streamIDStatements: streamID,
}
_, err := db.Exec(inviteEventsSchema)
@ -95,45 +93,37 @@ func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Inv
func (s *inviteEventsStatements) InsertInviteEvent(
ctx context.Context, txn *sql.Tx, inviteEvent gomatrixserverlib.HeaderedEvent,
) (streamPos types.StreamPosition, err error) {
err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
var err error
streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn)
if err != nil {
return err
}
streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn)
if err != nil {
return
}
var headeredJSON []byte
headeredJSON, err = json.Marshal(inviteEvent)
if err != nil {
return err
}
var headeredJSON []byte
headeredJSON, err = json.Marshal(inviteEvent)
if err != nil {
return
}
_, err = txn.Stmt(s.insertInviteEventStmt).ExecContext(
ctx,
streamPos,
inviteEvent.RoomID(),
inviteEvent.EventID(),
*inviteEvent.StateKey(),
headeredJSON,
)
return err
})
stmt := sqlutil.TxStmt(txn, s.insertInviteEventStmt)
_, err = stmt.ExecContext(
ctx,
streamPos,
inviteEvent.RoomID(),
inviteEvent.EventID(),
*inviteEvent.StateKey(),
headeredJSON,
)
return
}
func (s *inviteEventsStatements) DeleteInviteEvent(
ctx context.Context, inviteEventID string,
) (types.StreamPosition, error) {
var streamPos types.StreamPosition
err := s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
var err error
streamPos, err = s.streamIDStatements.nextStreamID(ctx, nil)
if err != nil {
return err
}
_, err = s.deleteInviteEventStmt.ExecContext(ctx, streamPos, inviteEventID)
return err
})
streamPos, err := s.streamIDStatements.nextStreamID(ctx, nil)
if err != nil {
return streamPos, err
}
_, err = s.deleteInviteEventStmt.ExecContext(ctx, streamPos, inviteEventID)
return streamPos, err
}

View file

@ -105,7 +105,6 @@ const selectStateInRangeSQL = "" +
type outputRoomEventsStatements struct {
db *sql.DB
writer sqlutil.TransactionWriter
streamIDStatements *streamIDStatements
insertEventStmt *sql.Stmt
selectEventsStmt *sql.Stmt
@ -120,7 +119,6 @@ type outputRoomEventsStatements struct {
func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Events, error) {
s := &outputRoomEventsStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
streamIDStatements: streamID,
}
_, err := db.Exec(outputRoomEventsSchema)
@ -159,10 +157,8 @@ func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, event
if err != nil {
return err
}
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
_, err = s.updateEventJSONStmt.ExecContext(ctx, headeredJSON, event.EventID())
return err
})
_, err = s.updateEventJSONStmt.ExecContext(ctx, headeredJSON, event.EventID())
return err
}
// selectStateInRange returns the state events between the two given PDU stream positions, exclusive of oldPos, inclusive of newPos.
@ -304,32 +300,27 @@ func (s *outputRoomEventsStatements) InsertEvent(
return 0, err
}
var streamPos types.StreamPosition
err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn)
if err != nil {
return err
}
insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt)
_, ierr := insertStmt.ExecContext(
ctx,
streamPos,
event.RoomID(),
event.EventID(),
headeredJSON,
event.Type(),
event.Sender(),
containsURL,
string(addStateJSON),
string(removeStateJSON),
sessionID,
txnID,
excludeFromSync,
excludeFromSync,
)
return ierr
})
streamPos, err := s.streamIDStatements.nextStreamID(ctx, txn)
if err != nil {
return 0, err
}
insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt)
_, err = insertStmt.ExecContext(
ctx,
streamPos,
event.RoomID(),
event.EventID(),
headeredJSON,
event.Type(),
event.Sender(),
containsURL,
string(addStateJSON),
string(removeStateJSON),
sessionID,
txnID,
excludeFromSync,
excludeFromSync,
)
return streamPos, err
}

View file

@ -67,7 +67,6 @@ const selectMaxPositionInTopologySQL = "" +
type outputRoomEventsTopologyStatements struct {
db *sql.DB
writer sqlutil.TransactionWriter
insertEventInTopologyStmt *sql.Stmt
selectEventIDsInRangeASCStmt *sql.Stmt
selectEventIDsInRangeDESCStmt *sql.Stmt
@ -77,8 +76,7 @@ type outputRoomEventsTopologyStatements struct {
func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) {
s := &outputRoomEventsTopologyStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
db: db,
}
_, err := db.Exec(outputRoomEventsTopologySchema)
if err != nil {
@ -107,13 +105,11 @@ func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) {
func (s *outputRoomEventsTopologyStatements) InsertEventInTopology(
ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition,
) (err error) {
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
stmt := sqlutil.TxStmt(txn, s.insertEventInTopologyStmt)
_, err := stmt.ExecContext(
ctx, event.EventID(), event.Depth(), event.RoomID(), pos,
)
return err
})
stmt := sqlutil.TxStmt(txn, s.insertEventInTopologyStmt)
_, err = stmt.ExecContext(
ctx, event.EventID(), event.Depth(), event.RoomID(), pos,
)
return
}
func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange(

View file

@ -73,7 +73,6 @@ const deleteSendToDeviceMessagesSQL = `
type sendToDeviceStatements struct {
db *sql.DB
writer sqlutil.TransactionWriter
insertSendToDeviceMessageStmt *sql.Stmt
selectSendToDeviceMessagesStmt *sql.Stmt
countSendToDeviceMessagesStmt *sql.Stmt
@ -81,8 +80,7 @@ type sendToDeviceStatements struct {
func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
s := &sendToDeviceStatements{
db: db,
writer: sqlutil.NewTransactionWriter(),
db: db,
}
_, err := db.Exec(sendToDeviceSchema)
if err != nil {
@ -103,10 +101,8 @@ func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
func (s *sendToDeviceStatements) InsertSendToDeviceMessage(
ctx context.Context, txn *sql.Tx, userID, deviceID, content string,
) (err error) {
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
_, err := sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content)
return err
})
_, err = sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content)
return
}
func (s *sendToDeviceStatements) CountSendToDeviceMessages(
@ -163,10 +159,8 @@ func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages(
for k, v := range nids {
params[k+1] = v
}
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
_, err := txn.ExecContext(ctx, query, params...)
return err
})
_, err = txn.ExecContext(ctx, query, params...)
return
}
func (s *sendToDeviceStatements) DeleteSendToDeviceMessages(
@ -177,8 +171,6 @@ func (s *sendToDeviceStatements) DeleteSendToDeviceMessages(
for k, v := range nids {
params[k] = v
}
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
_, err := txn.ExecContext(ctx, query, params...)
return err
})
_, err = txn.ExecContext(ctx, query, params...)
return
}

View file

@ -28,14 +28,12 @@ const selectStreamIDStmt = "" +
type streamIDStatements struct {
db *sql.DB
writer sqlutil.TransactionWriter
increaseStreamIDStmt *sql.Stmt
selectStreamIDStmt *sql.Stmt
}
func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
s.db = db
s.writer = sqlutil.NewTransactionWriter()
_, err = db.Exec(streamIDTableSchema)
if err != nil {
return
@ -52,14 +50,9 @@ func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
func (s *streamIDStatements) nextStreamID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt)
err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
if _, ierr := increaseStmt.ExecContext(ctx, "global"); err != nil {
return ierr
}
if serr := selectStmt.QueryRowContext(ctx, "global").Scan(&pos); err != nil {
return serr
}
return nil
})
if _, err = increaseStmt.ExecContext(ctx, "global"); err != nil {
return
}
err = selectStmt.QueryRowContext(ctx, "global").Scan(&pos)
return
}

View file

@ -31,7 +31,8 @@ import (
// both the database for PDUs and caches for EDUs.
type SyncServerDatasource struct {
shared.Database
db *sql.DB
db *sql.DB
writer sqlutil.Writer
sqlutil.PartitionOffsetStatements
streamID streamIDStatements
}
@ -44,6 +45,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
if d.db, err = sqlutil.Open(dbProperties); err != nil {
return nil, err
}
d.writer = sqlutil.NewExclusiveWriter()
if err = d.prepare(); err != nil {
return nil, err
}
@ -51,7 +53,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
}
func (d *SyncServerDatasource) prepare() (err error) {
if err = d.PartitionOffsetStatements.Prepare(d.db, "syncapi"); err != nil {
if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "syncapi"); err != nil {
return err
}
if err = d.streamID.prepare(d.db); err != nil {
@ -91,6 +93,7 @@ func (d *SyncServerDatasource) prepare() (err error) {
}
d.Database = shared.Database{
DB: d.db,
Writer: sqlutil.NewExclusiveWriter(),
Invites: invites,
AccountData: accountData,
OutputEvents: events,
@ -99,7 +102,6 @@ func (d *SyncServerDatasource) prepare() (err error) {
Topology: topology,
Filter: filter,
SendToDevice: sendToDevice,
SendToDeviceWriter: sqlutil.NewTransactionWriter(),
EDUCache: cache.New(),
}
return nil