mirror of
https://github.com/hoernschen/dendrite.git
synced 2024-12-27 07:28:27 +00:00
/sync
performance optimizations (#2927)
Since #2849 there is no limit for the current state we fetch to calculate history visibility. In large rooms this can cause us to fetch thousands of membership events we don't really care about. This now only gets the state event types and senders in our timeline, which should significantly reduce the amount of events we fetch from the database. Also removes `MaxTopologicalPosition`, as it is an unnecessary DB call, given we use the result in `topological_position < $1` calls.
This commit is contained in:
parent
8582c7520a
commit
0d0280cf5f
11 changed files with 372 additions and 105 deletions
|
@ -195,7 +195,7 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent, rew
|
|||
}
|
||||
|
||||
// If we added new hosts, inform them about our known presence events for this room
|
||||
if len(addsJoinedHosts) > 0 && ore.Event.Type() == gomatrixserverlib.MRoomMember && ore.Event.StateKey() != nil {
|
||||
if s.cfg.Matrix.Presence.EnableOutbound && len(addsJoinedHosts) > 0 && ore.Event.Type() == gomatrixserverlib.MRoomMember && ore.Event.StateKey() != nil {
|
||||
membership, _ := ore.Event.Membership()
|
||||
if membership == gomatrixserverlib.Join {
|
||||
s.sendPresence(ore.Event.RoomID(), addsJoinedHosts)
|
||||
|
|
|
@ -16,16 +16,16 @@ package routing
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"math"
|
||||
"net/http"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||
"github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage"
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
)
|
||||
|
||||
type getMembershipResponse struct {
|
||||
|
@ -87,19 +87,18 @@ func GetMemberships(
|
|||
if err != nil {
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
defer db.Rollback() // nolint: errcheck
|
||||
|
||||
atToken, err := types.NewTopologyTokenFromString(at)
|
||||
if err != nil {
|
||||
atToken = types.TopologyToken{Depth: math.MaxInt64, PDUPosition: math.MaxInt64}
|
||||
if queryRes.HasBeenInRoom && !queryRes.IsInRoom {
|
||||
// If you have left the room then this will be the members of the room when you left.
|
||||
atToken, err = db.EventPositionInTopology(req.Context(), queryRes.EventID)
|
||||
} else {
|
||||
// If you are joined to the room then this will be the current members of the room.
|
||||
atToken, err = db.MaxTopologicalPosition(req.Context(), roomID)
|
||||
}
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("unable to get 'atToken'")
|
||||
return jsonerror.InternalServerError()
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("unable to get 'atToken'")
|
||||
return jsonerror.InternalServerError()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -17,6 +17,7 @@ package routing
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math"
|
||||
"net/http"
|
||||
"sort"
|
||||
"time"
|
||||
|
@ -177,10 +178,11 @@ func OnIncomingMessagesRequest(
|
|||
// If "to" isn't provided, it defaults to either the earliest stream
|
||||
// position (if we're going backward) or to the latest one (if we're
|
||||
// going forward).
|
||||
to, err = setToDefault(req.Context(), snapshot, backwardOrdering, roomID)
|
||||
if err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("setToDefault failed")
|
||||
return jsonerror.InternalServerError()
|
||||
to = types.TopologyToken{Depth: math.MaxInt64, PDUPosition: math.MaxInt64}
|
||||
if backwardOrdering {
|
||||
// go 1 earlier than the first event so we correctly fetch the earliest event
|
||||
// this is because Database.GetEventsInTopologicalRange is exclusive of the lower-bound.
|
||||
to = types.TopologyToken{}
|
||||
}
|
||||
wasToProvided = false
|
||||
}
|
||||
|
@ -577,24 +579,3 @@ func (r *messagesReq) backfill(roomID string, backwardsExtremities map[string][]
|
|||
|
||||
return events, nil
|
||||
}
|
||||
|
||||
// setToDefault returns the default value for the "to" query parameter of a
|
||||
// request to /messages if not provided. It defaults to either the earliest
|
||||
// topological position (if we're going backward) or to the latest one (if we're
|
||||
// going forward).
|
||||
// Returns an error if there was an issue with retrieving the latest position
|
||||
// from the database
|
||||
func setToDefault(
|
||||
ctx context.Context, snapshot storage.DatabaseTransaction, backwardOrdering bool,
|
||||
roomID string,
|
||||
) (to types.TopologyToken, err error) {
|
||||
if backwardOrdering {
|
||||
// go 1 earlier than the first event so we correctly fetch the earliest event
|
||||
// this is because Database.GetEventsInTopologicalRange is exclusive of the lower-bound.
|
||||
to = types.TopologyToken{}
|
||||
} else {
|
||||
to, err = snapshot.MaxTopologicalPosition(ctx, roomID)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
|
|
@ -84,8 +84,6 @@ type DatabaseTransaction interface {
|
|||
EventPositionInTopology(ctx context.Context, eventID string) (types.TopologyToken, error)
|
||||
// BackwardExtremitiesForRoom returns a map of backwards extremity event ID to a list of its prev_events.
|
||||
BackwardExtremitiesForRoom(ctx context.Context, roomID string) (backwardExtremities map[string][]string, err error)
|
||||
// MaxTopologicalPosition returns the highest topological position for a given room.
|
||||
MaxTopologicalPosition(ctx context.Context, roomID string) (types.TopologyToken, error)
|
||||
// StreamEventsToEvents converts streamEvent to Event. If device is non-nil and
|
||||
// matches the streamevent.transactionID device then the transaction ID gets
|
||||
// added to the unsigned section of the output event.
|
||||
|
|
|
@ -65,14 +65,6 @@ const selectPositionInTopologySQL = "" +
|
|||
"SELECT topological_position, stream_position FROM syncapi_output_room_events_topology" +
|
||||
" WHERE event_id = $1"
|
||||
|
||||
// Select the max topological position for the room, then sort by stream position and take the highest,
|
||||
// returning both topological and stream positions.
|
||||
const selectMaxPositionInTopologySQL = "" +
|
||||
"SELECT topological_position, stream_position FROM syncapi_output_room_events_topology" +
|
||||
" WHERE topological_position=(" +
|
||||
"SELECT MAX(topological_position) FROM syncapi_output_room_events_topology WHERE room_id=$1" +
|
||||
") ORDER BY stream_position DESC LIMIT 1"
|
||||
|
||||
const selectStreamToTopologicalPositionAscSQL = "" +
|
||||
"SELECT topological_position FROM syncapi_output_room_events_topology WHERE room_id = $1 AND stream_position >= $2 ORDER BY topological_position ASC LIMIT 1;"
|
||||
|
||||
|
@ -84,7 +76,6 @@ type outputRoomEventsTopologyStatements struct {
|
|||
selectEventIDsInRangeASCStmt *sql.Stmt
|
||||
selectEventIDsInRangeDESCStmt *sql.Stmt
|
||||
selectPositionInTopologyStmt *sql.Stmt
|
||||
selectMaxPositionInTopologyStmt *sql.Stmt
|
||||
selectStreamToTopologicalPositionAscStmt *sql.Stmt
|
||||
selectStreamToTopologicalPositionDescStmt *sql.Stmt
|
||||
}
|
||||
|
@ -107,9 +98,6 @@ func NewPostgresTopologyTable(db *sql.DB) (tables.Topology, error) {
|
|||
if s.selectPositionInTopologyStmt, err = db.Prepare(selectPositionInTopologySQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectMaxPositionInTopologyStmt, err = db.Prepare(selectMaxPositionInTopologySQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectStreamToTopologicalPositionAscStmt, err = db.Prepare(selectStreamToTopologicalPositionAscSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -189,10 +177,3 @@ func (s *outputRoomEventsTopologyStatements) SelectStreamToTopologicalPosition(
|
|||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology(
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) (pos types.StreamPosition, spos types.StreamPosition, err error) {
|
||||
err = sqlutil.TxStmt(txn, s.selectMaxPositionInTopologyStmt).QueryRowContext(ctx, roomID).Scan(&pos, &spos)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/tidwall/gjson"
|
||||
|
@ -269,16 +270,6 @@ func (d *DatabaseTransaction) BackwardExtremitiesForRoom(
|
|||
return d.BackwardExtremities.SelectBackwardExtremitiesForRoom(ctx, d.txn, roomID)
|
||||
}
|
||||
|
||||
func (d *DatabaseTransaction) MaxTopologicalPosition(
|
||||
ctx context.Context, roomID string,
|
||||
) (types.TopologyToken, error) {
|
||||
depth, streamPos, err := d.Topology.SelectMaxPositionInTopology(ctx, d.txn, roomID)
|
||||
if err != nil {
|
||||
return types.TopologyToken{}, err
|
||||
}
|
||||
return types.TopologyToken{Depth: depth, PDUPosition: streamPos}, nil
|
||||
}
|
||||
|
||||
func (d *DatabaseTransaction) EventPositionInTopology(
|
||||
ctx context.Context, eventID string,
|
||||
) (types.TopologyToken, error) {
|
||||
|
@ -297,11 +288,7 @@ func (d *DatabaseTransaction) StreamToTopologicalPosition(
|
|||
case err == sql.ErrNoRows && backwardOrdering: // no events in range, going backward
|
||||
return types.TopologyToken{PDUPosition: streamPos}, nil
|
||||
case err == sql.ErrNoRows && !backwardOrdering: // no events in range, going forward
|
||||
topoPos, streamPos, err = d.Topology.SelectMaxPositionInTopology(ctx, d.txn, roomID)
|
||||
if err != nil {
|
||||
return types.TopologyToken{}, fmt.Errorf("d.Topology.SelectMaxPositionInTopology: %w", err)
|
||||
}
|
||||
return types.TopologyToken{Depth: topoPos, PDUPosition: streamPos}, nil
|
||||
return types.TopologyToken{Depth: math.MaxInt64, PDUPosition: math.MaxInt64}, nil
|
||||
case err != nil: // some other error happened
|
||||
return types.TopologyToken{}, fmt.Errorf("d.Topology.SelectStreamToTopologicalPosition: %w", err)
|
||||
default:
|
||||
|
|
|
@ -61,10 +61,6 @@ const selectPositionInTopologySQL = "" +
|
|||
"SELECT topological_position, stream_position FROM syncapi_output_room_events_topology" +
|
||||
" WHERE event_id = $1"
|
||||
|
||||
const selectMaxPositionInTopologySQL = "" +
|
||||
"SELECT MAX(topological_position), stream_position FROM syncapi_output_room_events_topology" +
|
||||
" WHERE room_id = $1 ORDER BY stream_position DESC"
|
||||
|
||||
const selectStreamToTopologicalPositionAscSQL = "" +
|
||||
"SELECT topological_position FROM syncapi_output_room_events_topology WHERE room_id = $1 AND stream_position >= $2 ORDER BY topological_position ASC LIMIT 1;"
|
||||
|
||||
|
@ -77,7 +73,6 @@ type outputRoomEventsTopologyStatements struct {
|
|||
selectEventIDsInRangeASCStmt *sql.Stmt
|
||||
selectEventIDsInRangeDESCStmt *sql.Stmt
|
||||
selectPositionInTopologyStmt *sql.Stmt
|
||||
selectMaxPositionInTopologyStmt *sql.Stmt
|
||||
selectStreamToTopologicalPositionAscStmt *sql.Stmt
|
||||
selectStreamToTopologicalPositionDescStmt *sql.Stmt
|
||||
}
|
||||
|
@ -102,9 +97,6 @@ func NewSqliteTopologyTable(db *sql.DB) (tables.Topology, error) {
|
|||
if s.selectPositionInTopologyStmt, err = db.Prepare(selectPositionInTopologySQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectMaxPositionInTopologyStmt, err = db.Prepare(selectMaxPositionInTopologySQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.selectStreamToTopologicalPositionAscStmt, err = db.Prepare(selectStreamToTopologicalPositionAscSQL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -182,11 +174,3 @@ func (s *outputRoomEventsTopologyStatements) SelectStreamToTopologicalPosition(
|
|||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *outputRoomEventsTopologyStatements) SelectMaxPositionInTopology(
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) (pos types.StreamPosition, spos types.StreamPosition, err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectMaxPositionInTopologyStmt)
|
||||
err = stmt.QueryRowContext(ctx, roomID).Scan(&pos, &spos)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ import (
|
|||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
|
@ -199,10 +200,7 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) {
|
|||
_ = MustWriteEvents(t, db, events)
|
||||
|
||||
WithSnapshot(t, db, func(snapshot storage.DatabaseTransaction) {
|
||||
from, err := snapshot.MaxTopologicalPosition(ctx, r.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get MaxTopologicalPosition: %s", err)
|
||||
}
|
||||
from := types.TopologyToken{Depth: math.MaxInt64, PDUPosition: math.MaxInt64}
|
||||
t.Logf("max topo pos = %+v", from)
|
||||
// head towards the beginning of time
|
||||
to := types.TopologyToken{}
|
||||
|
@ -219,6 +217,88 @@ func TestGetEventsInRangeWithTopologyToken(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestStreamToTopologicalPosition(t *testing.T) {
|
||||
alice := test.NewUser(t)
|
||||
r := test.NewRoom(t, alice)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
roomID string
|
||||
streamPos types.StreamPosition
|
||||
backwardOrdering bool
|
||||
wantToken types.TopologyToken
|
||||
}{
|
||||
{
|
||||
name: "forward ordering found streamPos returns found position",
|
||||
roomID: r.ID,
|
||||
streamPos: 1,
|
||||
backwardOrdering: false,
|
||||
wantToken: types.TopologyToken{Depth: 1, PDUPosition: 1},
|
||||
},
|
||||
{
|
||||
name: "forward ordering not found streamPos returns max position",
|
||||
roomID: r.ID,
|
||||
streamPos: 100,
|
||||
backwardOrdering: false,
|
||||
wantToken: types.TopologyToken{Depth: math.MaxInt64, PDUPosition: math.MaxInt64},
|
||||
},
|
||||
{
|
||||
name: "backward ordering found streamPos returns found position",
|
||||
roomID: r.ID,
|
||||
streamPos: 1,
|
||||
backwardOrdering: true,
|
||||
wantToken: types.TopologyToken{Depth: 1, PDUPosition: 1},
|
||||
},
|
||||
{
|
||||
name: "backward ordering not found streamPos returns maxDepth with param pduPosition",
|
||||
roomID: r.ID,
|
||||
streamPos: 100,
|
||||
backwardOrdering: true,
|
||||
wantToken: types.TopologyToken{Depth: 5, PDUPosition: 100},
|
||||
},
|
||||
{
|
||||
name: "backward non-existent room returns zero token",
|
||||
roomID: "!doesnotexist:localhost",
|
||||
streamPos: 1,
|
||||
backwardOrdering: true,
|
||||
wantToken: types.TopologyToken{Depth: 0, PDUPosition: 1},
|
||||
},
|
||||
{
|
||||
name: "forward non-existent room returns max token",
|
||||
roomID: "!doesnotexist:localhost",
|
||||
streamPos: 1,
|
||||
backwardOrdering: false,
|
||||
wantToken: types.TopologyToken{Depth: math.MaxInt64, PDUPosition: math.MaxInt64},
|
||||
},
|
||||
}
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close, closeBase := MustCreateDatabase(t, dbType)
|
||||
defer close()
|
||||
defer closeBase()
|
||||
|
||||
txn, err := db.NewDatabaseTransaction(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer txn.Rollback()
|
||||
MustWriteEvents(t, db, r.Events())
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
token, err := txn.StreamToTopologicalPosition(ctx, tc.roomID, tc.streamPos, tc.backwardOrdering)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if tc.wantToken != token {
|
||||
t.Fatalf("expected token %q, got %q", tc.wantToken, token)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
/*
|
||||
// The purpose of this test is to make sure that backpagination returns all events, even if some events have the same depth.
|
||||
// For cases where events have the same depth, the streaming token should be used to tie break so events written via WriteEvent
|
||||
|
|
|
@ -91,8 +91,6 @@ type Topology interface {
|
|||
SelectEventIDsInRange(ctx context.Context, txn *sql.Tx, roomID string, minDepth, maxDepth, maxStreamPos types.StreamPosition, limit int, chronologicalOrder bool) (eventIDs []string, err error)
|
||||
// SelectPositionInTopology returns the depth and stream position of a given event in the topology of the room it belongs to.
|
||||
SelectPositionInTopology(ctx context.Context, txn *sql.Tx, eventID string) (depth, spos types.StreamPosition, err error)
|
||||
// SelectMaxPositionInTopology returns the event which has the highest depth, and if there are multiple, the event with the highest stream position.
|
||||
SelectMaxPositionInTopology(ctx context.Context, txn *sql.Tx, roomID string) (depth types.StreamPosition, spos types.StreamPosition, err error)
|
||||
// SelectStreamToTopologicalPosition converts a stream position to a topological position by finding the nearest topological position in the room.
|
||||
SelectStreamToTopologicalPosition(ctx context.Context, txn *sql.Tx, roomID string, streamPos types.StreamPosition, forward bool) (topoPos types.StreamPosition, err error)
|
||||
}
|
||||
|
|
|
@ -384,19 +384,32 @@ func applyHistoryVisibilityFilter(
|
|||
roomID, userID string,
|
||||
recentEvents []*gomatrixserverlib.HeaderedEvent,
|
||||
) ([]*gomatrixserverlib.HeaderedEvent, error) {
|
||||
// We need to make sure we always include the latest states events, if they are in the timeline.
|
||||
// We grep at least limit * 2 events, to ensure we really get the needed events.
|
||||
filter := gomatrixserverlib.DefaultStateFilter()
|
||||
stateEvents, err := snapshot.CurrentState(ctx, roomID, &filter, nil)
|
||||
if err != nil {
|
||||
// Not a fatal error, we can continue without the stateEvents,
|
||||
// they are only needed if there are state events in the timeline.
|
||||
logrus.WithError(err).Warnf("Failed to get current room state for history visibility")
|
||||
// We need to make sure we always include the latest state events, if they are in the timeline.
|
||||
alwaysIncludeIDs := make(map[string]struct{})
|
||||
var stateTypes []string
|
||||
var senders []string
|
||||
for _, ev := range recentEvents {
|
||||
if ev.StateKey() != nil {
|
||||
stateTypes = append(stateTypes, ev.Type())
|
||||
senders = append(senders, ev.Sender())
|
||||
}
|
||||
}
|
||||
alwaysIncludeIDs := make(map[string]struct{}, len(stateEvents))
|
||||
for _, ev := range stateEvents {
|
||||
alwaysIncludeIDs[ev.EventID()] = struct{}{}
|
||||
|
||||
// Only get the state again if there are state events in the timeline
|
||||
if len(stateTypes) > 0 {
|
||||
filter := gomatrixserverlib.DefaultStateFilter()
|
||||
filter.Types = &stateTypes
|
||||
filter.Senders = &senders
|
||||
stateEvents, err := snapshot.CurrentState(ctx, roomID, &filter, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get current room state for history visibility calculation: %w", err)
|
||||
}
|
||||
|
||||
for _, ev := range stateEvents {
|
||||
alwaysIncludeIDs[ev.EventID()] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
startTime := time.Now()
|
||||
events, err := internal.ApplyHistoryVisibilityFilter(ctx, snapshot, rsAPI, recentEvents, alwaysIncludeIDs, userID, "sync")
|
||||
if err != nil {
|
||||
|
|
|
@ -521,6 +521,252 @@ func verifyEventVisible(t *testing.T, wantVisible bool, wantVisibleEvent *gomatr
|
|||
}
|
||||
}
|
||||
|
||||
func TestGetMembership(t *testing.T) {
|
||||
alice := test.NewUser(t)
|
||||
|
||||
aliceDev := userapi.Device{
|
||||
ID: "ALICEID",
|
||||
UserID: alice.ID,
|
||||
AccessToken: "ALICE_BEARER_TOKEN",
|
||||
DisplayName: "Alice",
|
||||
AccountType: userapi.AccountTypeUser,
|
||||
}
|
||||
|
||||
bob := test.NewUser(t)
|
||||
bobDev := userapi.Device{
|
||||
ID: "BOBID",
|
||||
UserID: bob.ID,
|
||||
AccessToken: "notjoinedtoanyrooms",
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
roomID string
|
||||
additionalEvents func(t *testing.T, room *test.Room)
|
||||
request func(t *testing.T, room *test.Room) *http.Request
|
||||
wantOK bool
|
||||
wantMemberCount int
|
||||
useSleep bool // :/
|
||||
}{
|
||||
{
|
||||
name: "/members - Alice joined",
|
||||
request: func(t *testing.T, room *test.Room) *http.Request {
|
||||
return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{
|
||||
"access_token": aliceDev.AccessToken,
|
||||
}))
|
||||
},
|
||||
wantOK: true,
|
||||
wantMemberCount: 1,
|
||||
},
|
||||
{
|
||||
name: "/members - Bob never joined",
|
||||
request: func(t *testing.T, room *test.Room) *http.Request {
|
||||
return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{
|
||||
"access_token": bobDev.AccessToken,
|
||||
}))
|
||||
},
|
||||
wantOK: false,
|
||||
},
|
||||
{
|
||||
name: "/joined_members - Bob never joined",
|
||||
request: func(t *testing.T, room *test.Room) *http.Request {
|
||||
return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/joined_members", room.ID), test.WithQueryParams(map[string]string{
|
||||
"access_token": bobDev.AccessToken,
|
||||
}))
|
||||
},
|
||||
wantOK: false,
|
||||
},
|
||||
{
|
||||
name: "/joined_members - Alice joined",
|
||||
request: func(t *testing.T, room *test.Room) *http.Request {
|
||||
return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/joined_members", room.ID), test.WithQueryParams(map[string]string{
|
||||
"access_token": aliceDev.AccessToken,
|
||||
}))
|
||||
},
|
||||
wantOK: true,
|
||||
},
|
||||
{
|
||||
name: "Alice leaves before Bob joins, should not be able to see Bob",
|
||||
request: func(t *testing.T, room *test.Room) *http.Request {
|
||||
return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{
|
||||
"access_token": aliceDev.AccessToken,
|
||||
}))
|
||||
},
|
||||
additionalEvents: func(t *testing.T, room *test.Room) {
|
||||
room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||
"membership": "leave",
|
||||
}, test.WithStateKey(alice.ID))
|
||||
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||
"membership": "join",
|
||||
}, test.WithStateKey(bob.ID))
|
||||
},
|
||||
useSleep: true,
|
||||
wantOK: true,
|
||||
wantMemberCount: 1,
|
||||
},
|
||||
{
|
||||
name: "Alice leaves after Bob joins, should be able to see Bob",
|
||||
request: func(t *testing.T, room *test.Room) *http.Request {
|
||||
return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{
|
||||
"access_token": aliceDev.AccessToken,
|
||||
}))
|
||||
},
|
||||
additionalEvents: func(t *testing.T, room *test.Room) {
|
||||
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||
"membership": "join",
|
||||
}, test.WithStateKey(bob.ID))
|
||||
room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||
"membership": "leave",
|
||||
}, test.WithStateKey(alice.ID))
|
||||
},
|
||||
useSleep: true,
|
||||
wantOK: true,
|
||||
wantMemberCount: 2,
|
||||
},
|
||||
{
|
||||
name: "/joined_members - Alice leaves, shouldn't be able to see members ",
|
||||
request: func(t *testing.T, room *test.Room) *http.Request {
|
||||
return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/joined_members", room.ID), test.WithQueryParams(map[string]string{
|
||||
"access_token": aliceDev.AccessToken,
|
||||
}))
|
||||
},
|
||||
additionalEvents: func(t *testing.T, room *test.Room) {
|
||||
room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||
"membership": "leave",
|
||||
}, test.WithStateKey(alice.ID))
|
||||
},
|
||||
useSleep: true,
|
||||
wantOK: false,
|
||||
},
|
||||
{
|
||||
name: "'at' specified, returns memberships before Bob joins",
|
||||
request: func(t *testing.T, room *test.Room) *http.Request {
|
||||
return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{
|
||||
"access_token": aliceDev.AccessToken,
|
||||
"at": "t2_5",
|
||||
}))
|
||||
},
|
||||
additionalEvents: func(t *testing.T, room *test.Room) {
|
||||
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||
"membership": "join",
|
||||
}, test.WithStateKey(bob.ID))
|
||||
},
|
||||
useSleep: true,
|
||||
wantOK: true,
|
||||
wantMemberCount: 1,
|
||||
},
|
||||
{
|
||||
name: "'membership=leave' specified, returns no memberships",
|
||||
request: func(t *testing.T, room *test.Room) *http.Request {
|
||||
return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{
|
||||
"access_token": aliceDev.AccessToken,
|
||||
"membership": "leave",
|
||||
}))
|
||||
},
|
||||
wantOK: true,
|
||||
wantMemberCount: 0,
|
||||
},
|
||||
{
|
||||
name: "'not_membership=join' specified, returns no memberships",
|
||||
request: func(t *testing.T, room *test.Room) *http.Request {
|
||||
return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{
|
||||
"access_token": aliceDev.AccessToken,
|
||||
"not_membership": "join",
|
||||
}))
|
||||
},
|
||||
wantOK: true,
|
||||
wantMemberCount: 0,
|
||||
},
|
||||
{
|
||||
name: "'not_membership=leave' & 'membership=join' specified, returns correct memberships",
|
||||
request: func(t *testing.T, room *test.Room) *http.Request {
|
||||
return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", room.ID), test.WithQueryParams(map[string]string{
|
||||
"access_token": aliceDev.AccessToken,
|
||||
"not_membership": "leave",
|
||||
"membership": "join",
|
||||
}))
|
||||
},
|
||||
additionalEvents: func(t *testing.T, room *test.Room) {
|
||||
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||
"membership": "join",
|
||||
}, test.WithStateKey(bob.ID))
|
||||
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||
"membership": "leave",
|
||||
}, test.WithStateKey(bob.ID))
|
||||
},
|
||||
wantOK: true,
|
||||
wantMemberCount: 1,
|
||||
},
|
||||
{
|
||||
name: "non-existent room ID",
|
||||
request: func(t *testing.T, room *test.Room) *http.Request {
|
||||
return test.NewRequest(t, "GET", fmt.Sprintf("/_matrix/client/v3/rooms/%s/members", "!notavalidroom:test"), test.WithQueryParams(map[string]string{
|
||||
"access_token": aliceDev.AccessToken,
|
||||
}))
|
||||
},
|
||||
wantOK: false,
|
||||
},
|
||||
}
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
|
||||
base, close := testrig.CreateBaseDendrite(t, dbType)
|
||||
defer close()
|
||||
|
||||
jsctx, _ := base.NATS.Prepare(base.ProcessContext, &base.Cfg.Global.JetStream)
|
||||
defer jetstream.DeleteAllStreams(jsctx, &base.Cfg.Global.JetStream)
|
||||
|
||||
// Use an actual roomserver for this
|
||||
rsAPI := roomserver.NewInternalAPI(base)
|
||||
rsAPI.SetFederationAPI(nil, nil)
|
||||
|
||||
AddPublicRoutes(base, &syncUserAPI{accounts: []userapi.Device{aliceDev, bobDev}}, rsAPI, &syncKeyAPI{})
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
room := test.NewRoom(t, alice)
|
||||
t.Cleanup(func() {
|
||||
t.Logf("running cleanup for %s", tc.name)
|
||||
})
|
||||
// inject additional events
|
||||
if tc.additionalEvents != nil {
|
||||
tc.additionalEvents(t, room)
|
||||
}
|
||||
if err := api.SendEvents(context.Background(), rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
|
||||
t.Fatalf("failed to send events: %v", err)
|
||||
}
|
||||
|
||||
// wait for the events to come down sync
|
||||
if tc.useSleep {
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
} else {
|
||||
syncUntil(t, base, aliceDev.AccessToken, false, func(syncBody string) bool {
|
||||
// wait for the last sent eventID to come down sync
|
||||
path := fmt.Sprintf(`rooms.join.%s.timeline.events.#(event_id=="%s")`, room.ID, room.Events()[len(room.Events())-1].EventID())
|
||||
return gjson.Get(syncBody, path).Exists()
|
||||
})
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
base.PublicClientAPIMux.ServeHTTP(w, tc.request(t, room))
|
||||
if w.Code != 200 && tc.wantOK {
|
||||
t.Logf("%s", w.Body.String())
|
||||
t.Fatalf("got HTTP %d want %d", w.Code, 200)
|
||||
}
|
||||
t.Logf("[%s] Resp: %s", tc.name, w.Body.String())
|
||||
|
||||
// check we got the expected events
|
||||
if tc.wantOK {
|
||||
memberCount := len(gjson.GetBytes(w.Body.Bytes(), "chunk").Array())
|
||||
if memberCount != tc.wantMemberCount {
|
||||
t.Fatalf("expected %d members, got %d", tc.wantMemberCount, memberCount)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSendToDevice(t *testing.T) {
|
||||
test.WithAllDatabases(t, testSendToDevice)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue