mirror of
https://github.com/hoernschen/dendrite.git
synced 2025-08-02 06:12:45 +00:00
syncapi: add more tests; fix more bugs (#2338)
* syncapi: add more tests; fix more bugs bugfixes: - The postgres impl of TopologyTable.SelectEventIDsInRange did not use the provided txn - The postgres impl of EventsTable.SelectEvents did not preserve the ordering of the input event IDs in the output events slice - The sqlite impl of EventsTable.SelectEvents did not use a bulk `IN ($1)` query. Added tests: - `TestGetEventsInRangeWithTopologyToken` - `TestOutputRoomEventsTable` - `TestTopologyTable` * -p 1 for now
This commit is contained in:
parent
986d27a128
commit
6d25bd6ca5
20 changed files with 388 additions and 197 deletions
|
@ -51,13 +51,13 @@ const selectMaxAccountDataIDSQL = "" +
|
|||
|
||||
type accountDataStatements struct {
|
||||
db *sql.DB
|
||||
streamIDStatements *streamIDStatements
|
||||
streamIDStatements *StreamIDStatements
|
||||
insertAccountDataStmt *sql.Stmt
|
||||
selectMaxAccountDataIDStmt *sql.Stmt
|
||||
selectAccountDataInRangeStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteAccountDataTable(db *sql.DB, streamID *streamIDStatements) (tables.AccountData, error) {
|
||||
func NewSqliteAccountDataTable(db *sql.DB, streamID *StreamIDStatements) (tables.AccountData, error) {
|
||||
s := &accountDataStatements{
|
||||
db: db,
|
||||
streamIDStatements: streamID,
|
||||
|
|
|
@ -90,7 +90,7 @@ const selectEventsWithEventIDsSQL = "" +
|
|||
|
||||
type currentRoomStateStatements struct {
|
||||
db *sql.DB
|
||||
streamIDStatements *streamIDStatements
|
||||
streamIDStatements *StreamIDStatements
|
||||
upsertRoomStateStmt *sql.Stmt
|
||||
deleteRoomStateByEventIDStmt *sql.Stmt
|
||||
deleteRoomStateForRoomStmt *sql.Stmt
|
||||
|
@ -100,7 +100,7 @@ type currentRoomStateStatements struct {
|
|||
selectStateEventStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (tables.CurrentRoomState, error) {
|
||||
func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *StreamIDStatements) (tables.CurrentRoomState, error) {
|
||||
s := ¤tRoomStateStatements{
|
||||
db: db,
|
||||
streamIDStatements: streamID,
|
||||
|
|
|
@ -59,14 +59,14 @@ const selectMaxInviteIDSQL = "" +
|
|||
|
||||
type inviteEventsStatements struct {
|
||||
db *sql.DB
|
||||
streamIDStatements *streamIDStatements
|
||||
streamIDStatements *StreamIDStatements
|
||||
insertInviteEventStmt *sql.Stmt
|
||||
selectInviteEventsInRangeStmt *sql.Stmt
|
||||
deleteInviteEventStmt *sql.Stmt
|
||||
selectMaxInviteIDStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Invites, error) {
|
||||
func NewSqliteInvitesTable(db *sql.DB, streamID *StreamIDStatements) (tables.Invites, error) {
|
||||
s := &inviteEventsStatements{
|
||||
db: db,
|
||||
streamIDStatements: streamID,
|
||||
|
|
|
@ -58,7 +58,7 @@ const insertEventSQL = "" +
|
|||
"ON CONFLICT (event_id) DO UPDATE SET exclude_from_sync = (excluded.exclude_from_sync AND $13)"
|
||||
|
||||
const selectEventsSQL = "" +
|
||||
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id = $1"
|
||||
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events WHERE event_id IN ($1)"
|
||||
|
||||
const selectRecentEventsSQL = "" +
|
||||
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
|
||||
|
@ -111,9 +111,8 @@ const selectContextAfterEventSQL = "" +
|
|||
|
||||
type outputRoomEventsStatements struct {
|
||||
db *sql.DB
|
||||
streamIDStatements *streamIDStatements
|
||||
streamIDStatements *StreamIDStatements
|
||||
insertEventStmt *sql.Stmt
|
||||
selectEventsStmt *sql.Stmt
|
||||
selectMaxEventIDStmt *sql.Stmt
|
||||
updateEventJSONStmt *sql.Stmt
|
||||
deleteEventsForRoomStmt *sql.Stmt
|
||||
|
@ -122,7 +121,7 @@ type outputRoomEventsStatements struct {
|
|||
selectContextAfterEventStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Events, error) {
|
||||
func NewSqliteEventsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Events, error) {
|
||||
s := &outputRoomEventsStatements{
|
||||
db: db,
|
||||
streamIDStatements: streamID,
|
||||
|
@ -133,7 +132,6 @@ func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Even
|
|||
}
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.insertEventStmt, insertEventSQL},
|
||||
{&s.selectEventsStmt, selectEventsSQL},
|
||||
{&s.selectMaxEventIDStmt, selectMaxEventIDSQL},
|
||||
{&s.updateEventJSONStmt, updateEventJSONSQL},
|
||||
{&s.deleteEventsForRoomStmt, deleteEventsForRoomSQL},
|
||||
|
@ -421,21 +419,43 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
|
|||
// selectEvents returns the events for the given event IDs. If an event is
|
||||
// missing from the database, it will be omitted.
|
||||
func (s *outputRoomEventsStatements) SelectEvents(
|
||||
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
||||
ctx context.Context, txn *sql.Tx, eventIDs []string, preserveOrder bool,
|
||||
) ([]types.StreamEvent, error) {
|
||||
var returnEvents []types.StreamEvent
|
||||
stmt := sqlutil.TxStmt(txn, s.selectEventsStmt)
|
||||
for _, eventID := range eventIDs {
|
||||
rows, err := stmt.QueryContext(ctx, eventID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if streamEvents, err := rowsToStreamEvents(rows); err == nil {
|
||||
returnEvents = append(returnEvents, streamEvents...)
|
||||
}
|
||||
internal.CloseAndLogIfError(ctx, rows, "selectEvents: rows.close() failed")
|
||||
iEventIDs := make([]interface{}, len(eventIDs))
|
||||
for i := range eventIDs {
|
||||
iEventIDs[i] = eventIDs[i]
|
||||
}
|
||||
return returnEvents, nil
|
||||
selectSQL := strings.Replace(selectEventsSQL, "($1)", sqlutil.QueryVariadic(len(eventIDs)), 1)
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
if txn != nil {
|
||||
rows, err = txn.QueryContext(ctx, selectSQL, iEventIDs...)
|
||||
} else {
|
||||
rows, err = s.db.QueryContext(ctx, selectSQL, iEventIDs...)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectEvents: rows.close() failed")
|
||||
streamEvents, err := rowsToStreamEvents(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if preserveOrder {
|
||||
var returnEvents []types.StreamEvent
|
||||
eventMap := make(map[string]types.StreamEvent)
|
||||
for _, ev := range streamEvents {
|
||||
eventMap[ev.EventID()] = ev
|
||||
}
|
||||
for _, eventID := range eventIDs {
|
||||
ev, ok := eventMap[eventID]
|
||||
if ok {
|
||||
returnEvents = append(returnEvents, ev)
|
||||
}
|
||||
}
|
||||
return returnEvents, nil
|
||||
}
|
||||
return streamEvents, nil
|
||||
}
|
||||
|
||||
func (s *outputRoomEventsStatements) DeleteEventsForRoom(
|
||||
|
|
|
@ -66,7 +66,7 @@ const selectMaxPeekIDSQL = "" +
|
|||
|
||||
type peekStatements struct {
|
||||
db *sql.DB
|
||||
streamIDStatements *streamIDStatements
|
||||
streamIDStatements *StreamIDStatements
|
||||
insertPeekStmt *sql.Stmt
|
||||
deletePeekStmt *sql.Stmt
|
||||
deletePeeksStmt *sql.Stmt
|
||||
|
@ -75,7 +75,7 @@ type peekStatements struct {
|
|||
selectMaxPeekIDStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqlitePeeksTable(db *sql.DB, streamID *streamIDStatements) (tables.Peeks, error) {
|
||||
func NewSqlitePeeksTable(db *sql.DB, streamID *StreamIDStatements) (tables.Peeks, error) {
|
||||
_, err := db.Exec(peeksSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -75,7 +75,7 @@ const selectPresenceAfter = "" +
|
|||
|
||||
type presenceStatements struct {
|
||||
db *sql.DB
|
||||
streamIDStatements *streamIDStatements
|
||||
streamIDStatements *StreamIDStatements
|
||||
upsertPresenceStmt *sql.Stmt
|
||||
upsertPresenceFromSyncStmt *sql.Stmt
|
||||
selectPresenceForUsersStmt *sql.Stmt
|
||||
|
@ -83,7 +83,7 @@ type presenceStatements struct {
|
|||
selectPresenceAfterStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqlitePresenceTable(db *sql.DB, streamID *streamIDStatements) (*presenceStatements, error) {
|
||||
func NewSqlitePresenceTable(db *sql.DB, streamID *StreamIDStatements) (*presenceStatements, error) {
|
||||
_, err := db.Exec(presenceSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -59,13 +59,13 @@ const selectMaxReceiptIDSQL = "" +
|
|||
|
||||
type receiptStatements struct {
|
||||
db *sql.DB
|
||||
streamIDStatements *streamIDStatements
|
||||
streamIDStatements *StreamIDStatements
|
||||
upsertReceipt *sql.Stmt
|
||||
selectRoomReceipts *sql.Stmt
|
||||
selectMaxReceiptID *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteReceiptsTable(db *sql.DB, streamID *streamIDStatements) (tables.Receipts, error) {
|
||||
func NewSqliteReceiptsTable(db *sql.DB, streamID *StreamIDStatements) (tables.Receipts, error) {
|
||||
_, err := db.Exec(receiptsSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -32,12 +32,12 @@ const increaseStreamIDStmt = "" +
|
|||
"UPDATE syncapi_stream_id SET stream_id = stream_id + 1 WHERE stream_name = $1" +
|
||||
" RETURNING stream_id"
|
||||
|
||||
type streamIDStatements struct {
|
||||
type StreamIDStatements struct {
|
||||
db *sql.DB
|
||||
increaseStreamIDStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
|
||||
func (s *StreamIDStatements) Prepare(db *sql.DB) (err error) {
|
||||
s.db = db
|
||||
_, err = db.Exec(streamIDTableSchema)
|
||||
if err != nil {
|
||||
|
@ -49,31 +49,31 @@ func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func (s *streamIDStatements) nextPDUID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||
func (s *StreamIDStatements) nextPDUID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
|
||||
err = increaseStmt.QueryRowContext(ctx, "global").Scan(&pos)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *streamIDStatements) nextReceiptID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||
func (s *StreamIDStatements) nextReceiptID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
|
||||
err = increaseStmt.QueryRowContext(ctx, "receipt").Scan(&pos)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *streamIDStatements) nextInviteID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||
func (s *StreamIDStatements) nextInviteID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
|
||||
err = increaseStmt.QueryRowContext(ctx, "invite").Scan(&pos)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *streamIDStatements) nextAccountDataID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||
func (s *StreamIDStatements) nextAccountDataID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
|
||||
err = increaseStmt.QueryRowContext(ctx, "accountdata").Scan(&pos)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *streamIDStatements) nextPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||
func (s *StreamIDStatements) nextPresenceID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
|
||||
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
|
||||
err = increaseStmt.QueryRowContext(ctx, "presence").Scan(&pos)
|
||||
return
|
||||
|
|
|
@ -30,7 +30,7 @@ type SyncServerDatasource struct {
|
|||
shared.Database
|
||||
db *sql.DB
|
||||
writer sqlutil.Writer
|
||||
streamID streamIDStatements
|
||||
streamID StreamIDStatements
|
||||
}
|
||||
|
||||
// NewDatabase creates a new sync server database
|
||||
|
@ -49,7 +49,7 @@ func NewDatabase(dbProperties *config.DatabaseOptions) (*SyncServerDatasource, e
|
|||
}
|
||||
|
||||
func (d *SyncServerDatasource) prepare(dbProperties *config.DatabaseOptions) (err error) {
|
||||
if err = d.streamID.prepare(d.db); err != nil {
|
||||
if err = d.streamID.Prepare(d.db); err != nil {
|
||||
return err
|
||||
}
|
||||
accountData, err := NewSqliteAccountDataTable(d.db, &d.streamID)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue