* Add Range

* Use Range
This commit is contained in:
Kegsay 2020-05-15 09:41:12 +01:00 committed by GitHub
parent 419ff150d4
commit 2b5052eccf
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 108 additions and 74 deletions

View file

@ -64,7 +64,7 @@ type Database interface {
// Returns a map following the format data[roomID] = []dataTypes // Returns a map following the format data[roomID] = []dataTypes
// If no data is retrieved, returns an empty map // If no data is retrieved, returns an empty map
// If there was an issue with the retrieval, returns an error // If there was an issue with the retrieval, returns an error
GetAccountDataInRange(ctx context.Context, userID string, oldPos, newPos types.StreamPosition, accountDataFilterPart *gomatrixserverlib.EventFilter) (map[string][]string, error) GetAccountDataInRange(ctx context.Context, userID string, r types.Range, accountDataFilterPart *gomatrixserverlib.EventFilter) (map[string][]string, error)
// UpsertAccountData keeps track of new or updated account data, by saving the type // UpsertAccountData keeps track of new or updated account data, by saving the type
// of the new/updated data, and the user ID and room ID the data is related to (empty) // of the new/updated data, and the user ID and room ID the data is related to (empty)
// room ID means the data isn't specific to any room) // room ID means the data isn't specific to any room)

View file

@ -100,19 +100,12 @@ func (s *accountDataStatements) InsertAccountData(
func (s *accountDataStatements) SelectAccountDataInRange( func (s *accountDataStatements) SelectAccountDataInRange(
ctx context.Context, ctx context.Context,
userID string, userID string,
oldPos, newPos types.StreamPosition, r types.Range,
accountDataEventFilter *gomatrixserverlib.EventFilter, accountDataEventFilter *gomatrixserverlib.EventFilter,
) (data map[string][]string, err error) { ) (data map[string][]string, err error) {
data = make(map[string][]string) data = make(map[string][]string)
// If both positions are the same, it means that the data was saved after the rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, r.Low(), r.High(),
// latest room event. In that case, we need to decrement the old position as
// it would prevent the SQL request from returning anything.
if oldPos == newPos {
oldPos--
}
rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, oldPos, newPos,
pq.StringArray(filterConvertTypeWildcardToSQL(accountDataEventFilter.Types)), pq.StringArray(filterConvertTypeWildcardToSQL(accountDataEventFilter.Types)),
pq.StringArray(filterConvertTypeWildcardToSQL(accountDataEventFilter.NotTypes)), pq.StringArray(filterConvertTypeWildcardToSQL(accountDataEventFilter.NotTypes)),
accountDataEventFilter.Limit, accountDataEventFilter.Limit,

View file

@ -117,10 +117,10 @@ func (s *inviteEventsStatements) DeleteInviteEvent(
// selectInviteEventsInRange returns a map of room ID to invite event for the // selectInviteEventsInRange returns a map of room ID to invite event for the
// active invites for the target user ID in the supplied range. // active invites for the target user ID in the supplied range.
func (s *inviteEventsStatements) SelectInviteEventsInRange( func (s *inviteEventsStatements) SelectInviteEventsInRange(
ctx context.Context, txn *sql.Tx, targetUserID string, startPos, endPos types.StreamPosition, ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range,
) (map[string]gomatrixserverlib.HeaderedEvent, error) { ) (map[string]gomatrixserverlib.HeaderedEvent, error) {
stmt := common.TxStmt(txn, s.selectInviteEventsInRangeStmt) stmt := common.TxStmt(txn, s.selectInviteEventsInRangeStmt)
rows, err := stmt.QueryContext(ctx, targetUserID, startPos, endPos) rows, err := stmt.QueryContext(ctx, targetUserID, r.Low(), r.High())
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -155,13 +155,13 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) {
// Results are bucketed based on the room ID. If the same state is overwritten multiple times between the // Results are bucketed based on the room ID. If the same state is overwritten multiple times between the
// two positions, only the most recent state is returned. // two positions, only the most recent state is returned.
func (s *outputRoomEventsStatements) SelectStateInRange( func (s *outputRoomEventsStatements) SelectStateInRange(
ctx context.Context, txn *sql.Tx, oldPos, newPos types.StreamPosition, ctx context.Context, txn *sql.Tx, r types.Range,
stateFilter *gomatrixserverlib.StateFilter, stateFilter *gomatrixserverlib.StateFilter,
) (map[string]map[string]bool, map[string]types.StreamEvent, error) { ) (map[string]map[string]bool, map[string]types.StreamEvent, error) {
stmt := common.TxStmt(txn, s.selectStateInRangeStmt) stmt := common.TxStmt(txn, s.selectStateInRangeStmt)
rows, err := stmt.QueryContext( rows, err := stmt.QueryContext(
ctx, oldPos, newPos, ctx, r.Low(), r.High(),
pq.StringArray(stateFilter.Senders), pq.StringArray(stateFilter.Senders),
pq.StringArray(stateFilter.NotSenders), pq.StringArray(stateFilter.NotSenders),
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.Types)), pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.Types)),
@ -198,8 +198,8 @@ func (s *outputRoomEventsStatements) SelectStateInRange(
// since it'll just mark the event as not being needed. // since it'll just mark the event as not being needed.
if len(addIDs) < len(delIDs) { if len(addIDs) < len(delIDs) {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"since": oldPos, "since": r.From,
"current": newPos, "current": r.To,
"adds": addIDs, "adds": addIDs,
"dels": delIDs, "dels": delIDs,
}).Warn("StateBetween: ignoring deleted state") }).Warn("StateBetween: ignoring deleted state")
@ -298,7 +298,7 @@ func (s *outputRoomEventsStatements) InsertEvent(
// from sync. // from sync.
func (s *outputRoomEventsStatements) SelectRecentEvents( func (s *outputRoomEventsStatements) SelectRecentEvents(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
roomID string, fromPos, toPos types.StreamPosition, limit int, roomID string, r types.Range, limit int,
chronologicalOrder bool, onlySyncEvents bool, chronologicalOrder bool, onlySyncEvents bool,
) ([]types.StreamEvent, error) { ) ([]types.StreamEvent, error) {
var stmt *sql.Stmt var stmt *sql.Stmt
@ -307,7 +307,7 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
} else { } else {
stmt = common.TxStmt(txn, s.selectRecentEventsStmt) stmt = common.TxStmt(txn, s.selectRecentEventsStmt)
} }
rows, err := stmt.QueryContext(ctx, roomID, fromPos, toPos, limit) rows, err := stmt.QueryContext(ctx, roomID, r.Low(), r.High(), limit)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -331,10 +331,10 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
// from a given position, up to a maximum of 'limit'. // from a given position, up to a maximum of 'limit'.
func (s *outputRoomEventsStatements) SelectEarlyEvents( func (s *outputRoomEventsStatements) SelectEarlyEvents(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
roomID string, fromPos, toPos types.StreamPosition, limit int, roomID string, r types.Range, limit int,
) ([]types.StreamEvent, error) { ) ([]types.StreamEvent, error) {
stmt := common.TxStmt(txn, s.selectEarlyEventsStmt) stmt := common.TxStmt(txn, s.selectEarlyEventsStmt)
rows, err := stmt.QueryContext(ctx, roomID, fromPos, toPos, limit) rows, err := stmt.QueryContext(ctx, roomID, r.Low(), r.High(), limit)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -54,17 +54,22 @@ func (d *Database) GetEventsInStreamingRange(
roomID string, limit int, roomID string, limit int,
backwardOrdering bool, backwardOrdering bool,
) (events []types.StreamEvent, err error) { ) (events []types.StreamEvent, err error) {
r := types.Range{
From: from.PDUPosition(),
To: to.PDUPosition(),
Backwards: backwardOrdering,
}
if backwardOrdering { if backwardOrdering {
// When using backward ordering, we want the most recent events first. // When using backward ordering, we want the most recent events first.
if events, err = d.OutputEvents.SelectRecentEvents( if events, err = d.OutputEvents.SelectRecentEvents(
ctx, nil, roomID, to.PDUPosition(), from.PDUPosition(), limit, false, false, ctx, nil, roomID, r, limit, false, false,
); err != nil { ); err != nil {
return return
} }
} else { } else {
// When using forward ordering, we want the least recent events first. // When using forward ordering, we want the least recent events first.
if events, err = d.OutputEvents.SelectEarlyEvents( if events, err = d.OutputEvents.SelectEarlyEvents(
ctx, nil, roomID, from.PDUPosition(), to.PDUPosition(), limit, ctx, nil, roomID, r, limit,
); err != nil { ); err != nil {
return return
} }
@ -167,10 +172,10 @@ func (d *Database) RetireInviteEvent(
// If no data is retrieved, returns an empty map // If no data is retrieved, returns an empty map
// If there was an issue with the retrieval, returns an error // If there was an issue with the retrieval, returns an error
func (d *Database) GetAccountDataInRange( func (d *Database) GetAccountDataInRange(
ctx context.Context, userID string, oldPos, newPos types.StreamPosition, ctx context.Context, userID string, r types.Range,
accountDataFilterPart *gomatrixserverlib.EventFilter, accountDataFilterPart *gomatrixserverlib.EventFilter,
) (map[string][]string, error) { ) (map[string][]string, error) {
return d.AccountData.SelectAccountDataInRange(ctx, userID, oldPos, newPos, accountDataFilterPart) return d.AccountData.SelectAccountDataInRange(ctx, userID, r, accountDataFilterPart)
} }
// UpsertAccountData keeps track of new or updated account data, by saving the type // UpsertAccountData keeps track of new or updated account data, by saving the type
@ -417,7 +422,7 @@ func (d *Database) syncPositionTx(
func (d *Database) addPDUDeltaToResponse( func (d *Database) addPDUDeltaToResponse(
ctx context.Context, ctx context.Context,
device authtypes.Device, device authtypes.Device,
fromPos, toPos types.StreamPosition, r types.Range,
numRecentEventsPerRoom int, numRecentEventsPerRoom int,
wantFullState bool, wantFullState bool,
res *types.Response, res *types.Response,
@ -443,11 +448,11 @@ func (d *Database) addPDUDeltaToResponse(
var deltas []stateDelta var deltas []stateDelta
if !wantFullState { if !wantFullState {
deltas, joinedRoomIDs, err = d.getStateDeltas( deltas, joinedRoomIDs, err = d.getStateDeltas(
ctx, &device, txn, fromPos, toPos, device.UserID, &stateFilter, ctx, &device, txn, r, device.UserID, &stateFilter,
) )
} else { } else {
deltas, joinedRoomIDs, err = d.getStateDeltasForFullStateSync( deltas, joinedRoomIDs, err = d.getStateDeltasForFullStateSync(
ctx, &device, txn, fromPos, toPos, device.UserID, &stateFilter, ctx, &device, txn, r, device.UserID, &stateFilter,
) )
} }
if err != nil { if err != nil {
@ -455,14 +460,14 @@ func (d *Database) addPDUDeltaToResponse(
} }
for _, delta := range deltas { for _, delta := range deltas {
err = d.addRoomDeltaToResponse(ctx, &device, txn, fromPos, toPos, delta, numRecentEventsPerRoom, res) err = d.addRoomDeltaToResponse(ctx, &device, txn, r, delta, numRecentEventsPerRoom, res)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
// TODO: This should be done in getStateDeltas // TODO: This should be done in getStateDeltas
if err = d.addInvitesToResponse(ctx, txn, device.UserID, fromPos, toPos, res); err != nil { if err = d.addInvitesToResponse(ctx, txn, device.UserID, r, res); err != nil {
return nil, err return nil, err
} }
@ -534,8 +539,12 @@ func (d *Database) IncrementalSync(
var joinedRoomIDs []string var joinedRoomIDs []string
var err error var err error
if fromPos.PDUPosition() != toPos.PDUPosition() || wantFullState { if fromPos.PDUPosition() != toPos.PDUPosition() || wantFullState {
r := types.Range{
From: fromPos.PDUPosition(),
To: toPos.PDUPosition(),
}
joinedRoomIDs, err = d.addPDUDeltaToResponse( joinedRoomIDs, err = d.addPDUDeltaToResponse(
ctx, device, fromPos.PDUPosition(), toPos.PDUPosition(), numRecentEventsPerRoom, wantFullState, res, ctx, device, r, numRecentEventsPerRoom, wantFullState, res,
) )
} else { } else {
joinedRoomIDs, err = d.CurrentRoomState.SelectRoomIDsWithMembership( joinedRoomIDs, err = d.CurrentRoomState.SelectRoomIDsWithMembership(
@ -589,6 +598,10 @@ func (d *Database) getResponseWithPDUsForCompleteSync(
if err != nil { if err != nil {
return return
} }
r := types.Range{
From: 0,
To: toPos.PDUPosition(),
}
res = types.NewResponse(toPos) res = types.NewResponse(toPos)
@ -611,8 +624,7 @@ func (d *Database) getResponseWithPDUsForCompleteSync(
// See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316 // See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316
var recentStreamEvents []types.StreamEvent var recentStreamEvents []types.StreamEvent
recentStreamEvents, err = d.OutputEvents.SelectRecentEvents( recentStreamEvents, err = d.OutputEvents.SelectRecentEvents(
ctx, txn, roomID, types.StreamPosition(0), toPos.PDUPosition(), ctx, txn, roomID, r, numRecentEventsPerRoom, true, true,
numRecentEventsPerRoom, true, true,
) )
if err != nil { if err != nil {
return return
@ -644,7 +656,7 @@ func (d *Database) getResponseWithPDUsForCompleteSync(
res.Rooms.Join[roomID] = *jr res.Rooms.Join[roomID] = *jr
} }
if err = d.addInvitesToResponse(ctx, txn, userID, 0, toPos.PDUPosition(), res); err != nil { if err = d.addInvitesToResponse(ctx, txn, userID, r, res); err != nil {
return return
} }
@ -686,11 +698,11 @@ var txReadOnlySnapshot = sql.TxOptions{
func (d *Database) addInvitesToResponse( func (d *Database) addInvitesToResponse(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
userID string, userID string,
fromPos, toPos types.StreamPosition, r types.Range,
res *types.Response, res *types.Response,
) error { ) error {
invites, err := d.Invites.SelectInviteEventsInRange( invites, err := d.Invites.SelectInviteEventsInRange(
ctx, txn, userID, fromPos, toPos, ctx, txn, userID, r,
) )
if err != nil { if err != nil {
return err return err
@ -726,12 +738,11 @@ func (d *Database) addRoomDeltaToResponse(
ctx context.Context, ctx context.Context,
device *authtypes.Device, device *authtypes.Device,
txn *sql.Tx, txn *sql.Tx,
fromPos, toPos types.StreamPosition, r types.Range,
delta stateDelta, delta stateDelta,
numRecentEventsPerRoom int, numRecentEventsPerRoom int,
res *types.Response, res *types.Response,
) error { ) error {
endPos := toPos
if delta.membershipPos > 0 && delta.membership == gomatrixserverlib.Leave { if delta.membershipPos > 0 && delta.membership == gomatrixserverlib.Leave {
// make sure we don't leak recent events after the leave event. // make sure we don't leak recent events after the leave event.
// TODO: History visibility makes this somewhat complex to handle correctly. For example: // TODO: History visibility makes this somewhat complex to handle correctly. For example:
@ -739,10 +750,10 @@ func (d *Database) addRoomDeltaToResponse(
// TODO: This will fail on join -> leave -> sensitive msg -> join -> leave // TODO: This will fail on join -> leave -> sensitive msg -> join -> leave
// in a single /sync request // in a single /sync request
// This is all "okay" assuming history_visibility == "shared" which it is by default. // This is all "okay" assuming history_visibility == "shared" which it is by default.
endPos = delta.membershipPos r.To = delta.membershipPos
} }
recentStreamEvents, err := d.OutputEvents.SelectRecentEvents( recentStreamEvents, err := d.OutputEvents.SelectRecentEvents(
ctx, txn, delta.roomID, types.StreamPosition(fromPos), types.StreamPosition(endPos), ctx, txn, delta.roomID, r,
numRecentEventsPerRoom, true, true, numRecentEventsPerRoom, true, true,
) )
if err != nil { if err != nil {
@ -872,7 +883,7 @@ func (d *Database) fetchMissingStateEvents(
// A list of joined room IDs is also returned in case the caller needs it. // A list of joined room IDs is also returned in case the caller needs it.
func (d *Database) getStateDeltas( func (d *Database) getStateDeltas(
ctx context.Context, device *authtypes.Device, txn *sql.Tx, ctx context.Context, device *authtypes.Device, txn *sql.Tx,
fromPos, toPos types.StreamPosition, userID string, r types.Range, userID string,
stateFilter *gomatrixserverlib.StateFilter, stateFilter *gomatrixserverlib.StateFilter,
) ([]stateDelta, []string, error) { ) ([]stateDelta, []string, error) {
// Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821 // Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821
@ -886,7 +897,7 @@ func (d *Database) getStateDeltas(
var deltas []stateDelta var deltas []stateDelta
// get all the state events ever between these two positions // get all the state events ever between these two positions
stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, fromPos, toPos, stateFilter) stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
@ -947,7 +958,7 @@ func (d *Database) getStateDeltas(
// updates for other rooms. // updates for other rooms.
func (d *Database) getStateDeltasForFullStateSync( func (d *Database) getStateDeltasForFullStateSync(
ctx context.Context, device *authtypes.Device, txn *sql.Tx, ctx context.Context, device *authtypes.Device, txn *sql.Tx,
fromPos, toPos types.StreamPosition, userID string, r types.Range, userID string,
stateFilter *gomatrixserverlib.StateFilter, stateFilter *gomatrixserverlib.StateFilter,
) ([]stateDelta, []string, error) { ) ([]stateDelta, []string, error) {
joinedRoomIDs, err := d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join) joinedRoomIDs, err := d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, txn, userID, gomatrixserverlib.Join)
@ -972,7 +983,7 @@ func (d *Database) getStateDeltasForFullStateSync(
} }
// Get all the state events ever between these two positions // Get all the state events ever between these two positions
stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, fromPos, toPos, stateFilter) stateNeeded, eventMap, err := d.OutputEvents.SelectStateInRange(ctx, txn, r, stateFilter)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }

View file

@ -91,19 +91,12 @@ func (s *accountDataStatements) InsertAccountData(
func (s *accountDataStatements) SelectAccountDataInRange( func (s *accountDataStatements) SelectAccountDataInRange(
ctx context.Context, ctx context.Context,
userID string, userID string,
oldPos, newPos types.StreamPosition, r types.Range,
accountDataFilterPart *gomatrixserverlib.EventFilter, accountDataFilterPart *gomatrixserverlib.EventFilter,
) (data map[string][]string, err error) { ) (data map[string][]string, err error) {
data = make(map[string][]string) data = make(map[string][]string)
// If both positions are the same, it means that the data was saved after the rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, r.Low(), r.High())
// latest room event. In that case, we need to decrement the old position as
// it would prevent the SQL request from returning anything.
if oldPos == newPos {
oldPos--
}
rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, oldPos, newPos)
if err != nil { if err != nil {
return return
} }

View file

@ -121,10 +121,10 @@ func (s *inviteEventsStatements) DeleteInviteEvent(
// selectInviteEventsInRange returns a map of room ID to invite event for the // selectInviteEventsInRange returns a map of room ID to invite event for the
// active invites for the target user ID in the supplied range. // active invites for the target user ID in the supplied range.
func (s *inviteEventsStatements) SelectInviteEventsInRange( func (s *inviteEventsStatements) SelectInviteEventsInRange(
ctx context.Context, txn *sql.Tx, targetUserID string, startPos, endPos types.StreamPosition, ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range,
) (map[string]gomatrixserverlib.HeaderedEvent, error) { ) (map[string]gomatrixserverlib.HeaderedEvent, error) {
stmt := common.TxStmt(txn, s.selectInviteEventsInRangeStmt) stmt := common.TxStmt(txn, s.selectInviteEventsInRangeStmt)
rows, err := stmt.QueryContext(ctx, targetUserID, startPos, endPos) rows, err := stmt.QueryContext(ctx, targetUserID, r.Low(), r.High())
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -146,13 +146,13 @@ func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Even
// Results are bucketed based on the room ID. If the same state is overwritten multiple times between the // Results are bucketed based on the room ID. If the same state is overwritten multiple times between the
// two positions, only the most recent state is returned. // two positions, only the most recent state is returned.
func (s *outputRoomEventsStatements) SelectStateInRange( func (s *outputRoomEventsStatements) SelectStateInRange(
ctx context.Context, txn *sql.Tx, oldPos, newPos types.StreamPosition, ctx context.Context, txn *sql.Tx, r types.Range,
stateFilterPart *gomatrixserverlib.StateFilter, stateFilterPart *gomatrixserverlib.StateFilter,
) (map[string]map[string]bool, map[string]types.StreamEvent, error) { ) (map[string]map[string]bool, map[string]types.StreamEvent, error) {
stmt := common.TxStmt(txn, s.selectStateInRangeStmt) stmt := common.TxStmt(txn, s.selectStateInRangeStmt)
rows, err := stmt.QueryContext( rows, err := stmt.QueryContext(
ctx, oldPos, newPos, ctx, r.Low(), r.High(),
/*pq.StringArray(stateFilterPart.Senders), /*pq.StringArray(stateFilterPart.Senders),
pq.StringArray(stateFilterPart.NotSenders), pq.StringArray(stateFilterPart.NotSenders),
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilterPart.Types)), pq.StringArray(filterConvertTypeWildcardToSQL(stateFilterPart.Types)),
@ -195,8 +195,8 @@ func (s *outputRoomEventsStatements) SelectStateInRange(
// since it'll just mark the event as not being needed. // since it'll just mark the event as not being needed.
if len(addIDs) < len(delIDs) { if len(addIDs) < len(delIDs) {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"since": oldPos, "since": r.From,
"current": newPos, "current": r.To,
"adds": addIDsJSON, "adds": addIDsJSON,
"dels": delIDsJSON, "dels": delIDsJSON,
}).Warn("StateBetween: ignoring deleted state") }).Warn("StateBetween: ignoring deleted state")
@ -308,7 +308,7 @@ func (s *outputRoomEventsStatements) InsertEvent(
func (s *outputRoomEventsStatements) SelectRecentEvents( func (s *outputRoomEventsStatements) SelectRecentEvents(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
roomID string, fromPos, toPos types.StreamPosition, limit int, roomID string, r types.Range, limit int,
chronologicalOrder bool, onlySyncEvents bool, chronologicalOrder bool, onlySyncEvents bool,
) ([]types.StreamEvent, error) { ) ([]types.StreamEvent, error) {
var stmt *sql.Stmt var stmt *sql.Stmt
@ -318,7 +318,7 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
stmt = common.TxStmt(txn, s.selectRecentEventsStmt) stmt = common.TxStmt(txn, s.selectRecentEventsStmt)
} }
rows, err := stmt.QueryContext(ctx, roomID, fromPos, toPos, limit) rows, err := stmt.QueryContext(ctx, roomID, r.Low(), r.High(), limit)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -340,10 +340,10 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
func (s *outputRoomEventsStatements) SelectEarlyEvents( func (s *outputRoomEventsStatements) SelectEarlyEvents(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
roomID string, fromPos, toPos types.StreamPosition, limit int, roomID string, r types.Range, limit int,
) ([]types.StreamEvent, error) { ) ([]types.StreamEvent, error) {
stmt := common.TxStmt(txn, s.selectEarlyEventsStmt) stmt := common.TxStmt(txn, s.selectEarlyEventsStmt)
rows, err := stmt.QueryContext(ctx, roomID, fromPos, toPos, limit) rows, err := stmt.QueryContext(ctx, roomID, r.Low(), r.High(), limit)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -11,29 +11,29 @@ import (
type AccountData interface { type AccountData interface {
InsertAccountData(ctx context.Context, txn *sql.Tx, userID, roomID, dataType string) (pos types.StreamPosition, err error) InsertAccountData(ctx context.Context, txn *sql.Tx, userID, roomID, dataType string) (pos types.StreamPosition, err error)
// SelectAccountDataInRange returns a map of room ID to a list of `dataType`. The range is exclusive of `lowPos` and inclusive of `hiPos`. // SelectAccountDataInRange returns a map of room ID to a list of `dataType`.
SelectAccountDataInRange(ctx context.Context, userID string, lowPos, hiPos types.StreamPosition, accountDataEventFilter *gomatrixserverlib.EventFilter) (data map[string][]string, err error) SelectAccountDataInRange(ctx context.Context, userID string, r types.Range, accountDataEventFilter *gomatrixserverlib.EventFilter) (data map[string][]string, err error)
SelectMaxAccountDataID(ctx context.Context, txn *sql.Tx) (id int64, err error) SelectMaxAccountDataID(ctx context.Context, txn *sql.Tx) (id int64, err error)
} }
type Invites interface { type Invites interface {
InsertInviteEvent(ctx context.Context, txn *sql.Tx, inviteEvent gomatrixserverlib.HeaderedEvent) (streamPos types.StreamPosition, err error) InsertInviteEvent(ctx context.Context, txn *sql.Tx, inviteEvent gomatrixserverlib.HeaderedEvent) (streamPos types.StreamPosition, err error)
DeleteInviteEvent(ctx context.Context, inviteEventID string) error DeleteInviteEvent(ctx context.Context, inviteEventID string) error
// SelectInviteEventsInRange returns a map of room ID to invite events. The range is exclusive of `startPos` and inclusive of `endPos`. // SelectInviteEventsInRange returns a map of room ID to invite events.
SelectInviteEventsInRange(ctx context.Context, txn *sql.Tx, targetUserID string, startPos, endPos types.StreamPosition) (map[string]gomatrixserverlib.HeaderedEvent, error) SelectInviteEventsInRange(ctx context.Context, txn *sql.Tx, targetUserID string, r types.Range) (map[string]gomatrixserverlib.HeaderedEvent, error)
SelectMaxInviteID(ctx context.Context, txn *sql.Tx) (id int64, err error) SelectMaxInviteID(ctx context.Context, txn *sql.Tx) (id int64, err error)
} }
type Events interface { type Events interface {
SelectStateInRange(ctx context.Context, txn *sql.Tx, oldPos, newPos types.StreamPosition, stateFilter *gomatrixserverlib.StateFilter) (map[string]map[string]bool, map[string]types.StreamEvent, error) SelectStateInRange(ctx context.Context, txn *sql.Tx, r types.Range, stateFilter *gomatrixserverlib.StateFilter) (map[string]map[string]bool, map[string]types.StreamEvent, error)
SelectMaxEventID(ctx context.Context, txn *sql.Tx) (id int64, err error) SelectMaxEventID(ctx context.Context, txn *sql.Tx) (id int64, err error)
InsertEvent(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, addState, removeState []string, transactionID *api.TransactionID, excludeFromSync bool) (streamPos types.StreamPosition, err error) InsertEvent(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, addState, removeState []string, transactionID *api.TransactionID, excludeFromSync bool) (streamPos types.StreamPosition, err error)
// SelectRecentEvents returns events between the two stream positions: exclusive of `fromPos` and inclusive of `toPos`. // SelectRecentEvents returns events between the two stream positions: exclusive of low and inclusive of high.
// If onlySyncEvents has a value of true, only returns the events that aren't marked as to exclude from sync. // If onlySyncEvents has a value of true, only returns the events that aren't marked as to exclude from sync.
// Returns up to `limit` events. // Returns up to `limit` events.
SelectRecentEvents(ctx context.Context, txn *sql.Tx, roomID string, fromPos, toPos types.StreamPosition, limit int, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, error) SelectRecentEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, limit int, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, error)
// SelectEarlyEvents returns the earliest events in the given room, exclusive of `fromPos` and inclusive of `toPos`. // SelectEarlyEvents returns the earliest events in the given room.
SelectEarlyEvents(ctx context.Context, txn *sql.Tx, roomID string, fromPos, toPos types.StreamPosition, limit int) ([]types.StreamEvent, error) SelectEarlyEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, limit int) ([]types.StreamEvent, error)
SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StreamEvent, error) SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StreamEvent, error)
} }

View file

@ -184,11 +184,20 @@ func (rp *RequestPool) appendAccountData(
return data, nil return data, nil
} }
r := types.Range{
From: req.since.PDUPosition(),
To: currentPos,
}
// If both positions are the same, it means that the data was saved after the
// latest room event. In that case, we need to decrement the old position as
// results are exclusive of Low.
if r.Low() == r.High() {
r.From--
}
// Sync is not initial, get all account data since the latest sync // Sync is not initial, get all account data since the latest sync
dataTypes, err := rp.db.GetAccountDataInRange( dataTypes, err := rp.db.GetAccountDataInRange(
req.ctx, userID, req.ctx, userID, r, accountDataFilter,
types.StreamPosition(req.since.PDUPosition()), types.StreamPosition(currentPos),
accountDataFilter,
) )
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -47,6 +47,34 @@ type StreamEvent struct {
ExcludeFromSync bool ExcludeFromSync bool
} }
// Range represents a range between two stream positions.
type Range struct {
// From is the position the client has already received.
From StreamPosition
// To is the position the client is going towards.
To StreamPosition
// True if the client is going backwards
Backwards bool
}
// Low returns the low number of the range.
// This represents the position the client already has and hence is exclusive.
func (r *Range) Low() StreamPosition {
if !r.Backwards {
return r.From
}
return r.To
}
// High returns the high number of the range
// This represents the position the client is going towards and hence is inclusive.
func (r *Range) High() StreamPosition {
if !r.Backwards {
return r.To
}
return r.From
}
// SyncTokenType represents the type of a sync token. // SyncTokenType represents the type of a sync token.
// It can be either "s" (representing a position in the whole stream of events) // It can be either "s" (representing a position in the whole stream of events)
// or "t" (representing a position in a room's topology/depth). // or "t" (representing a position in a room's topology/depth).