mirror of
https://github.com/hoernschen/dendrite.git
synced 2024-12-27 23:48:27 +00:00
Some refactoring
This commit is contained in:
parent
9b2b7a6e28
commit
e254594453
7 changed files with 193 additions and 173 deletions
|
@ -190,29 +190,28 @@ func OnIncomingMessagesRequest(
|
||||||
|
|
||||||
switch visibility {
|
switch visibility {
|
||||||
case "joined", "invited": // TODO: treat invites properly
|
case "joined", "invited": // TODO: treat invites properly
|
||||||
membership, _, err := db.MostRecentMembership(req.Context(), roomID, device.UserID) // nolint:govet
|
joinTypes := []string{"join"}
|
||||||
if err != nil {
|
if visibility == "invited" {
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("db.MostRecentMembership for history visibility failed")
|
joinTypes = []string{"invite", "join"}
|
||||||
return jsonerror.InternalServerError()
|
|
||||||
}
|
|
||||||
if membership == nil {
|
|
||||||
return util.JSONResponse{
|
|
||||||
Code: http.StatusForbidden,
|
|
||||||
JSON: jsonerror.Forbidden("History visibility prevents non-members from seeing this room"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
pos, err := db.EventPositionInTopology(req.Context(), membership.EventID())
|
|
||||||
if err != nil {
|
|
||||||
util.GetLogger(req.Context()).WithError(err).Error("db.PositionInTopology for history visibility failed")
|
|
||||||
return jsonerror.InternalServerError()
|
|
||||||
}
|
}
|
||||||
|
var joinPDUPos, joinTopoPos types.StreamPosition
|
||||||
|
var leavePDUPos, leaveTopoPos types.StreamPosition
|
||||||
|
if _, joinPDUPos, joinTopoPos, err = db.MostRecentMembership(req.Context(), roomID, device.UserID, joinTypes); err == nil {
|
||||||
if backwardOrdering {
|
if backwardOrdering {
|
||||||
if to.Depth < pos.Depth || to.PDUPosition < pos.PDUPosition {
|
to.PDUPosition = joinPDUPos
|
||||||
to = pos
|
to.Depth = joinTopoPos
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
if from.Depth < pos.Depth || from.PDUPosition < pos.PDUPosition {
|
from.PDUPosition = joinPDUPos
|
||||||
from = pos
|
from.Depth = joinTopoPos
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if _, leavePDUPos, leaveTopoPos, err = db.MostRecentMembership(req.Context(), roomID, device.UserID, []string{"leave", "ban", "kick"}); err == nil {
|
||||||
|
if backwardOrdering {
|
||||||
|
from.PDUPosition = leavePDUPos
|
||||||
|
from.Depth = leaveTopoPos
|
||||||
|
} else {
|
||||||
|
to.PDUPosition = leavePDUPos
|
||||||
|
to.Depth = leaveTopoPos
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -50,7 +50,7 @@ type Database interface {
|
||||||
RoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []eduAPI.OutputReceiptEvent, error)
|
RoomReceiptsAfter(ctx context.Context, roomIDs []string, streamPos types.StreamPosition) (types.StreamPosition, []eduAPI.OutputReceiptEvent, error)
|
||||||
|
|
||||||
// MostRecentMembership returns the most recent membership event for the user, along with the global stream position.
|
// MostRecentMembership returns the most recent membership event for the user, along with the global stream position.
|
||||||
MostRecentMembership(ctx context.Context, roomID, userID string) (*gomatrixserverlib.HeaderedEvent, types.StreamPosition, error)
|
MostRecentMembership(ctx context.Context, roomID, userID string, memberships []string) (*gomatrixserverlib.HeaderedEvent, types.StreamPosition, types.StreamPosition, error)
|
||||||
|
|
||||||
// AllJoinedUsersInRooms returns a map of room ID to a list of all joined user IDs.
|
// AllJoinedUsersInRooms returns a map of room ID to a list of all joined user IDs.
|
||||||
AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error)
|
AllJoinedUsersInRooms(ctx context.Context) (map[string][]string, error)
|
||||||
|
|
|
@ -59,7 +59,7 @@ const upsertMembershipSQL = "" +
|
||||||
const selectMembershipSQL = "" +
|
const selectMembershipSQL = "" +
|
||||||
"SELECT event_id, stream_pos, topological_pos FROM syncapi_memberships" +
|
"SELECT event_id, stream_pos, topological_pos FROM syncapi_memberships" +
|
||||||
" WHERE room_id = $1 AND user_id = $2 AND membership = ANY($3)" +
|
" WHERE room_id = $1 AND user_id = $2 AND membership = ANY($3)" +
|
||||||
" ORDER BY stream_pos DESC" +
|
" ORDER BY stream_pos ASC" +
|
||||||
" LIMIT 1"
|
" LIMIT 1"
|
||||||
|
|
||||||
type membershipsStatements struct {
|
type membershipsStatements struct {
|
||||||
|
@ -103,7 +103,7 @@ func (s *membershipsStatements) UpsertMembership(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipsStatements) SelectMembership(
|
func (s *membershipsStatements) SelectMembership(
|
||||||
ctx context.Context, txn *sql.Tx, roomID, userID, memberships []string,
|
ctx context.Context, txn *sql.Tx, roomID, userID string, memberships []string,
|
||||||
) (eventID string, streamPos, topologyPos types.StreamPosition, err error) {
|
) (eventID string, streamPos, topologyPos types.StreamPosition, err error) {
|
||||||
stmt := sqlutil.TxStmt(txn, s.selectMembershipStmt)
|
stmt := sqlutil.TxStmt(txn, s.selectMembershipStmt)
|
||||||
err = stmt.QueryRowContext(ctx, roomID, userID, memberships).Scan(&eventID, &streamPos, &topologyPos)
|
err = stmt.QueryRowContext(ctx, roomID, userID, memberships).Scan(&eventID, &streamPos, &topologyPos)
|
||||||
|
|
|
@ -516,20 +516,20 @@ func (d *Database) EventPositionInStream(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) MostRecentMembership(
|
func (d *Database) MostRecentMembership(
|
||||||
ctx context.Context, roomID, userID string,
|
ctx context.Context, roomID, userID string, memberships []string,
|
||||||
) (*gomatrixserverlib.HeaderedEvent, types.StreamPosition, error) {
|
) (*gomatrixserverlib.HeaderedEvent, types.StreamPosition, types.StreamPosition, error) {
|
||||||
event, err := d.CurrentRoomState.SelectStateEvent(ctx, roomID, gomatrixserverlib.MRoomMember, userID)
|
eventID, streamPos, topoPos, err := d.Memberships.SelectMembership(ctx, nil, roomID, userID, memberships)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, fmt.Errorf("d.CurrentRoomState.SelectStateEvent: %w", err)
|
return nil, 0, 0, fmt.Errorf("d.CurrentRoomState.SelectStateEvent: %w", err)
|
||||||
}
|
}
|
||||||
if event == nil {
|
events, err := d.OutputEvents.SelectEvents(ctx, nil, []string{eventID})
|
||||||
return nil, 0, nil
|
|
||||||
}
|
|
||||||
pos, err := d.OutputEvents.SelectPositionInStream(ctx, nil, event.EventID())
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, fmt.Errorf("d.OutputEvents.SelectPositionInStream: %w", err)
|
return nil, 0, 0, fmt.Errorf("d.OutputEvents.SelectEvents: %w", err)
|
||||||
}
|
}
|
||||||
return event, pos, nil
|
if len(events) == 0 {
|
||||||
|
return nil, 0, 0, fmt.Errorf("no event returned")
|
||||||
|
}
|
||||||
|
return events[0].HeaderedEvent, streamPos, topoPos, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) GetFilter(
|
func (d *Database) GetFilter(
|
||||||
|
|
|
@ -60,7 +60,7 @@ const upsertMembershipSQL = "" +
|
||||||
const selectMembershipSQL = "" +
|
const selectMembershipSQL = "" +
|
||||||
"SELECT event_id, stream_pos, topological_pos FROM syncapi_memberships" +
|
"SELECT event_id, stream_pos, topological_pos FROM syncapi_memberships" +
|
||||||
" WHERE room_id = $1 AND user_id = $2 AND membership IN ($3)" +
|
" WHERE room_id = $1 AND user_id = $2 AND membership IN ($3)" +
|
||||||
" ORDER BY stream_pos DESC" +
|
" ORDER BY stream_pos ASC" +
|
||||||
" LIMIT 1"
|
" LIMIT 1"
|
||||||
|
|
||||||
type membershipsStatements struct {
|
type membershipsStatements struct {
|
||||||
|
@ -103,7 +103,7 @@ func (s *membershipsStatements) UpsertMembership(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *membershipsStatements) SelectMembership(
|
func (s *membershipsStatements) SelectMembership(
|
||||||
ctx context.Context, txn *sql.Tx, roomID, userID, memberships []string,
|
ctx context.Context, txn *sql.Tx, roomID, userID string, memberships []string,
|
||||||
) (eventID string, streamPos, topologyPos types.StreamPosition, err error) {
|
) (eventID string, streamPos, topologyPos types.StreamPosition, err error) {
|
||||||
params := []interface{}{roomID, userID}
|
params := []interface{}{roomID, userID}
|
||||||
for _, membership := range memberships {
|
for _, membership := range memberships {
|
||||||
|
|
|
@ -166,5 +166,5 @@ type Receipts interface {
|
||||||
|
|
||||||
type Memberships interface {
|
type Memberships interface {
|
||||||
UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error
|
UpsertMembership(ctx context.Context, txn *sql.Tx, event *gomatrixserverlib.HeaderedEvent, streamPos, topologicalPos types.StreamPosition) error
|
||||||
SelectMembership(ctx context.Context, txn *sql.Tx, roomID, userID, memberships []string) (eventID string, streamPos, topologyPos types.StreamPosition, err error)
|
SelectMembership(ctx context.Context, txn *sql.Tx, roomID, userID string, memberships []string) (eventID string, streamPos, topologyPos types.StreamPosition, err error)
|
||||||
}
|
}
|
||||||
|
|
|
@ -151,7 +151,7 @@ func (p *PDUStreamProvider) IncrementalSync(
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, delta := range stateDeltas {
|
for _, delta := range stateDeltas {
|
||||||
if err = p.addRoomDeltaToResponse(ctx, req.Device, r, delta, &eventFilter, req.Response); err != nil {
|
if err = p.addRoomDeltaToResponse(ctx, req.Device, r, delta, &stateFilter, &eventFilter, req.Response); err != nil {
|
||||||
req.Log.WithError(err).Error("d.addRoomDeltaToResponse failed")
|
req.Log.WithError(err).Error("d.addRoomDeltaToResponse failed")
|
||||||
return newPos
|
return newPos
|
||||||
}
|
}
|
||||||
|
@ -160,27 +160,60 @@ func (p *PDUStreamProvider) IncrementalSync(
|
||||||
return r.To
|
return r.To
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *PDUStreamProvider) getHistoryVisibility(
|
||||||
|
ctx context.Context,
|
||||||
|
roomID string,
|
||||||
|
) (string, string, error) {
|
||||||
|
historyVisibility := "shared"
|
||||||
|
historyEventID := ""
|
||||||
|
|
||||||
|
historyVisFilter := gomatrixserverlib.DefaultStateFilter()
|
||||||
|
historyVisFilter.Types = []string{"m.room.history_visibility"}
|
||||||
|
|
||||||
|
historyVisEvents, err := p.DB.CurrentState(ctx, roomID, &historyVisFilter)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", fmt.Errorf("p.DB.CurrentState: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, event := range historyVisEvents {
|
||||||
|
if event.Type() != gomatrixserverlib.MRoomHistoryVisibility {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var content struct {
|
||||||
|
HistoryVisibility string `json:"history_visibility"`
|
||||||
|
}
|
||||||
|
if err := json.Unmarshal(event.Content(), &content); err != nil {
|
||||||
|
return historyVisibility, event.EventID(), fmt.Errorf("json.Unmarshal: %w", err)
|
||||||
|
} else {
|
||||||
|
historyVisibility = content.HistoryVisibility
|
||||||
|
historyEventID = event.EventID()
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
return historyVisibility, historyEventID, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (p *PDUStreamProvider) addRoomDeltaToResponse(
|
func (p *PDUStreamProvider) addRoomDeltaToResponse(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
device *userapi.Device,
|
device *userapi.Device,
|
||||||
r types.Range,
|
r types.Range,
|
||||||
delta types.StateDelta,
|
delta types.StateDelta,
|
||||||
|
_ *gomatrixserverlib.StateFilter,
|
||||||
eventFilter *gomatrixserverlib.RoomEventFilter,
|
eventFilter *gomatrixserverlib.RoomEventFilter,
|
||||||
res *types.Response,
|
res *types.Response,
|
||||||
) error {
|
) error {
|
||||||
if delta.MembershipPos > 0 && delta.Membership == gomatrixserverlib.Leave {
|
historyVisibility, historyEventID, err := p.getHistoryVisibility(ctx, delta.RoomID)
|
||||||
// make sure we don't leak recent events after the leave event.
|
if err != nil {
|
||||||
// TODO: History visibility makes this somewhat complex to handle correctly. For example:
|
return fmt.Errorf("p.getHistoryVisibility: %w", err)
|
||||||
// TODO: This doesn't work for join -> leave in a single /sync request (see events prior to join).
|
|
||||||
// TODO: This will fail on join -> leave -> sensitive msg -> join -> leave
|
|
||||||
// in a single /sync request
|
|
||||||
// This is all "okay" assuming history_visibility == "shared" which it is by default.
|
|
||||||
if r.Backwards {
|
|
||||||
r.From = delta.MembershipPos
|
|
||||||
} else {
|
|
||||||
r.To = delta.MembershipPos
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if r, _, err = p.limitBoundariesUsingHistoryVisibility(
|
||||||
|
ctx, delta.RoomID, device.UserID, historyVisibility, historyEventID, r,
|
||||||
|
); err != nil {
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
recentStreamEvents, limited, err := p.DB.RecentEvents(
|
recentStreamEvents, limited, err := p.DB.RecentEvents(
|
||||||
ctx, delta.RoomID, r,
|
ctx, delta.RoomID, r,
|
||||||
eventFilter, true, true,
|
eventFilter, true, true,
|
||||||
|
@ -235,6 +268,102 @@ func (p *PDUStreamProvider) addRoomDeltaToResponse(
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *PDUStreamProvider) limitBoundariesUsingHistoryVisibility(
|
||||||
|
ctx context.Context,
|
||||||
|
roomID, userID string,
|
||||||
|
historyVisibility, historyEventID string,
|
||||||
|
r types.Range,
|
||||||
|
) (types.Range, string, error) {
|
||||||
|
// Calculate the current history visibility rule.
|
||||||
|
|
||||||
|
var err error
|
||||||
|
var joinPos types.StreamPosition
|
||||||
|
var stateAtEventID string
|
||||||
|
|
||||||
|
// Check and see if the user is in the room.
|
||||||
|
switch historyVisibility {
|
||||||
|
case "invited", "joined", "shared":
|
||||||
|
// Get the most recent membership event of the user and check if
|
||||||
|
// they are still in the room. If not then we will restrict how
|
||||||
|
// much of the room the user can see - they won't see beyond their
|
||||||
|
// leave event.
|
||||||
|
joinTypes := []string{"join"}
|
||||||
|
if historyVisibility == "invited" {
|
||||||
|
joinTypes = []string{"invite", "join"}
|
||||||
|
}
|
||||||
|
if _, joinPos, _, err = p.DB.MostRecentMembership(ctx, roomID, userID, joinTypes); err != nil {
|
||||||
|
// The user isn't a part of the room, or hasn't been invited
|
||||||
|
// to the room.
|
||||||
|
return r, stateAtEventID, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
case "world_readable":
|
||||||
|
// It doesn't matter if the user is joined to the room or not
|
||||||
|
// when the history is world_readable.
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the user is in the room then we next need to work out if we
|
||||||
|
// should bind the beginning of the window based on the join position,
|
||||||
|
// or the position of the history visibility event.
|
||||||
|
switch historyVisibility {
|
||||||
|
case "invited", "joined":
|
||||||
|
if r.Backwards {
|
||||||
|
if r.To > joinPos {
|
||||||
|
r.To = joinPos
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if r.From > joinPos {
|
||||||
|
r.From = joinPos
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case "shared":
|
||||||
|
// Find the stream position of the history visibility event
|
||||||
|
// and use that as a boundary instead.
|
||||||
|
var historyVisibilityPosition types.StreamPosition
|
||||||
|
historyVisibilityPosition, err = p.DB.EventPositionInStream(ctx, historyEventID)
|
||||||
|
if err != nil {
|
||||||
|
return r, stateAtEventID, fmt.Errorf("p.DB.EventPositionInStream: %w", err)
|
||||||
|
}
|
||||||
|
if r.Backwards {
|
||||||
|
if r.To < historyVisibilityPosition {
|
||||||
|
r.To = historyVisibilityPosition
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if r.From < historyVisibilityPosition {
|
||||||
|
r.From = historyVisibilityPosition
|
||||||
|
}
|
||||||
|
}
|
||||||
|
stateAtEventID = historyEventID
|
||||||
|
|
||||||
|
case "world_readable":
|
||||||
|
// Do nothing, as it's OK to reveal the entire timeline in a
|
||||||
|
// world-readable room.
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finally, work out if the user left the room. If they did then
|
||||||
|
// we will request the state at the leave event from the roomserver.
|
||||||
|
switch historyVisibility {
|
||||||
|
case "invited", "joined", "shared":
|
||||||
|
if leaveEvent, leavePos, _, err := p.DB.MostRecentMembership(ctx, roomID, userID, []string{"leave", "ban", "kick"}); err == nil {
|
||||||
|
if r.Backwards {
|
||||||
|
if r.From > leavePos {
|
||||||
|
r.From = leavePos
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if r.To > leavePos {
|
||||||
|
r.To = leavePos
|
||||||
|
}
|
||||||
|
}
|
||||||
|
stateAtEventID = leaveEvent.EventID()
|
||||||
|
}
|
||||||
|
|
||||||
|
case "world_readable":
|
||||||
|
}
|
||||||
|
|
||||||
|
return r, stateAtEventID, nil
|
||||||
|
}
|
||||||
|
|
||||||
// nolint:gocyclo
|
// nolint:gocyclo
|
||||||
func (p *PDUStreamProvider) getResponseForCompleteSync(
|
func (p *PDUStreamProvider) getResponseForCompleteSync(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
|
@ -247,80 +376,33 @@ func (p *PDUStreamProvider) getResponseForCompleteSync(
|
||||||
recentEvents, stateEvents []*gomatrixserverlib.HeaderedEvent,
|
recentEvents, stateEvents []*gomatrixserverlib.HeaderedEvent,
|
||||||
prevBatch *types.TopologyToken, limited bool, err error,
|
prevBatch *types.TopologyToken, limited bool, err error,
|
||||||
) {
|
) {
|
||||||
|
historyVisibility, historyEventID, err := p.getHistoryVisibility(ctx, roomID)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var stateAtEvent string
|
||||||
|
if r, stateAtEvent, err = p.limitBoundariesUsingHistoryVisibility(
|
||||||
|
ctx, roomID, device.UserID, historyVisibility, historyEventID, r,
|
||||||
|
); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if stateAtEvent == "" {
|
||||||
stateEvents, err = p.DB.CurrentState(ctx, roomID, stateFilter)
|
stateEvents, err = p.DB.CurrentState(ctx, roomID, stateFilter)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calculate the current history visibility rule.
|
|
||||||
historyVisibility := "joined"
|
|
||||||
var historyVisibilityEvent *gomatrixserverlib.HeaderedEvent
|
|
||||||
for _, stateEvent := range stateEvents {
|
|
||||||
if stateEvent.Type() == gomatrixserverlib.MRoomHistoryVisibility {
|
|
||||||
var content struct {
|
|
||||||
HistoryVisibility string `json:"history_visibility"`
|
|
||||||
}
|
|
||||||
if err = json.Unmarshal(stateEvent.Content(), &content); err != nil {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
historyVisibility = content.HistoryVisibility
|
|
||||||
historyVisibilityEvent = stateEvent
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
switch historyVisibility {
|
|
||||||
case "invited", "joined":
|
|
||||||
// Get the most recent membership event of the user and check if
|
|
||||||
// they are still in the room. If not then we will restrict how
|
|
||||||
// much of the room the user can see - they won't see beyond their
|
|
||||||
// leave event.
|
|
||||||
var membershipEvent *gomatrixserverlib.HeaderedEvent
|
|
||||||
var membershipPos types.StreamPosition
|
|
||||||
membershipEvent, membershipPos, err = p.DB.MostRecentMembership(ctx, roomID, device.UserID)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if membershipEvent == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
membership, _ := membershipEvent.Membership()
|
|
||||||
switch membership {
|
|
||||||
case "leave", "ban", "kick":
|
|
||||||
if r.Backwards {
|
|
||||||
r.From = membershipPos
|
|
||||||
} else {
|
} else {
|
||||||
r.To = membershipPos
|
|
||||||
}
|
|
||||||
queryReq := &rsapi.QueryStateAfterEventsRequest{
|
queryReq := &rsapi.QueryStateAfterEventsRequest{
|
||||||
RoomID: roomID,
|
RoomID: roomID,
|
||||||
PrevEventIDs: []string{membershipEvent.EventID()},
|
PrevEventIDs: []string{stateAtEvent},
|
||||||
}
|
}
|
||||||
queryRes := &rsapi.QueryStateAfterEventsResponse{}
|
queryRes := &rsapi.QueryStateAfterEventsResponse{}
|
||||||
if err = p.rsAPI.QueryStateAfterEvents(ctx, queryReq, queryRes); err != nil {
|
if err = p.rsAPI.QueryStateAfterEvents(ctx, queryReq, queryRes); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
stateEvents = p.filterStateEventsAccordingToFilter(queryRes.StateEvents, stateFilter)
|
stateEvents = p.filterStateEventsAccordingToFilter(queryRes.StateEvents, stateFilter)
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
case "shared":
|
|
||||||
// Find the stream position of the history visibility event
|
|
||||||
// and use that as a boundary instead.
|
|
||||||
var historyVisibilityPosition types.StreamPosition
|
|
||||||
historyVisibilityPosition, err = p.DB.EventPositionInStream(ctx, historyVisibilityEvent.EventID())
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if r.Backwards {
|
|
||||||
r.To = historyVisibilityPosition
|
|
||||||
} else {
|
|
||||||
r.From = historyVisibilityPosition
|
|
||||||
}
|
|
||||||
|
|
||||||
case "world_readable":
|
|
||||||
// Do nothing, as it's OK to reveal the entire timeline in a
|
|
||||||
// world-readable room.
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: When filters are added, we may need to call this multiple times to get enough events.
|
// TODO: When filters are added, we may need to call this multiple times to get enough events.
|
||||||
|
@ -333,10 +415,6 @@ func (p *PDUStreamProvider) getResponseForCompleteSync(
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
recentStreamEvents, limited = p.filterStreamEventsAccordingToHistoryVisibility(
|
|
||||||
historyVisibility, recentStreamEvents, device, limited,
|
|
||||||
)
|
|
||||||
|
|
||||||
for _, event := range recentStreamEvents {
|
for _, event := range recentStreamEvents {
|
||||||
if event.HeaderedEvent.Event.StateKey() != nil {
|
if event.HeaderedEvent.Event.StateKey() != nil {
|
||||||
stateEvents = append(stateEvents, event.HeaderedEvent)
|
stateEvents = append(stateEvents, event.HeaderedEvent)
|
||||||
|
@ -460,63 +538,6 @@ func (p *PDUStreamProvider) filterStateEventsAccordingToFilter(
|
||||||
return newState
|
return newState
|
||||||
}
|
}
|
||||||
|
|
||||||
// nolint:gocyclo
|
|
||||||
func (p *PDUStreamProvider) filterStreamEventsAccordingToHistoryVisibility(
|
|
||||||
visibility string,
|
|
||||||
recentStreamEvents []types.StreamEvent,
|
|
||||||
device *userapi.Device,
|
|
||||||
limited bool,
|
|
||||||
) ([]types.StreamEvent, bool) {
|
|
||||||
// If the history is world_readable or shared then don't filter.
|
|
||||||
if visibility == "world_readable" || visibility == "shared" {
|
|
||||||
return recentStreamEvents, limited
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO FIXME: We don't fully implement history visibility yet. To avoid leaking events which the
|
|
||||||
// user shouldn't see, we check the recent events and remove any prior to the join event of the user
|
|
||||||
// which is equiv to history_visibility: joined
|
|
||||||
joinEventIndex := -1
|
|
||||||
leaveEventIndex := -1
|
|
||||||
for i := len(recentStreamEvents) - 1; i >= 0; i-- {
|
|
||||||
ev := recentStreamEvents[i]
|
|
||||||
if ev.Type() == gomatrixserverlib.MRoomMember && ev.StateKeyEquals(device.UserID) {
|
|
||||||
membership, _ := ev.Membership()
|
|
||||||
if membership == gomatrixserverlib.Join {
|
|
||||||
joinEventIndex = i
|
|
||||||
if i > 0 {
|
|
||||||
// the create event happens before the first join, so we should cut it at that point instead
|
|
||||||
if recentStreamEvents[i-1].Type() == gomatrixserverlib.MRoomCreate && recentStreamEvents[i-1].StateKeyEquals("") {
|
|
||||||
joinEventIndex = i - 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
break
|
|
||||||
} else if membership == gomatrixserverlib.Leave {
|
|
||||||
leaveEventIndex = i
|
|
||||||
}
|
|
||||||
|
|
||||||
if joinEventIndex != -1 && leaveEventIndex != -1 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Default at the start of the array
|
|
||||||
sliceStart := 0
|
|
||||||
// If there is a joinEvent, then cut all events earlier the join
|
|
||||||
if joinEventIndex != -1 {
|
|
||||||
sliceStart = joinEventIndex
|
|
||||||
limited = false // so clients know not to try to backpaginate
|
|
||||||
}
|
|
||||||
// Default to spanning the rest of the array
|
|
||||||
sliceEnd := len(recentStreamEvents)
|
|
||||||
// If there is a leaveEvent, then cut all events after the person left
|
|
||||||
if leaveEventIndex != -1 {
|
|
||||||
sliceEnd = leaveEventIndex + 1
|
|
||||||
}
|
|
||||||
|
|
||||||
return recentStreamEvents[sliceStart:sliceEnd], limited
|
|
||||||
}
|
|
||||||
|
|
||||||
func removeDuplicates(stateEvents, recentEvents []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent {
|
func removeDuplicates(stateEvents, recentEvents []*gomatrixserverlib.HeaderedEvent) []*gomatrixserverlib.HeaderedEvent {
|
||||||
timeline := map[string]struct{}{}
|
timeline := map[string]struct{}{}
|
||||||
for _, event := range recentEvents {
|
for _, event := range recentEvents {
|
||||||
|
|
Loading…
Reference in a new issue