mirror of
https://github.com/hoernschen/dendrite.git
synced 2025-08-01 05:42:46 +00:00
Implement /sync limited
and read timeline limit from stored filters (#1168)
* Move filter table to syncapi where it is used * Implement /sync `limited` and read timeline limit from stored filters We now fully handle `room.timeline.limit` filters (in-line + stored) and return the right value for `limited` syncs. * Update whitelist * Default to the default timeline limit if it's unset, also strip the extra event correctly * Update whitelist
This commit is contained in:
parent
164057a3be
commit
1ad7219e4b
19 changed files with 194 additions and 135 deletions
128
syncapi/routing/filter.go
Normal file
128
syncapi/routing/filter.go
Normal file
|
@ -0,0 +1,128 @@
|
|||
// Copyright 2017 Jan Christian Grünhage
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package routing
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage"
|
||||
"github.com/matrix-org/dendrite/syncapi/sync"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// GetFilter implements GET /_matrix/client/r0/user/{userId}/filter/{filterId}
|
||||
func GetFilter(
|
||||
req *http.Request, device *api.Device, syncDB storage.Database, userID string, filterID string,
|
||||
) util.JSONResponse {
|
||||
if userID != device.UserID {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusForbidden,
|
||||
JSON: jsonerror.Forbidden("Cannot get filters for other users"),
|
||||
}
|
||||
}
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', userID)
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
|
||||
filter, err := syncDB.GetFilter(req.Context(), localpart, filterID)
|
||||
if err != nil {
|
||||
//TODO better error handling. This error message is *probably* right,
|
||||
// but if there are obscure db errors, this will also be returned,
|
||||
// even though it is not correct.
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.NotFound("No such filter"),
|
||||
}
|
||||
}
|
||||
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
JSON: filter,
|
||||
}
|
||||
}
|
||||
|
||||
type filterResponse struct {
|
||||
FilterID string `json:"filter_id"`
|
||||
}
|
||||
|
||||
//PutFilter implements POST /_matrix/client/r0/user/{userId}/filter
|
||||
func PutFilter(
|
||||
req *http.Request, device *api.Device, syncDB storage.Database, userID string,
|
||||
) util.JSONResponse {
|
||||
if userID != device.UserID {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusForbidden,
|
||||
JSON: jsonerror.Forbidden("Cannot create filters for other users"),
|
||||
}
|
||||
}
|
||||
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', userID)
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
|
||||
var filter gomatrixserverlib.Filter
|
||||
|
||||
defer req.Body.Close() // nolint:errcheck
|
||||
body, err := ioutil.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.BadJSON("The request body could not be read. " + err.Error()),
|
||||
}
|
||||
}
|
||||
|
||||
if err = json.Unmarshal(body, &filter); err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()),
|
||||
}
|
||||
}
|
||||
// the filter `limit` is `int` which defaults to 0 if not set which is not what we want. We want to use the default
|
||||
// limit if it is unset, which is what this does.
|
||||
limitRes := gjson.GetBytes(body, "room.timeline.limit")
|
||||
if !limitRes.Exists() {
|
||||
util.GetLogger(req.Context()).Infof("missing timeline limit, using default")
|
||||
filter.Room.Timeline.Limit = sync.DefaultTimelineLimit
|
||||
}
|
||||
|
||||
// Validate generates a user-friendly error
|
||||
if err = filter.Validate(); err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: jsonerror.BadJSON("Invalid filter: " + err.Error()),
|
||||
}
|
||||
}
|
||||
|
||||
filterID, err := syncDB.PutFilter(req.Context(), localpart, &filter)
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("syncDB.PutFilter failed")
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
JSON: filterResponse{FilterID: filterID},
|
||||
}
|
||||
}
|
|
@ -55,4 +55,24 @@ func Setup(
|
|||
}
|
||||
return OnIncomingMessagesRequest(req, syncDB, vars["roomID"], federation, rsAPI, cfg)
|
||||
})).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/user/{userId}/filter",
|
||||
httputil.MakeAuthAPI("put_filter", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return PutFilter(req, device, syncDB, vars["userId"])
|
||||
}),
|
||||
).Methods(http.MethodPost, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/user/{userId}/filter/{filterId}",
|
||||
httputil.MakeAuthAPI("get_filter", userAPI, func(req *http.Request, device *userapi.Device) util.JSONResponse {
|
||||
vars, err := httputil.URLDecodeMapValues(mux.Vars(req))
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return GetFilter(req, device, syncDB, vars["userId"], vars["filterId"])
|
||||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
}
|
||||
|
|
|
@ -128,4 +128,12 @@ type Database interface {
|
|||
CleanSendToDeviceUpdates(ctx context.Context, toUpdate, toDelete []types.SendToDeviceNID, token types.StreamingToken) (err error)
|
||||
// SendToDeviceUpdatesWaiting returns true if there are send-to-device updates waiting to be sent.
|
||||
SendToDeviceUpdatesWaiting(ctx context.Context, userID, deviceID string) (bool, error)
|
||||
// GetFilter looks up the filter associated with a given local user and filter ID.
|
||||
// Returns a filter structure. Otherwise returns an error if no such filter exists
|
||||
// or if there was an error talking to the database.
|
||||
GetFilter(ctx context.Context, localpart string, filterID string) (*gomatrixserverlib.Filter, error)
|
||||
// PutFilter puts the passed filter into the database.
|
||||
// Returns the filterID as a string. Otherwise returns an error if something
|
||||
// goes wrong.
|
||||
PutFilter(ctx context.Context, localpart string, filter *gomatrixserverlib.Filter) (string, error)
|
||||
}
|
||||
|
|
129
syncapi/storage/postgres/filter_table.go
Normal file
129
syncapi/storage/postgres/filter_table.go
Normal file
|
@ -0,0 +1,129 @@
|
|||
// Copyright 2017 Jan Christian Grünhage
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const filterSchema = `
|
||||
-- Stores data about filters
|
||||
CREATE TABLE IF NOT EXISTS syncapi_filter (
|
||||
-- The filter
|
||||
filter TEXT NOT NULL,
|
||||
-- The ID
|
||||
id SERIAL UNIQUE,
|
||||
-- The localpart of the Matrix user ID associated to this filter
|
||||
localpart TEXT NOT NULL,
|
||||
|
||||
PRIMARY KEY(id, localpart)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS syncapi_filter_localpart ON syncapi_filter(localpart);
|
||||
`
|
||||
|
||||
const selectFilterSQL = "" +
|
||||
"SELECT filter FROM syncapi_filter WHERE localpart = $1 AND id = $2"
|
||||
|
||||
const selectFilterIDByContentSQL = "" +
|
||||
"SELECT id FROM syncapi_filter WHERE localpart = $1 AND filter = $2"
|
||||
|
||||
const insertFilterSQL = "" +
|
||||
"INSERT INTO syncapi_filter (filter, id, localpart) VALUES ($1, DEFAULT, $2) RETURNING id"
|
||||
|
||||
type filterStatements struct {
|
||||
selectFilterStmt *sql.Stmt
|
||||
selectFilterIDByContentStmt *sql.Stmt
|
||||
insertFilterStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresFilterTable(db *sql.DB) (tables.Filter, error) {
|
||||
_, err := db.Exec(filterSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s := &filterStatements{}
|
||||
if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectFilterIDByContentStmt, err = db.Prepare(selectFilterIDByContentSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.insertFilterStmt, err = db.Prepare(insertFilterSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *filterStatements) SelectFilter(
|
||||
ctx context.Context, localpart string, filterID string,
|
||||
) (*gomatrixserverlib.Filter, error) {
|
||||
// Retrieve filter from database (stored as canonical JSON)
|
||||
var filterData []byte
|
||||
err := s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filterData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Unmarshal JSON into Filter struct
|
||||
var filter gomatrixserverlib.Filter
|
||||
if err = json.Unmarshal(filterData, &filter); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &filter, nil
|
||||
}
|
||||
|
||||
func (s *filterStatements) InsertFilter(
|
||||
ctx context.Context, filter *gomatrixserverlib.Filter, localpart string,
|
||||
) (filterID string, err error) {
|
||||
var existingFilterID string
|
||||
|
||||
// Serialise json
|
||||
filterJSON, err := json.Marshal(filter)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
// Remove whitespaces and sort JSON data
|
||||
// needed to prevent from inserting the same filter multiple times
|
||||
filterJSON, err = gomatrixserverlib.CanonicalJSON(filterJSON)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Check if filter already exists in the database using its localpart and content
|
||||
//
|
||||
// This can result in a race condition when two clients try to insert the
|
||||
// same filter and localpart at the same time, however this is not a
|
||||
// problem as both calls will result in the same filterID
|
||||
err = s.selectFilterIDByContentStmt.QueryRowContext(ctx,
|
||||
localpart, filterJSON).Scan(&existingFilterID)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return "", err
|
||||
}
|
||||
// If it does, return the existing ID
|
||||
if existingFilterID != "" {
|
||||
return existingFilterID, err
|
||||
}
|
||||
|
||||
// Otherwise insert the filter and return the new ID
|
||||
err = s.insertFilterStmt.QueryRowContext(ctx, filterJSON, localpart).
|
||||
Scan(&filterID)
|
||||
return
|
||||
}
|
|
@ -301,21 +301,21 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
|
|||
ctx context.Context, txn *sql.Tx,
|
||||
roomID string, r types.Range, limit int,
|
||||
chronologicalOrder bool, onlySyncEvents bool,
|
||||
) ([]types.StreamEvent, error) {
|
||||
) ([]types.StreamEvent, bool, error) {
|
||||
var stmt *sql.Stmt
|
||||
if onlySyncEvents {
|
||||
stmt = sqlutil.TxStmt(txn, s.selectRecentEventsForSyncStmt)
|
||||
} else {
|
||||
stmt = sqlutil.TxStmt(txn, s.selectRecentEventsStmt)
|
||||
}
|
||||
rows, err := stmt.QueryContext(ctx, roomID, r.Low(), r.High(), limit)
|
||||
rows, err := stmt.QueryContext(ctx, roomID, r.Low(), r.High(), limit+1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, false, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectRecentEvents: rows.close() failed")
|
||||
events, err := rowsToStreamEvents(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, false, err
|
||||
}
|
||||
if chronologicalOrder {
|
||||
// The events need to be returned from oldest to latest, which isn't
|
||||
|
@ -325,7 +325,19 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
|
|||
return events[i].StreamPosition < events[j].StreamPosition
|
||||
})
|
||||
}
|
||||
return events, nil
|
||||
// we queried for 1 more than the limit, so if we returned one more mark limited=true
|
||||
limited := false
|
||||
if len(events) > limit {
|
||||
limited = true
|
||||
// re-slice the extra (oldest) event out: in chronological order this is the first entry, else the last.
|
||||
if chronologicalOrder {
|
||||
events = events[1:]
|
||||
} else {
|
||||
events = events[:len(events)-1]
|
||||
}
|
||||
}
|
||||
|
||||
return events, limited, nil
|
||||
}
|
||||
|
||||
// selectEarlyEvents returns the earliest events in the given room, starting
|
||||
|
|
|
@ -71,6 +71,10 @@ func NewDatabase(dbDataSourceName string, dbProperties sqlutil.DbProperties) (*S
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
filter, err := NewPostgresFilterTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d.Database = shared.Database{
|
||||
DB: d.db,
|
||||
Invites: invites,
|
||||
|
@ -79,6 +83,7 @@ func NewDatabase(dbDataSourceName string, dbProperties sqlutil.DbProperties) (*S
|
|||
Topology: topology,
|
||||
CurrentRoomState: currState,
|
||||
BackwardExtremities: backwardExtremities,
|
||||
Filter: filter,
|
||||
SendToDevice: sendToDevice,
|
||||
SendToDeviceWriter: sqlutil.NewTransactionWriter(),
|
||||
EDUCache: cache.New(),
|
||||
|
|
|
@ -43,6 +43,7 @@ type Database struct {
|
|||
CurrentRoomState tables.CurrentRoomState
|
||||
BackwardExtremities tables.BackwardsExtremities
|
||||
SendToDevice tables.SendToDevice
|
||||
Filter tables.Filter
|
||||
SendToDeviceWriter *sqlutil.TransactionWriter
|
||||
EDUCache *cache.EDUCache
|
||||
}
|
||||
|
@ -78,7 +79,7 @@ func (d *Database) GetEventsInStreamingRange(
|
|||
}
|
||||
if backwardOrdering {
|
||||
// When using backward ordering, we want the most recent events first.
|
||||
if events, err = d.OutputEvents.SelectRecentEvents(
|
||||
if events, _, err = d.OutputEvents.SelectRecentEvents(
|
||||
ctx, nil, roomID, r, limit, false, false,
|
||||
); err != nil {
|
||||
return
|
||||
|
@ -545,6 +546,18 @@ func (d *Database) addEDUDeltaToResponse(
|
|||
return
|
||||
}
|
||||
|
||||
func (d *Database) GetFilter(
|
||||
ctx context.Context, localpart string, filterID string,
|
||||
) (*gomatrixserverlib.Filter, error) {
|
||||
return d.Filter.SelectFilter(ctx, localpart, filterID)
|
||||
}
|
||||
|
||||
func (d *Database) PutFilter(
|
||||
ctx context.Context, localpart string, filter *gomatrixserverlib.Filter,
|
||||
) (string, error) {
|
||||
return d.Filter.InsertFilter(ctx, filter, localpart)
|
||||
}
|
||||
|
||||
func (d *Database) IncrementalSync(
|
||||
ctx context.Context, res *types.Response,
|
||||
device userapi.Device,
|
||||
|
@ -642,7 +655,8 @@ func (d *Database) getResponseWithPDUsForCompleteSync(
|
|||
// TODO: When filters are added, we may need to call this multiple times to get enough events.
|
||||
// See: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L316
|
||||
var recentStreamEvents []types.StreamEvent
|
||||
recentStreamEvents, err = d.OutputEvents.SelectRecentEvents(
|
||||
var limited bool
|
||||
recentStreamEvents, limited, err = d.OutputEvents.SelectRecentEvents(
|
||||
ctx, txn, roomID, r, numRecentEventsPerRoom, true, true,
|
||||
)
|
||||
if err != nil {
|
||||
|
@ -670,7 +684,7 @@ func (d *Database) getResponseWithPDUsForCompleteSync(
|
|||
jr := types.NewJoinResponse()
|
||||
jr.Timeline.PrevBatch = prevBatchStr
|
||||
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
|
||||
jr.Timeline.Limited = true
|
||||
jr.Timeline.Limited = limited
|
||||
jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(stateEvents, gomatrixserverlib.FormatSync)
|
||||
res.Rooms.Join[roomID] = *jr
|
||||
}
|
||||
|
@ -776,7 +790,7 @@ func (d *Database) addRoomDeltaToResponse(
|
|||
// This is all "okay" assuming history_visibility == "shared" which it is by default.
|
||||
r.To = delta.membershipPos
|
||||
}
|
||||
recentStreamEvents, err := d.OutputEvents.SelectRecentEvents(
|
||||
recentStreamEvents, limited, err := d.OutputEvents.SelectRecentEvents(
|
||||
ctx, txn, delta.roomID, r,
|
||||
numRecentEventsPerRoom, true, true,
|
||||
)
|
||||
|
@ -796,7 +810,7 @@ func (d *Database) addRoomDeltaToResponse(
|
|||
|
||||
jr.Timeline.PrevBatch = prevBatch.String()
|
||||
jr.Timeline.Events = gomatrixserverlib.HeaderedToClientEvents(recentEvents, gomatrixserverlib.FormatSync)
|
||||
jr.Timeline.Limited = false // TODO: if len(events) >= numRecents + 1 and then set limited:true
|
||||
jr.Timeline.Limited = limited
|
||||
jr.State.Events = gomatrixserverlib.HeaderedToClientEvents(delta.stateEvents, gomatrixserverlib.FormatSync)
|
||||
res.Rooms.Join[delta.roomID] = *jr
|
||||
case gomatrixserverlib.Leave:
|
||||
|
|
137
syncapi/storage/sqlite3/filter_table.go
Normal file
137
syncapi/storage/sqlite3/filter_table.go
Normal file
|
@ -0,0 +1,137 @@
|
|||
// Copyright 2017 Jan Christian Grünhage
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/syncapi/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const filterSchema = `
|
||||
-- Stores data about filters
|
||||
CREATE TABLE IF NOT EXISTS syncapi_filter (
|
||||
-- The filter
|
||||
filter TEXT NOT NULL,
|
||||
-- The ID
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
-- The localpart of the Matrix user ID associated to this filter
|
||||
localpart TEXT NOT NULL,
|
||||
|
||||
UNIQUE (id, localpart)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS syncapi_filter_localpart ON syncapi_filter(localpart);
|
||||
`
|
||||
|
||||
const selectFilterSQL = "" +
|
||||
"SELECT filter FROM syncapi_filter WHERE localpart = $1 AND id = $2"
|
||||
|
||||
const selectFilterIDByContentSQL = "" +
|
||||
"SELECT id FROM syncapi_filter WHERE localpart = $1 AND filter = $2"
|
||||
|
||||
const insertFilterSQL = "" +
|
||||
"INSERT INTO syncapi_filter (filter, localpart) VALUES ($1, $2)"
|
||||
|
||||
type filterStatements struct {
|
||||
selectFilterStmt *sql.Stmt
|
||||
selectFilterIDByContentStmt *sql.Stmt
|
||||
insertFilterStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteFilterTable(db *sql.DB) (tables.Filter, error) {
|
||||
_, err := db.Exec(filterSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s := &filterStatements{}
|
||||
if s.selectFilterStmt, err = db.Prepare(selectFilterSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectFilterIDByContentStmt, err = db.Prepare(selectFilterIDByContentSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.insertFilterStmt, err = db.Prepare(insertFilterSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *filterStatements) SelectFilter(
|
||||
ctx context.Context, localpart string, filterID string,
|
||||
) (*gomatrixserverlib.Filter, error) {
|
||||
// Retrieve filter from database (stored as canonical JSON)
|
||||
var filterData []byte
|
||||
err := s.selectFilterStmt.QueryRowContext(ctx, localpart, filterID).Scan(&filterData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Unmarshal JSON into Filter struct
|
||||
var filter gomatrixserverlib.Filter
|
||||
if err = json.Unmarshal(filterData, &filter); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &filter, nil
|
||||
}
|
||||
|
||||
func (s *filterStatements) InsertFilter(
|
||||
ctx context.Context, filter *gomatrixserverlib.Filter, localpart string,
|
||||
) (filterID string, err error) {
|
||||
var existingFilterID string
|
||||
|
||||
// Serialise json
|
||||
filterJSON, err := json.Marshal(filter)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
// Remove whitespaces and sort JSON data
|
||||
// needed to prevent from inserting the same filter multiple times
|
||||
filterJSON, err = gomatrixserverlib.CanonicalJSON(filterJSON)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Check if filter already exists in the database using its localpart and content
|
||||
//
|
||||
// This can result in a race condition when two clients try to insert the
|
||||
// same filter and localpart at the same time, however this is not a
|
||||
// problem as both calls will result in the same filterID
|
||||
err = s.selectFilterIDByContentStmt.QueryRowContext(ctx,
|
||||
localpart, filterJSON).Scan(&existingFilterID)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return "", err
|
||||
}
|
||||
// If it does, return the existing ID
|
||||
if existingFilterID != "" {
|
||||
return existingFilterID, err
|
||||
}
|
||||
|
||||
// Otherwise insert the filter and return the new ID
|
||||
res, err := s.insertFilterStmt.ExecContext(ctx, filterJSON, localpart)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
rowid, err := res.LastInsertId()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
filterID = fmt.Sprintf("%d", rowid)
|
||||
return
|
||||
}
|
|
@ -311,7 +311,7 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
|
|||
ctx context.Context, txn *sql.Tx,
|
||||
roomID string, r types.Range, limit int,
|
||||
chronologicalOrder bool, onlySyncEvents bool,
|
||||
) ([]types.StreamEvent, error) {
|
||||
) ([]types.StreamEvent, bool, error) {
|
||||
var stmt *sql.Stmt
|
||||
if onlySyncEvents {
|
||||
stmt = sqlutil.TxStmt(txn, s.selectRecentEventsForSyncStmt)
|
||||
|
@ -319,14 +319,14 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
|
|||
stmt = sqlutil.TxStmt(txn, s.selectRecentEventsStmt)
|
||||
}
|
||||
|
||||
rows, err := stmt.QueryContext(ctx, roomID, r.Low(), r.High(), limit)
|
||||
rows, err := stmt.QueryContext(ctx, roomID, r.Low(), r.High(), limit+1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, false, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectRecentEvents: rows.close() failed")
|
||||
events, err := rowsToStreamEvents(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, false, err
|
||||
}
|
||||
if chronologicalOrder {
|
||||
// The events need to be returned from oldest to latest, which isn't
|
||||
|
@ -336,7 +336,18 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
|
|||
return events[i].StreamPosition < events[j].StreamPosition
|
||||
})
|
||||
}
|
||||
return events, nil
|
||||
// we queried for 1 more than the limit, so if we returned one more mark limited=true
|
||||
limited := false
|
||||
if len(events) > limit {
|
||||
limited = true
|
||||
// re-slice the extra (oldest) event out: in chronological order this is the first entry, else the last.
|
||||
if chronologicalOrder {
|
||||
events = events[1:]
|
||||
} else {
|
||||
events = events[:len(events)-1]
|
||||
}
|
||||
}
|
||||
return events, limited, nil
|
||||
}
|
||||
|
||||
func (s *outputRoomEventsStatements) SelectEarlyEvents(
|
||||
|
|
|
@ -87,6 +87,10 @@ func (d *SyncServerDatasource) prepare() (err error) {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
filter, err := NewSqliteFilterTable(d.db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.Database = shared.Database{
|
||||
DB: d.db,
|
||||
Invites: invites,
|
||||
|
@ -95,6 +99,7 @@ func (d *SyncServerDatasource) prepare() (err error) {
|
|||
BackwardExtremities: bwExtrem,
|
||||
CurrentRoomState: roomState,
|
||||
Topology: topology,
|
||||
Filter: filter,
|
||||
SendToDevice: sendToDevice,
|
||||
SendToDeviceWriter: sqlutil.NewTransactionWriter(),
|
||||
EDUCache: cache.New(),
|
||||
|
|
|
@ -44,8 +44,8 @@ type Events interface {
|
|||
InsertEvent(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, addState, removeState []string, transactionID *api.TransactionID, excludeFromSync bool) (streamPos types.StreamPosition, err error)
|
||||
// SelectRecentEvents returns events between the two stream positions: exclusive of low and inclusive of high.
|
||||
// If onlySyncEvents has a value of true, only returns the events that aren't marked as to exclude from sync.
|
||||
// Returns up to `limit` events.
|
||||
SelectRecentEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, limit int, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, error)
|
||||
// Returns up to `limit` events. Returns `limited=true` if there are more events in this range but we hit the `limit`.
|
||||
SelectRecentEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, limit int, chronologicalOrder bool, onlySyncEvents bool) ([]types.StreamEvent, bool, error)
|
||||
// SelectEarlyEvents returns the earliest events in the given room.
|
||||
SelectEarlyEvents(ctx context.Context, txn *sql.Tx, roomID string, r types.Range, limit int) ([]types.StreamEvent, error)
|
||||
SelectEvents(ctx context.Context, txn *sql.Tx, eventIDs []string) ([]types.StreamEvent, error)
|
||||
|
@ -133,3 +133,8 @@ type SendToDevice interface {
|
|||
DeleteSendToDeviceMessages(ctx context.Context, txn *sql.Tx, nids []types.SendToDeviceNID) (err error)
|
||||
CountSendToDeviceMessages(ctx context.Context, txn *sql.Tx, userID, deviceID string) (count int, err error)
|
||||
}
|
||||
|
||||
type Filter interface {
|
||||
SelectFilter(ctx context.Context, localpart string, filterID string) (*gomatrixserverlib.Filter, error)
|
||||
InsertFilter(ctx context.Context, filter *gomatrixserverlib.Filter, localpart string) (filterID string, err error)
|
||||
}
|
||||
|
|
|
@ -363,7 +363,7 @@ func newTestSyncRequest(userID, deviceID string, since types.StreamingToken) syn
|
|||
timeout: 1 * time.Minute,
|
||||
since: &since,
|
||||
wantFullState: false,
|
||||
limit: defaultTimelineLimit,
|
||||
limit: DefaultTimelineLimit,
|
||||
log: util.GetLogger(context.TODO()),
|
||||
ctx: context.TODO(),
|
||||
}
|
||||
|
|
|
@ -21,14 +21,16 @@ import (
|
|||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/syncapi/storage"
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const defaultSyncTimeout = time.Duration(0)
|
||||
const defaultTimelineLimit = 20
|
||||
const DefaultTimelineLimit = 20
|
||||
|
||||
type filter struct {
|
||||
Room struct {
|
||||
|
@ -49,7 +51,7 @@ type syncRequest struct {
|
|||
log *log.Entry
|
||||
}
|
||||
|
||||
func newSyncRequest(req *http.Request, device userapi.Device) (*syncRequest, error) {
|
||||
func newSyncRequest(req *http.Request, device userapi.Device, syncDB storage.Database) (*syncRequest, error) {
|
||||
timeout := getTimeout(req.URL.Query().Get("timeout"))
|
||||
fullState := req.URL.Query().Get("full_state")
|
||||
wantFullState := fullState != "" && fullState != "false"
|
||||
|
@ -66,15 +68,28 @@ func newSyncRequest(req *http.Request, device userapi.Device) (*syncRequest, err
|
|||
tok := types.NewStreamToken(0, 0)
|
||||
since = &tok
|
||||
}
|
||||
timelineLimit := defaultTimelineLimit
|
||||
timelineLimit := DefaultTimelineLimit
|
||||
// TODO: read from stored filters too
|
||||
filterQuery := req.URL.Query().Get("filter")
|
||||
if filterQuery != "" && filterQuery[0] == '{' {
|
||||
// attempt to parse the timeline limit at least
|
||||
var f filter
|
||||
err := json.Unmarshal([]byte(filterQuery), &f)
|
||||
if err == nil && f.Room.Timeline.Limit != nil {
|
||||
timelineLimit = *f.Room.Timeline.Limit
|
||||
if filterQuery != "" {
|
||||
if filterQuery[0] == '{' {
|
||||
// attempt to parse the timeline limit at least
|
||||
var f filter
|
||||
err := json.Unmarshal([]byte(filterQuery), &f)
|
||||
if err == nil && f.Room.Timeline.Limit != nil {
|
||||
timelineLimit = *f.Room.Timeline.Limit
|
||||
}
|
||||
} else {
|
||||
// attempt to load the filter ID
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', device.UserID)
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("gomatrixserverlib.SplitID failed")
|
||||
return nil, err
|
||||
}
|
||||
f, err := syncDB.GetFilter(req.Context(), localpart, filterQuery)
|
||||
if err == nil {
|
||||
timelineLimit = f.Room.Timeline.Limit
|
||||
}
|
||||
}
|
||||
}
|
||||
// TODO: Additional query params: set_presence, filter
|
||||
|
|
|
@ -49,7 +49,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *userapi.
|
|||
var syncData *types.Response
|
||||
|
||||
// Extract values from request
|
||||
syncReq, err := newSyncRequest(req, *device)
|
||||
syncReq, err := newSyncRequest(req, *device, rp.db)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue