Basic sync filtering (#1721)

* Add some filtering (postgres only for now)

* Fix build error

* Try to use request filter

* Use default filter as a template when retrieving from the database

* Remove unused strut

* Update sytest-whitelist

* Add filtering to SelectEarlyEvents

* Fix Postgres selectEarlyEvents query

* Attempt filtering on SQLite

* Test limit, set field for limit/order in prepareWithFilters

* Remove debug logging, add comments

* Tweaks, debug logging

* Separate SQLite stream IDs

* Fix filtering in current state table

* Fix lock issues

* More tweaks

* Current state requires room ID

* Review comments
This commit is contained in:
Neil Alexander 2021-01-19 18:00:42 +00:00 committed by GitHub
parent 80aa9aa8b0
commit b70238f2d5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 279 additions and 160 deletions

View file

@ -40,7 +40,7 @@ type Database interface {
GetStateDeltas(ctx context.Context, device *userapi.Device, r types.Range, userID string, stateFilter *gomatrixserverlib.StateFilter) ([]types.StateDelta, []string, error)
RoomIDsWithMembership(ctx context.Context, userID string, membership string) ([]string, error)
RecentEvents(ctx context.Context, roomID string, r types.Range, limit int, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error)
RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error)
GetBackwardTopologyPos(ctx context.Context, events []types.StreamEvent) (types.TopologyToken, error)
PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error)
@ -105,7 +105,7 @@ type Database interface {
// Returns an error if there was a problem communicating with the database.
DeletePeeks(ctx context.Context, RoomID, UserID string) (types.StreamPosition, error)
// GetEventsInStreamingRange retrieves all of the events on a given ordering using the given extremities and limit.
GetEventsInStreamingRange(ctx context.Context, from, to *types.StreamingToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error)
GetEventsInStreamingRange(ctx context.Context, from, to *types.StreamingToken, roomID string, eventFilter *gomatrixserverlib.RoomEventFilter, backwardOrdering bool) (events []types.StreamEvent, err error)
// GetEventsInTopologicalRange retrieves all of the events on a given ordering using the given extremities and limit.
GetEventsInTopologicalRange(ctx context.Context, from, to *types.TopologyToken, roomID string, limit int, backwardOrdering bool) (events []types.StreamEvent, err error)
// EventPositionInTopology returns the depth and stream position of the given event.

View file

@ -83,7 +83,7 @@ func (s *filterStatements) SelectFilter(
}
// Unmarshal JSON into Filter struct
var filter gomatrixserverlib.Filter
filter := gomatrixserverlib.DefaultFilter()
if err = json.Unmarshal(filterData, &filter); err != nil {
return nil, err
}

View file

@ -84,17 +84,29 @@ const selectEventsSQL = "" +
const selectRecentEventsSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3" +
" ORDER BY id DESC LIMIT $4"
" AND ( $4::text[] IS NULL OR sender = ANY($4) )" +
" AND ( $5::text[] IS NULL OR NOT(sender = ANY($5)) )" +
" AND ( $6::text[] IS NULL OR type LIKE ANY($6) )" +
" AND ( $7::text[] IS NULL OR NOT(type LIKE ANY($7)) )" +
" ORDER BY id DESC LIMIT $8"
const selectRecentEventsForSyncSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3 AND exclude_from_sync = FALSE" +
" ORDER BY id DESC LIMIT $4"
" AND ( $4::text[] IS NULL OR sender = ANY($4) )" +
" AND ( $5::text[] IS NULL OR NOT(sender = ANY($5)) )" +
" AND ( $6::text[] IS NULL OR type LIKE ANY($6) )" +
" AND ( $7::text[] IS NULL OR NOT(type LIKE ANY($7)) )" +
" ORDER BY id DESC LIMIT $8"
const selectEarlyEventsSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3" +
" ORDER BY id ASC LIMIT $4"
" AND ( $4::text[] IS NULL OR sender = ANY($4) )" +
" AND ( $5::text[] IS NULL OR NOT(sender = ANY($5)) )" +
" AND ( $6::text[] IS NULL OR type LIKE ANY($6) )" +
" AND ( $7::text[] IS NULL OR NOT(type LIKE ANY($7)) )" +
" ORDER BY id ASC LIMIT $8"
const selectMaxEventIDSQL = "" +
"SELECT MAX(id) FROM syncapi_output_room_events"
@ -322,7 +334,7 @@ func (s *outputRoomEventsStatements) InsertEvent(
// from sync.
func (s *outputRoomEventsStatements) SelectRecentEvents(
ctx context.Context, txn *sql.Tx,
roomID string, r types.Range, limit int,
roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter,
chronologicalOrder bool, onlySyncEvents bool,
) ([]types.StreamEvent, bool, error) {
var stmt *sql.Stmt
@ -331,7 +343,14 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
} else {
stmt = sqlutil.TxStmt(txn, s.selectRecentEventsStmt)
}
rows, err := stmt.QueryContext(ctx, roomID, r.Low(), r.High(), limit+1)
rows, err := stmt.QueryContext(
ctx, roomID, r.Low(), r.High(),
pq.StringArray(eventFilter.Senders),
pq.StringArray(eventFilter.NotSenders),
pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.Types)),
pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.NotTypes)),
eventFilter.Limit+1,
)
if err != nil {
return nil, false, err
}
@ -350,7 +369,7 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
}
// we queried for 1 more than the limit, so if we returned one more mark limited=true
limited := false
if len(events) > limit {
if len(events) > eventFilter.Limit {
limited = true
// re-slice the extra (oldest) event out: in chronological order this is the first entry, else the last.
if chronologicalOrder {
@ -367,10 +386,17 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
// from a given position, up to a maximum of 'limit'.
func (s *outputRoomEventsStatements) SelectEarlyEvents(
ctx context.Context, txn *sql.Tx,
roomID string, r types.Range, limit int,
roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter,
) ([]types.StreamEvent, error) {
stmt := sqlutil.TxStmt(txn, s.selectEarlyEventsStmt)
rows, err := stmt.QueryContext(ctx, roomID, r.Low(), r.High(), limit)
rows, err := stmt.QueryContext(
ctx, roomID, r.Low(), r.High(),
pq.StringArray(eventFilter.Senders),
pq.StringArray(eventFilter.NotSenders),
pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.Types)),
pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.NotTypes)),
eventFilter.Limit,
)
if err != nil {
return nil, err
}

View file

@ -110,8 +110,8 @@ func (d *Database) RoomIDsWithMembership(ctx context.Context, userID string, mem
return d.CurrentRoomState.SelectRoomIDsWithMembership(ctx, nil, userID, membership)
}
func (d *Database) RecentEvents(ctx context.Context, roomID string, r types.Range, limit int, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) {
return d.OutputEvents.SelectRecentEvents(ctx, nil, roomID, r, limit, chronologicalOrder, onlySyncEvents)
func (d *Database) RecentEvents(ctx context.Context, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error) {
return d.OutputEvents.SelectRecentEvents(ctx, nil, roomID, r, eventFilter, chronologicalOrder, onlySyncEvents)
}
func (d *Database) PositionInTopology(ctx context.Context, eventID string) (pos types.StreamPosition, spos types.StreamPosition, err error) {
@ -151,7 +151,7 @@ func (d *Database) Events(ctx context.Context, eventIDs []string) ([]*gomatrixse
func (d *Database) GetEventsInStreamingRange(
ctx context.Context,
from, to *types.StreamingToken,
roomID string, limit int,
roomID string, eventFilter *gomatrixserverlib.RoomEventFilter,
backwardOrdering bool,
) (events []types.StreamEvent, err error) {
r := types.Range{
@ -162,14 +162,14 @@ func (d *Database) GetEventsInStreamingRange(
if backwardOrdering {
// When using backward ordering, we want the most recent events first.
if events, _, err = d.OutputEvents.SelectRecentEvents(
ctx, nil, roomID, r, limit, false, false,
ctx, nil, roomID, r, eventFilter, false, false,
); err != nil {
return
}
} else {
// When using forward ordering, we want the least recent events first.
if events, err = d.OutputEvents.SelectEarlyEvents(
ctx, nil, roomID, r, limit,
ctx, nil, roomID, r, eventFilter,
); err != nil {
return
}

View file

@ -82,7 +82,7 @@ func (s *accountDataStatements) InsertAccountData(
ctx context.Context, txn *sql.Tx,
userID, roomID, dataType string,
) (pos types.StreamPosition, err error) {
pos, err = s.streamIDStatements.nextStreamID(ctx, txn)
pos, err = s.streamIDStatements.nextAccountDataID(ctx, txn)
if err != nil {
return
}

View file

@ -19,6 +19,7 @@ import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"github.com/matrix-org/dendrite/internal"
@ -66,13 +67,8 @@ const selectRoomIDsWithMembershipSQL = "" +
"SELECT DISTINCT room_id FROM syncapi_current_room_state WHERE type = 'm.room.member' AND state_key = $1 AND membership = $2"
const selectCurrentStateSQL = "" +
"SELECT event_id, headered_event_json FROM syncapi_current_room_state WHERE room_id = $1" +
" AND ( $2 IS NULL OR sender IN ($2) )" +
" AND ( $3 IS NULL OR NOT(sender IN ($3)) )" +
" AND ( $4 IS NULL OR type IN ($4) )" +
" AND ( $5 IS NULL OR NOT(type IN ($5)) )" +
" AND ( $6 IS NULL OR contains_url = $6 )" +
" LIMIT $7"
"SELECT event_id, headered_event_json FROM syncapi_current_room_state WHERE room_id = $1"
// WHEN, ORDER BY and LIMIT will be added by prepareWithFilter
const selectJoinedUsersSQL = "" +
"SELECT room_id, state_key FROM syncapi_current_room_state WHERE type = 'm.room.member' AND membership = 'join'"
@ -95,7 +91,6 @@ type currentRoomStateStatements struct {
deleteRoomStateByEventIDStmt *sql.Stmt
DeleteRoomStateForRoomStmt *sql.Stmt
selectRoomIDsWithMembershipStmt *sql.Stmt
selectCurrentStateStmt *sql.Stmt
selectJoinedUsersStmt *sql.Stmt
selectStateEventStmt *sql.Stmt
}
@ -121,9 +116,6 @@ func NewSqliteCurrentRoomStateTable(db *sql.DB, streamID *streamIDStatements) (t
if s.selectRoomIDsWithMembershipStmt, err = db.Prepare(selectRoomIDsWithMembershipSQL); err != nil {
return nil, err
}
if s.selectCurrentStateStmt, err = db.Prepare(selectCurrentStateSQL); err != nil {
return nil, err
}
if s.selectJoinedUsersStmt, err = db.Prepare(selectJoinedUsersSQL); err != nil {
return nil, err
}
@ -185,17 +177,22 @@ func (s *currentRoomStateStatements) SelectRoomIDsWithMembership(
// CurrentState returns all the current state events for the given room.
func (s *currentRoomStateStatements) SelectCurrentState(
ctx context.Context, txn *sql.Tx, roomID string,
stateFilterPart *gomatrixserverlib.StateFilter,
stateFilter *gomatrixserverlib.StateFilter,
) ([]*gomatrixserverlib.HeaderedEvent, error) {
stmt := sqlutil.TxStmt(txn, s.selectCurrentStateStmt)
rows, err := stmt.QueryContext(ctx, roomID,
nil, // FIXME: pq.StringArray(stateFilterPart.Senders),
nil, // FIXME: pq.StringArray(stateFilterPart.NotSenders),
nil, // FIXME: pq.StringArray(filterConvertTypeWildcardToSQL(stateFilterPart.Types)),
nil, // FIXME: pq.StringArray(filterConvertTypeWildcardToSQL(stateFilterPart.NotTypes)),
stateFilterPart.ContainsURL,
stateFilterPart.Limit,
stmt, params, err := prepareWithFilters(
s.db, txn, selectCurrentStateSQL,
[]interface{}{
roomID,
},
stateFilter.Senders, stateFilter.NotSenders,
stateFilter.Types, stateFilter.NotTypes,
stateFilter.Limit, FilterOrderNone,
)
if err != nil {
return nil, fmt.Errorf("s.prepareWithFilters: %w", err)
}
rows, err := stmt.QueryContext(ctx, params...)
if err != nil {
return nil, err
}

View file

@ -87,7 +87,7 @@ func (s *filterStatements) SelectFilter(
}
// Unmarshal JSON into Filter struct
var filter gomatrixserverlib.Filter
filter := gomatrixserverlib.DefaultFilter()
if err = json.Unmarshal(filterData, &filter); err != nil {
return nil, err
}

View file

@ -0,0 +1,76 @@
package sqlite3
import (
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
)
type FilterOrder int
const (
FilterOrderNone = iota
FilterOrderAsc
FilterOrderDesc
)
// prepareWithFilters returns a prepared statement with the
// relevant filters included. It also includes an []interface{}
// list of all the relevant parameters to pass straight to
// QueryContext, QueryRowContext etc.
// We don't take the filter object directly here because the
// fields might come from either a StateFilter or an EventFilter,
// and it's easier just to have the caller extract the relevant
// parts.
func prepareWithFilters(
db *sql.DB, txn *sql.Tx, query string, params []interface{},
senders, notsenders, types, nottypes []string,
limit int, order FilterOrder,
) (*sql.Stmt, []interface{}, error) {
offset := len(params)
if count := len(senders); count > 0 {
query += " AND sender IN " + sqlutil.QueryVariadicOffset(count, offset)
for _, v := range senders {
params, offset = append(params, v), offset+1
}
}
if count := len(notsenders); count > 0 {
query += " AND sender NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
for _, v := range notsenders {
params, offset = append(params, v), offset+1
}
}
if count := len(types); count > 0 {
query += " AND type IN " + sqlutil.QueryVariadicOffset(count, offset)
for _, v := range types {
params, offset = append(params, v), offset+1
}
}
if count := len(nottypes); count > 0 {
query += " AND type NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
for _, v := range nottypes {
params, offset = append(params, v), offset+1
}
}
switch order {
case FilterOrderAsc:
query += " ORDER BY id ASC"
case FilterOrderDesc:
query += " ORDER BY id DESC"
}
query += fmt.Sprintf(" LIMIT $%d", offset+1)
params = append(params, limit)
var stmt *sql.Stmt
var err error
if txn != nil {
stmt, err = txn.Prepare(query)
} else {
stmt, err = db.Prepare(query)
}
if err != nil {
return nil, nil, fmt.Errorf("s.db.Prepare: %w", err)
}
return stmt, params, nil
}

View file

@ -93,7 +93,7 @@ func NewSqliteInvitesTable(db *sql.DB, streamID *streamIDStatements) (tables.Inv
func (s *inviteEventsStatements) InsertInviteEvent(
ctx context.Context, txn *sql.Tx, inviteEvent *gomatrixserverlib.HeaderedEvent,
) (streamPos types.StreamPosition, err error) {
streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn)
streamPos, err = s.streamIDStatements.nextInviteID(ctx, txn)
if err != nil {
return
}
@ -119,7 +119,7 @@ func (s *inviteEventsStatements) InsertInviteEvent(
func (s *inviteEventsStatements) DeleteInviteEvent(
ctx context.Context, txn *sql.Tx, inviteEventID string,
) (types.StreamPosition, error) {
streamPos, err := s.streamIDStatements.nextStreamID(ctx, txn)
streamPos, err := s.streamIDStatements.nextInviteID(ctx, txn)
if err != nil {
return streamPos, err
}

View file

@ -19,6 +19,7 @@ import (
"context"
"database/sql"
"encoding/json"
"fmt"
"sort"
"github.com/matrix-org/dendrite/internal"
@ -60,18 +61,18 @@ const selectEventsSQL = "" +
const selectRecentEventsSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3" +
" ORDER BY id DESC LIMIT $4"
" WHERE room_id = $1 AND id > $2 AND id <= $3"
// WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
const selectRecentEventsForSyncSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3 AND exclude_from_sync = FALSE" +
" ORDER BY id DESC LIMIT $4"
" WHERE room_id = $1 AND id > $2 AND id <= $3 AND exclude_from_sync = FALSE"
// WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
const selectEarlyEventsSQL = "" +
"SELECT event_id, id, headered_event_json, session_id, exclude_from_sync, transaction_id FROM syncapi_output_room_events" +
" WHERE room_id = $1 AND id > $2 AND id <= $3" +
" ORDER BY id ASC LIMIT $4"
" WHERE room_id = $1 AND id > $2 AND id <= $3"
// WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
const selectMaxEventIDSQL = "" +
"SELECT MAX(id) FROM syncapi_output_room_events"
@ -79,45 +80,24 @@ const selectMaxEventIDSQL = "" +
const updateEventJSONSQL = "" +
"UPDATE syncapi_output_room_events SET headered_event_json=$1 WHERE event_id=$2"
// In order for us to apply the state updates correctly, rows need to be ordered in the order they were received (id).
/*
$1 = oldPos,
$2 = newPos,
$3 = pq.StringArray(stateFilterPart.Senders),
$4 = pq.StringArray(stateFilterPart.NotSenders),
$5 = pq.StringArray(filterConvertTypeWildcardToSQL(stateFilterPart.Types)),
$6 = pq.StringArray(filterConvertTypeWildcardToSQL(stateFilterPart.NotTypes)),
$7 = stateFilterPart.ContainsURL,
$8 = stateFilterPart.Limit,
*/
const selectStateInRangeSQL = "" +
"SELECT id, headered_event_json, exclude_from_sync, add_state_ids, remove_state_ids" +
" FROM syncapi_output_room_events" +
" WHERE (id > $1 AND id <= $2)" + // old/new pos
" AND (add_state_ids IS NOT NULL OR remove_state_ids IS NOT NULL)" +
/* " AND ( $3 IS NULL OR sender IN ($3) )" + // sender
" AND ( $4 IS NULL OR NOT(sender IN ($4)) )" + // not sender
" AND ( $5 IS NULL OR type IN ($5) )" + // type
" AND ( $6 IS NULL OR NOT(type IN ($6)) )" + // not type
" AND ( $7 IS NULL OR contains_url = $7)" + // contains URL? */
" ORDER BY id ASC" +
" LIMIT $8" // limit
" WHERE (id > $1 AND id <= $2)" +
" AND ((add_state_ids IS NOT NULL AND add_state_ids != '') OR (remove_state_ids IS NOT NULL AND remove_state_ids != ''))"
// WHEN, ORDER BY and LIMIT are appended by prepareWithFilters
const deleteEventsForRoomSQL = "" +
"DELETE FROM syncapi_output_room_events WHERE room_id = $1"
type outputRoomEventsStatements struct {
db *sql.DB
streamIDStatements *streamIDStatements
insertEventStmt *sql.Stmt
selectEventsStmt *sql.Stmt
selectMaxEventIDStmt *sql.Stmt
selectRecentEventsStmt *sql.Stmt
selectRecentEventsForSyncStmt *sql.Stmt
selectEarlyEventsStmt *sql.Stmt
selectStateInRangeStmt *sql.Stmt
updateEventJSONStmt *sql.Stmt
deleteEventsForRoomStmt *sql.Stmt
db *sql.DB
streamIDStatements *streamIDStatements
insertEventStmt *sql.Stmt
selectEventsStmt *sql.Stmt
selectMaxEventIDStmt *sql.Stmt
updateEventJSONStmt *sql.Stmt
deleteEventsForRoomStmt *sql.Stmt
}
func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Events, error) {
@ -138,18 +118,6 @@ func NewSqliteEventsTable(db *sql.DB, streamID *streamIDStatements) (tables.Even
if s.selectMaxEventIDStmt, err = db.Prepare(selectMaxEventIDSQL); err != nil {
return nil, err
}
if s.selectRecentEventsStmt, err = db.Prepare(selectRecentEventsSQL); err != nil {
return nil, err
}
if s.selectRecentEventsForSyncStmt, err = db.Prepare(selectRecentEventsForSyncSQL); err != nil {
return nil, err
}
if s.selectEarlyEventsStmt, err = db.Prepare(selectEarlyEventsSQL); err != nil {
return nil, err
}
if s.selectStateInRangeStmt, err = db.Prepare(selectStateInRangeSQL); err != nil {
return nil, err
}
if s.updateEventJSONStmt, err = db.Prepare(updateEventJSONSQL); err != nil {
return nil, err
}
@ -173,19 +141,22 @@ func (s *outputRoomEventsStatements) UpdateEventJSON(ctx context.Context, event
// two positions, only the most recent state is returned.
func (s *outputRoomEventsStatements) SelectStateInRange(
ctx context.Context, txn *sql.Tx, r types.Range,
stateFilterPart *gomatrixserverlib.StateFilter,
stateFilter *gomatrixserverlib.StateFilter,
) (map[string]map[string]bool, map[string]types.StreamEvent, error) {
stmt := sqlutil.TxStmt(txn, s.selectStateInRangeStmt)
rows, err := stmt.QueryContext(
ctx, r.Low(), r.High(),
/*pq.StringArray(stateFilterPart.Senders),
pq.StringArray(stateFilterPart.NotSenders),
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilterPart.Types)),
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilterPart.NotTypes)),
stateFilterPart.ContainsURL,*/
stateFilterPart.Limit,
stmt, params, err := prepareWithFilters(
s.db, txn, selectStateInRangeSQL,
[]interface{}{
r.Low(), r.High(),
},
stateFilter.Senders, stateFilter.NotSenders,
stateFilter.Types, stateFilter.NotTypes,
stateFilter.Limit, FilterOrderAsc,
)
if err != nil {
return nil, nil, fmt.Errorf("s.prepareWithFilters: %w", err)
}
rows, err := stmt.QueryContext(ctx, params...)
if err != nil {
return nil, nil, err
}
@ -298,16 +269,21 @@ func (s *outputRoomEventsStatements) InsertEvent(
return 0, err
}
addStateJSON, err := json.Marshal(addState)
if err != nil {
return 0, err
var addStateJSON, removeStateJSON []byte
if len(addState) > 0 {
addStateJSON, err = json.Marshal(addState)
}
removeStateJSON, err := json.Marshal(removeState)
if err != nil {
return 0, err
return 0, fmt.Errorf("json.Marshal(addState): %w", err)
}
if len(removeState) > 0 {
removeStateJSON, err = json.Marshal(removeState)
}
if err != nil {
return 0, fmt.Errorf("json.Marshal(removeState): %w", err)
}
streamPos, err := s.streamIDStatements.nextStreamID(ctx, txn)
streamPos, err := s.streamIDStatements.nextPDUID(ctx, txn)
if err != nil {
return 0, err
}
@ -333,17 +309,30 @@ func (s *outputRoomEventsStatements) InsertEvent(
func (s *outputRoomEventsStatements) SelectRecentEvents(
ctx context.Context, txn *sql.Tx,
roomID string, r types.Range, limit int,
roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter,
chronologicalOrder bool, onlySyncEvents bool,
) ([]types.StreamEvent, bool, error) {
var stmt *sql.Stmt
var query string
if onlySyncEvents {
stmt = sqlutil.TxStmt(txn, s.selectRecentEventsForSyncStmt)
query = selectRecentEventsForSyncSQL
} else {
stmt = sqlutil.TxStmt(txn, s.selectRecentEventsStmt)
query = selectRecentEventsSQL
}
rows, err := stmt.QueryContext(ctx, roomID, r.Low(), r.High(), limit+1)
stmt, params, err := prepareWithFilters(
s.db, txn, query,
[]interface{}{
roomID, r.Low(), r.High(),
},
eventFilter.Senders, eventFilter.NotSenders,
eventFilter.Types, eventFilter.NotTypes,
eventFilter.Limit+1, FilterOrderDesc,
)
if err != nil {
return nil, false, fmt.Errorf("s.prepareWithFilters: %w", err)
}
rows, err := stmt.QueryContext(ctx, params...)
if err != nil {
return nil, false, err
}
@ -362,7 +351,7 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
}
// we queried for 1 more than the limit, so if we returned one more mark limited=true
limited := false
if len(events) > limit {
if len(events) > eventFilter.Limit {
limited = true
// re-slice the extra (oldest) event out: in chronological order this is the first entry, else the last.
if chronologicalOrder {
@ -376,10 +365,21 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
func (s *outputRoomEventsStatements) SelectEarlyEvents(
ctx context.Context, txn *sql.Tx,
roomID string, r types.Range, limit int,
roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter,
) ([]types.StreamEvent, error) {
stmt := sqlutil.TxStmt(txn, s.selectEarlyEventsStmt)
rows, err := stmt.QueryContext(ctx, roomID, r.Low(), r.High(), limit)
stmt, params, err := prepareWithFilters(
s.db, txn, selectEarlyEventsSQL,
[]interface{}{
roomID, r.Low(), r.High(),
},
eventFilter.Senders, eventFilter.NotSenders,
eventFilter.Types, eventFilter.NotTypes,
eventFilter.Limit, FilterOrderAsc,
)
if err != nil {
return nil, fmt.Errorf("s.prepareWithFilters: %w", err)
}
rows, err := stmt.QueryContext(ctx, params...)
if err != nil {
return nil, err
}

View file

@ -108,7 +108,7 @@ func NewSqlitePeeksTable(db *sql.DB, streamID *streamIDStatements) (tables.Peeks
func (s *peekStatements) InsertPeek(
ctx context.Context, txn *sql.Tx, roomID, userID, deviceID string,
) (streamPos types.StreamPosition, err error) {
streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn)
streamPos, err = s.streamIDStatements.nextPDUID(ctx, txn)
if err != nil {
return
}
@ -120,7 +120,7 @@ func (s *peekStatements) InsertPeek(
func (s *peekStatements) DeletePeek(
ctx context.Context, txn *sql.Tx, roomID, userID, deviceID string,
) (streamPos types.StreamPosition, err error) {
streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn)
streamPos, err = s.streamIDStatements.nextPDUID(ctx, txn)
if err != nil {
return
}
@ -131,7 +131,7 @@ func (s *peekStatements) DeletePeek(
func (s *peekStatements) DeletePeeks(
ctx context.Context, txn *sql.Tx, roomID, userID string,
) (types.StreamPosition, error) {
streamPos, err := s.streamIDStatements.nextStreamID(ctx, txn)
streamPos, err := s.streamIDStatements.nextPDUID(ctx, txn)
if err != nil {
return 0, err
}

View file

@ -20,6 +20,10 @@ INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("global", 0)
ON CONFLICT DO NOTHING;
INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("receipt", 0)
ON CONFLICT DO NOTHING;
INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("accountdata", 0)
ON CONFLICT DO NOTHING;
INSERT INTO syncapi_stream_id (stream_name, stream_id) VALUES ("invite", 0)
ON CONFLICT DO NOTHING;
`
const increaseStreamIDStmt = "" +
@ -49,7 +53,7 @@ func (s *streamIDStatements) prepare(db *sql.DB) (err error) {
return
}
func (s *streamIDStatements) nextStreamID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
func (s *streamIDStatements) nextPDUID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt)
if _, err = increaseStmt.ExecContext(ctx, "global"); err != nil {
@ -68,3 +72,23 @@ func (s *streamIDStatements) nextReceiptID(ctx context.Context, txn *sql.Tx) (po
err = selectStmt.QueryRowContext(ctx, "receipt").Scan(&pos)
return
}
func (s *streamIDStatements) nextInviteID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt)
if _, err = increaseStmt.ExecContext(ctx, "invite"); err != nil {
return
}
err = selectStmt.QueryRowContext(ctx, "invite").Scan(&pos)
return
}
func (s *streamIDStatements) nextAccountDataID(ctx context.Context, txn *sql.Tx) (pos types.StreamPosition, err error) {
increaseStmt := sqlutil.TxStmt(txn, s.increaseStreamIDStmt)
selectStmt := sqlutil.TxStmt(txn, s.selectStreamIDStmt)
if _, err = increaseStmt.ExecContext(ctx, "accountdata"); err != nil {
return
}
err = selectStmt.QueryRowContext(ctx, "accountdata").Scan(&pos)
return
}

View file

@ -56,9 +56,9 @@ type Events interface {
// SelectRecentEvents returns events between the two stream positions: exclusive of low and inclusive of high.
// If onlySyncEvents has a value of true, only returns the events that aren't marked as to exclude from sync.
// Returns up to `limit` events. Returns `limited=true` if there are more events in this range but we hit the `limit`.
SelectRecentEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, limit int, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error)
SelectRecentEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error)
// SelectEarlyEvents returns the earliest events in the given room.
SelectEarlyEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, limit int) ([]types.StreamEvent, error)
SelectEarlyEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter) ([]types.StreamEvent, error)
SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StreamEvent, error)
UpdateEventJSON(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error
// DeleteEventsForRoom removes all event information for a room. This should only be done when removing the room entirely.