mirror of
https://github.com/hoernschen/dendrite.git
synced 2025-08-01 05:42:46 +00:00
Use TransactionWriter in other component SQLite (#1209)
* Use TransactionWriter on other component SQLites * Fix sync API tests * Fix panic in media API * Fix a couple of transactions * Fix wrong query, add some logging output * Add debug logging into StoreEvent * Adjust InsertRoomNID * Update logging
This commit is contained in:
parent
1d72ce8b7a
commit
b6bc132485
27 changed files with 439 additions and 245 deletions
|
@ -281,16 +281,16 @@ func (d *Database) WriteEvent(
|
|||
ctx, txn, ev, addStateEventIDs, removeStateEventIDs, transactionID, excludeFromSync,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("d.OutputEvents.InsertEvent: %w", err)
|
||||
}
|
||||
pduPosition = pos
|
||||
|
||||
if err = d.Topology.InsertEventInTopology(ctx, txn, ev, pos); err != nil {
|
||||
return err
|
||||
return fmt.Errorf("d.Topology.InsertEventInTopology: %w", err)
|
||||
}
|
||||
|
||||
if err = d.handleBackwardExtremities(ctx, txn, ev); err != nil {
|
||||
return err
|
||||
return fmt.Errorf("d.handleBackwardExtremities: %w", err)
|
||||
}
|
||||
|
||||
if len(addStateEvents) == 0 && len(removeStateEventIDs) == 0 {
|
||||
|
@ -313,7 +313,7 @@ func (d *Database) updateRoomState(
|
|||
// remove first, then add, as we do not ever delete state, but do replace state which is a remove followed by an add.
|
||||
for _, eventID := range removedEventIDs {
|
||||
if err := d.CurrentRoomState.DeleteRoomStateByEventID(ctx, txn, eventID); err != nil {
|
||||
return err
|
||||
return fmt.Errorf("d.CurrentRoomState.DeleteRoomStateByEventID: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -326,13 +326,13 @@ func (d *Database) updateRoomState(
|
|||
if event.Type() == "m.room.member" {
|
||||
value, err := event.Membership()
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("event.Membership: %w", err)
|
||||
}
|
||||
membership = &value
|
||||
}
|
||||
|
||||
if err := d.CurrentRoomState.UpsertRoomState(ctx, txn, event, membership, pduPosition); err != nil {
|
||||
return err
|
||||
return fmt.Errorf("d.CurrentRoomState.UpsertRoomState: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -20,6 +20,7 @@ import (
|
|||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
@ -49,6 +50,8 @@ const selectMaxAccountDataIDSQL = "" +
|
|||
"SELECT MAX(id) FROM syncapi_account_data_type"
|
||||
|
||||
type accountDataStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
streamIDStatements *streamIDStatements
|
||||
insertAccountDataStmt *sql.Stmt
|
||||
selectMaxAccountDataIDStmt *sql.Stmt
|
||||
|
@ -57,6 +60,8 @@ type accountDataStatements struct {
|
|||
|
||||
func NewSqliteAccountDataTable(db *sql.DB, streamID *streamIDStatements) (tables.AccountData, error) {
|
||||
s := &accountDataStatements{
|
||||
db: db,
|
||||
writer: sqlutil.NewTransactionWriter(),
|
||||
streamIDStatements: streamID,
|
||||
}
|
||||
_, err := db.Exec(accountDataSchema)
|
||||
|
@ -79,12 +84,15 @@ func (s *accountDataStatements) InsertAccountData(
|
|||
ctx context.Context, txn *sql.Tx,
|
||||
userID, roomID, dataType string,
|
||||
) (pos types.StreamPosition, err error) {
|
||||
pos, err = s.streamIDStatements.nextStreamID(ctx, txn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
_, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType)
|
||||
return
|
||||
return pos, s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
||||
var err error
|
||||
pos, err = s.streamIDStatements.nextStreamID(ctx, txn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = txn.Stmt(s.insertAccountDataStmt).ExecContext(ctx, pos, userID, roomID, dataType)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func (s *accountDataStatements) SelectAccountDataInRange(
|
||||
|
|
|
@ -19,6 +19,7 @@ import (
|
|||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||
)
|
||||
|
||||
|
@ -47,13 +48,18 @@ const deleteBackwardExtremitySQL = "" +
|
|||
"DELETE FROM syncapi_backward_extremities WHERE room_id = $1 AND prev_event_id = $2"
|
||||
|
||||
type backwardExtremitiesStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
insertBackwardExtremityStmt *sql.Stmt
|
||||
selectBackwardExtremitiesForRoomStmt *sql.Stmt
|
||||
deleteBackwardExtremityStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities, error) {
|
||||
s := &backwardExtremitiesStatements{}
|
||||
s := &backwardExtremitiesStatements{
|
||||
db: db,
|
||||
writer: sqlutil.NewTransactionWriter(),
|
||||
}
|
||||
_, err := db.Exec(backwardExtremitiesSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -73,8 +79,10 @@ func NewSqliteBackwardsExtremitiesTable(db *sql.DB) (tables.BackwardsExtremities
|
|||
func (s *backwardExtremitiesStatements) InsertsBackwardExtremity(
|
||||
ctx context.Context, txn *sql.Tx, roomID, eventID string, prevEventID string,
|
||||
) (err error) {
|
||||
_, err = txn.Stmt(s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID)
|
||||
return
|
||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||
_, err := txn.Stmt(s.insertBackwardExtremityStmt).ExecContext(ctx, roomID, eventID, prevEventID)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom(
|
||||
|
@ -102,6 +110,8 @@ func (s *backwardExtremitiesStatements) SelectBackwardExtremitiesForRoom(
|
|||
func (s *backwardExtremitiesStatements) DeleteBackwardExtremity(
|
||||
ctx context.Context, txn *sql.Tx, roomID, knownEventID string,
|
||||
) (err error) {
|
||||
_, err = txn.Stmt(s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
|
||||
return
|
||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||
_, err := txn.Stmt(s.deleteBackwardExtremityStmt).ExecContext(ctx, roomID, knownEventID)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
|
|
@ -84,6 +84,8 @@ const selectEventsWithEventIDsSQL = "" +
|
|||
" FROM syncapi_current_room_state WHERE event_id IN ($1)"
|
||||
|
||||
type currentRoomStateStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
streamIDStatements *streamIDStatements
|
||||
upsertRoomStateStmt *sql.Stmt
|
||||
deleteRoomStateByEventIDStmt *sql.Stmt
|
||||
|
@ -95,6 +97,8 @@ type currentRoomStateStatements struct {
|
|||
|
||||
func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (tables.CurrentRoomState, error) {
|
||||
s := ¤tRoomStateStatements{
|
||||
db: db,
|
||||
writer: sqlutil.NewTransactionWriter(),
|
||||
streamIDStatements: streamID,
|
||||
}
|
||||
_, err := db.Exec(currentRoomStateSchema)
|
||||
|
@ -196,9 +200,11 @@ func (s *currentRoomStateStatements) SelectCurrentState(
|
|||
func (s *currentRoomStateStatements) DeleteRoomStateByEventID(
|
||||
ctx context.Context, txn *sql.Tx, eventID string,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt)
|
||||
_, err := stmt.ExecContext(ctx, eventID)
|
||||
return err
|
||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteRoomStateByEventIDStmt)
|
||||
_, err := stmt.ExecContext(ctx, eventID)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func (s *currentRoomStateStatements) UpsertRoomState(
|
||||
|
@ -219,20 +225,22 @@ func (s *currentRoomStateStatements) UpsertRoomState(
|
|||
}
|
||||
|
||||
// upsert state event
|
||||
stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt)
|
||||
_, err = stmt.ExecContext(
|
||||
ctx,
|
||||
event.RoomID(),
|
||||
event.EventID(),
|
||||
event.Type(),
|
||||
event.Sender(),
|
||||
containsURL,
|
||||
*event.StateKey(),
|
||||
headeredJSON,
|
||||
membership,
|
||||
addedAt,
|
||||
)
|
||||
return err
|
||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.upsertRoomStateStmt)
|
||||
_, err := stmt.ExecContext(
|
||||
ctx,
|
||||
event.RoomID(),
|
||||
event.EventID(),
|
||||
event.Type(),
|
||||
event.Sender(),
|
||||
containsURL,
|
||||
*event.StateKey(),
|
||||
headeredJSON,
|
||||
membership,
|
||||
addedAt,
|
||||
)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func (s *currentRoomStateStatements) SelectEventsWithEventIDs(
|
||||
|
|
|
@ -20,6 +20,7 @@ import (
|
|||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
@ -50,6 +51,8 @@ const insertFilterSQL = "" +
|
|||
"INSERT INTO syncapi_filter (filter, localpart) VALUES ($1, $2)"
|
||||
|
||||
type filterStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
selectFilterStmt *sql.Stmt
|
||||
selectFilterIDByContentStmt *sql.Stmt
|
||||
insertFilterStmt *sql.Stmt
|
||||
|
@ -60,7 +63,10 @@ func NewSqliteFilterTable(db *sql.DB) (tables.Filter, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s := &filterStatements{}
|
||||
s := &filterStatements{
|
||||
db: db,
|
||||
writer: sqlutil.NewTransactionWriter(),
|
||||
}
|
||||
if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -108,30 +114,33 @@ func (s *filterStatements) InsertFilter(
|
|||
return "", err
|
||||
}
|
||||
|
||||
// Check if filter already exists in the database using its localpart and content
|
||||
//
|
||||
// This can result in a race condition when two clients try to insert the
|
||||
// same filter and localpart at the same time, however this is not a
|
||||
// problem as both calls will result in the same filterID
|
||||
err = s.selectFilterIDByContentStmt.QueryRowContext(ctx,
|
||||
localpart, filterJSON).Scan(&existingFilterID)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return "", err
|
||||
}
|
||||
// If it does, return the existing ID
|
||||
if existingFilterID != "" {
|
||||
return existingFilterID, err
|
||||
}
|
||||
err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
||||
// Check if filter already exists in the database using its localpart and content
|
||||
//
|
||||
// This can result in a race condition when two clients try to insert the
|
||||
// same filter and localpart at the same time, however this is not a
|
||||
// problem as both calls will result in the same filterID
|
||||
err = s.selectFilterIDByContentStmt.QueryRowContext(ctx,
|
||||
localpart, filterJSON).Scan(&existingFilterID)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
// If it does, return the existing ID
|
||||
if existingFilterID != "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Otherwise insert the filter and return the new ID
|
||||
res, err := s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
rowid, err := res.LastInsertId()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
filterID = fmt.Sprintf("%d", rowid)
|
||||
// Otherwise insert the filter and return the new ID
|
||||
res, err := s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rowid, err := res.LastInsertId()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
filterID = fmt.Sprintf("%d", rowid)
|
||||
return nil
|
||||
})
|
||||
return
|
||||
}
|
||||
|
|
|
@ -58,6 +58,8 @@ const selectMaxInviteIDSQL = "" +
|
|||
"SELECT MAX(id) FROM syncapi_invite_events"
|
||||
|
||||
type inviteEventsStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
streamIDStatements *streamIDStatements
|
||||
insertInviteEventStmt *sql.Stmt
|
||||
selectInviteEventsInRangeStmt *sql.Stmt
|
||||
|
@ -67,6 +69,8 @@ type inviteEventsStatements struct {
|
|||
|
||||
func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Invites, error) {
|
||||
s := &inviteEventsStatements{
|
||||
db: db,
|
||||
writer: sqlutil.NewTransactionWriter(),
|
||||
streamIDStatements: streamID,
|
||||
}
|
||||
_, err := db.Exec(inviteEventsSchema)
|
||||
|
@ -91,36 +95,45 @@ func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Inv
|
|||
func (s *inviteEventsStatements) InsertInviteEvent(
|
||||
ctx context.Context, txn *sql.Tx, inviteEvent gomatrixserverlib.HeaderedEvent,
|
||||
) (streamPos types.StreamPosition, err error) {
|
||||
streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||
var err error
|
||||
streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var headeredJSON []byte
|
||||
headeredJSON, err = json.Marshal(inviteEvent)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var headeredJSON []byte
|
||||
headeredJSON, err = json.Marshal(inviteEvent)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = txn.Stmt(s.insertInviteEventStmt).ExecContext(
|
||||
ctx,
|
||||
streamPos,
|
||||
inviteEvent.RoomID(),
|
||||
inviteEvent.EventID(),
|
||||
*inviteEvent.StateKey(),
|
||||
headeredJSON,
|
||||
)
|
||||
_, err = txn.Stmt(s.insertInviteEventStmt).ExecContext(
|
||||
ctx,
|
||||
streamPos,
|
||||
inviteEvent.RoomID(),
|
||||
inviteEvent.EventID(),
|
||||
*inviteEvent.StateKey(),
|
||||
headeredJSON,
|
||||
)
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (s *inviteEventsStatements) DeleteInviteEvent(
|
||||
ctx context.Context, inviteEventID string,
|
||||
) (types.StreamPosition, error) {
|
||||
streamPos, err := s.streamIDStatements.nextStreamID(ctx, nil)
|
||||
if err != nil {
|
||||
return streamPos, err
|
||||
}
|
||||
_, err = s.deleteInviteEventStmt.ExecContext(ctx, streamPos, inviteEventID)
|
||||
var streamPos types.StreamPosition
|
||||
err := s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
||||
var err error
|
||||
streamPos, err = s.streamIDStatements.nextStreamID(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = s.deleteInviteEventStmt.ExecContext(ctx, streamPos, inviteEventID)
|
||||
return err
|
||||
})
|
||||
return streamPos, err
|
||||
}
|
||||
|
||||
|
|
|
@ -104,6 +104,8 @@ const selectStateInRangeSQL = "" +
|
|||
" LIMIT $8" // limit
|
||||
|
||||
type outputRoomEventsStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
streamIDStatements *streamIDStatements
|
||||
insertEventStmt *sql.Stmt
|
||||
selectEventsStmt *sql.Stmt
|
||||
|
@ -117,6 +119,8 @@ type outputRoomEventsStatements struct {
|
|||
|
||||
func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Events, error) {
|
||||
s := &outputRoomEventsStatements{
|
||||
db: db,
|
||||
writer: sqlutil.NewTransactionWriter(),
|
||||
streamIDStatements: streamID,
|
||||
}
|
||||
_, err := db.Exec(outputRoomEventsSchema)
|
||||
|
@ -155,8 +159,10 @@ func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, event
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = s.updateEventJSONStmt.ExecContext(ctx, headeredJSON, event.EventID())
|
||||
return err
|
||||
return s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
||||
_, err = s.updateEventJSONStmt.ExecContext(ctx, headeredJSON, event.EventID())
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
// selectStateInRange returns the state events between the two given PDU stream positions, exclusive of oldPos, inclusive of newPos.
|
||||
|
@ -267,7 +273,7 @@ func (s *outputRoomEventsStatements) InsertEvent(
|
|||
ctx context.Context, txn *sql.Tx,
|
||||
event *gomatrixserverlib.HeaderedEvent, addState, removeState []string,
|
||||
transactionID *api.TransactionID, excludeFromSync bool,
|
||||
) (streamPos types.StreamPosition, err error) {
|
||||
) (types.StreamPosition, error) {
|
||||
var txnID *string
|
||||
var sessionID *int64
|
||||
if transactionID != nil {
|
||||
|
@ -284,43 +290,47 @@ func (s *outputRoomEventsStatements) InsertEvent(
|
|||
}
|
||||
|
||||
var headeredJSON []byte
|
||||
headeredJSON, err = json.Marshal(event)
|
||||
headeredJSON, err := json.Marshal(event)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn)
|
||||
if err != nil {
|
||||
return
|
||||
return 0, err
|
||||
}
|
||||
|
||||
addStateJSON, err := json.Marshal(addState)
|
||||
if err != nil {
|
||||
return
|
||||
return 0, err
|
||||
}
|
||||
removeStateJSON, err := json.Marshal(removeState)
|
||||
if err != nil {
|
||||
return
|
||||
return 0, err
|
||||
}
|
||||
|
||||
insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt)
|
||||
_, err = insertStmt.ExecContext(
|
||||
ctx,
|
||||
streamPos,
|
||||
event.RoomID(),
|
||||
event.EventID(),
|
||||
headeredJSON,
|
||||
event.Type(),
|
||||
event.Sender(),
|
||||
containsURL,
|
||||
string(addStateJSON),
|
||||
string(removeStateJSON),
|
||||
sessionID,
|
||||
txnID,
|
||||
excludeFromSync,
|
||||
excludeFromSync,
|
||||
)
|
||||
return
|
||||
var streamPos types.StreamPosition
|
||||
err = s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||
streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
insertStmt := sqlutil.TxStmt(txn, s.insertEventStmt)
|
||||
_, ierr := insertStmt.ExecContext(
|
||||
ctx,
|
||||
streamPos,
|
||||
event.RoomID(),
|
||||
event.EventID(),
|
||||
headeredJSON,
|
||||
event.Type(),
|
||||
event.Sender(),
|
||||
containsURL,
|
||||
string(addStateJSON),
|
||||
string(removeStateJSON),
|
||||
sessionID,
|
||||
txnID,
|
||||
excludeFromSync,
|
||||
excludeFromSync,
|
||||
)
|
||||
return ierr
|
||||
})
|
||||
return streamPos, err
|
||||
}
|
||||
|
||||
func (s *outputRoomEventsStatements) SelectRecentEvents(
|
||||
|
|
|
@ -66,6 +66,8 @@ const selectMaxPositionInTopologySQL = "" +
|
|||
" WHERE room_id = $1 ORDER BY stream_position DESC"
|
||||
|
||||
type outputRoomEventsTopologyStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
insertEventInTopologyStmt *sql.Stmt
|
||||
selectEventIDsInRangeASCStmt *sql.Stmt
|
||||
selectEventIDsInRangeDESCStmt *sql.Stmt
|
||||
|
@ -74,7 +76,10 @@ type outputRoomEventsTopologyStatements struct {
|
|||
}
|
||||
|
||||
func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) {
|
||||
s := &outputRoomEventsTopologyStatements{}
|
||||
s := &outputRoomEventsTopologyStatements{
|
||||
db: db,
|
||||
writer: sqlutil.NewTransactionWriter(),
|
||||
}
|
||||
_, err := db.Exec(outputRoomEventsTopologySchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -102,11 +107,13 @@ func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) {
|
|||
func (s *outputRoomEventsTopologyStatements) InsertEventInTopology(
|
||||
ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, pos types.StreamPosition,
|
||||
) (err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertEventInTopologyStmt)
|
||||
_, err = stmt.ExecContext(
|
||||
ctx, event.EventID(), event.Depth(), event.RoomID(), pos,
|
||||
)
|
||||
return
|
||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertEventInTopologyStmt)
|
||||
_, err := stmt.ExecContext(
|
||||
ctx, event.EventID(), event.Depth(), event.RoomID(), pos,
|
||||
)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func (s *outputRoomEventsTopologyStatements) SelectEventIDsInRange(
|
||||
|
|
|
@ -72,13 +72,18 @@ const deleteSendToDeviceMessagesSQL = `
|
|||
`
|
||||
|
||||
type sendToDeviceStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
insertSendToDeviceMessageStmt *sql.Stmt
|
||||
selectSendToDeviceMessagesStmt *sql.Stmt
|
||||
countSendToDeviceMessagesStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
|
||||
s := &sendToDeviceStatements{}
|
||||
s := &sendToDeviceStatements{
|
||||
db: db,
|
||||
writer: sqlutil.NewTransactionWriter(),
|
||||
}
|
||||
_, err := db.Exec(sendToDeviceSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -98,8 +103,10 @@ func NewSqliteSendToDeviceTable(db *sql.DB) (tables.SendToDevice, error) {
|
|||
func (s *sendToDeviceStatements) InsertSendToDeviceMessage(
|
||||
ctx context.Context, txn *sql.Tx, userID, deviceID, content string,
|
||||
) (err error) {
|
||||
_, err = sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content)
|
||||
return
|
||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.insertSendToDeviceMessageStmt).ExecContext(ctx, userID, deviceID, content)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func (s *sendToDeviceStatements) CountSendToDeviceMessages(
|
||||
|
@ -156,8 +163,10 @@ func (s *sendToDeviceStatements) UpdateSentSendToDeviceMessages(
|
|||
for k, v := range nids {
|
||||
params[k+1] = v
|
||||
}
|
||||
_, err = txn.ExecContext(ctx, query, params...)
|
||||
return
|
||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||
_, err := txn.ExecContext(ctx, query, params...)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func (s *sendToDeviceStatements) DeleteSendToDeviceMessages(
|
||||
|
@ -168,6 +177,8 @@ func (s *sendToDeviceStatements) DeleteSendToDeviceMessages(
|
|||
for k, v := range nids {
|
||||
params[k] = v
|
||||
}
|
||||
_, err = txn.ExecContext(ctx, query, params...)
|
||||
return
|
||||
return s.writer.Do(s.db, txn, func(txn *sql.Tx) error {
|
||||
_, err := txn.ExecContext(ctx, query, params...)
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
|
|
@ -27,11 +27,15 @@ const selectStreamIDStmt = "" +
|
|||
"SELECT stream_id FROM syncapi_stream_id WHERE stream_name = $1"
|
||||
|
||||
type streamIDStatements struct {
|
||||
db *sql.DB
|
||||
writer *sqlutil.TransactionWriter
|
||||
increaseStreamIDStmt *sql.Stmt
|
||||
selectStreamIDStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
|
||||
s.db = db
|
||||
s.writer = sqlutil.NewTransactionWriter()
|
||||
_, err = db.Exec(streamIDTableSchema)
|
||||
if err != nil {
|
||||
return
|
||||
|
@ -48,11 +52,14 @@ func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
|
|||
func (s *streamIDStatements) nextStreamID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
|
||||
selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt)
|
||||
if _, err = increaseStmt.ExecContext(ctx, "global"); err != nil {
|
||||
return
|
||||
}
|
||||
if err = selectStmt.QueryRowContext(ctx, "global").Scan(&pos); err != nil {
|
||||
return
|
||||
}
|
||||
err = s.writer.Do(s.db, nil, func(txn *sql.Tx) error {
|
||||
if _, ierr := increaseStmt.ExecContext(ctx, "global"); err != nil {
|
||||
return ierr
|
||||
}
|
||||
if serr := selectStmt.QueryRowContext(ctx, "global").Scan(&pos); err != nil {
|
||||
return serr
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"crypto/ed25519"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -52,7 +53,13 @@ func MustCreateEvent(t *testing.T, roomID string, prevs []gomatrixserverlib.Head
|
|||
}
|
||||
|
||||
func MustCreateDatabase(t *testing.T) storage.Database {
|
||||
db, err := sqlite3.NewDatabase("file::memory:")
|
||||
dbname := fmt.Sprintf("test_%s.db", t.Name())
|
||||
if _, err := os.Stat(dbname); err == nil {
|
||||
if err = os.Remove(dbname); err != nil {
|
||||
t.Fatalf("tried to delete stale test database but failed: %s", err)
|
||||
}
|
||||
}
|
||||
db, err := sqlite3.NewDatabase(fmt.Sprintf("file:%s", dbname))
|
||||
if err != nil {
|
||||
t.Fatalf("NewSyncServerDatasource returned %s", err)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue