mirror of
https://github.com/hoernschen/dendrite.git
synced 2025-07-30 21:12:45 +00:00
Cleanup remaining statekey usage for senderIDs (#3106)
This commit is contained in:
parent
832ccc32f6
commit
77d9e4e93d
62 changed files with 760 additions and 455 deletions
|
@ -373,7 +373,15 @@ func (s *OutputRoomEventConsumer) notifyJoinedPeeks(ctx context.Context, ev *rst
|
|||
// TODO: check that it's a join and not a profile change (means unmarshalling prev_content)
|
||||
if membership == spec.Join {
|
||||
// check it's a local join
|
||||
if _, _, err := s.cfg.Matrix.SplitLocalID('@', *ev.StateKey()); err != nil {
|
||||
if ev.StateKey() == nil {
|
||||
return sp, fmt.Errorf("unexpected nil state_key")
|
||||
}
|
||||
|
||||
userID, err := s.rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*ev.StateKey()))
|
||||
if err != nil || userID == nil {
|
||||
return sp, fmt.Errorf("failed getting userID for sender: %w", err)
|
||||
}
|
||||
if !s.cfg.Matrix.IsLocalServerName(userID.Domain()) {
|
||||
return sp, nil
|
||||
}
|
||||
|
||||
|
@ -395,9 +403,15 @@ func (s *OutputRoomEventConsumer) onNewInviteEvent(
|
|||
if msg.Event.StateKey() == nil {
|
||||
return
|
||||
}
|
||||
if _, _, err := s.cfg.Matrix.SplitLocalID('@', *msg.Event.StateKey()); err != nil {
|
||||
|
||||
userID, err := s.rsAPI.QueryUserIDForSender(ctx, msg.Event.RoomID(), spec.SenderID(*msg.Event.StateKey()))
|
||||
if err != nil || userID == nil {
|
||||
return
|
||||
}
|
||||
if !s.cfg.Matrix.IsLocalServerName(userID.Domain()) {
|
||||
return
|
||||
}
|
||||
|
||||
pduPos, err := s.db.AddInviteEvent(ctx, msg.Event)
|
||||
if err != nil {
|
||||
sentry.CaptureException(err)
|
||||
|
@ -440,7 +454,16 @@ func (s *OutputRoomEventConsumer) onRetireInviteEvent(
|
|||
|
||||
// Notify any active sync requests that the invite has been retired.
|
||||
s.inviteStream.Advance(pduPos)
|
||||
s.notifier.OnNewInvite(types.StreamingToken{InvitePosition: pduPos}, msg.TargetUserID)
|
||||
userID, err := s.rsAPI.QueryUserIDForSender(ctx, msg.RoomID, msg.TargetSenderID)
|
||||
if err != nil || userID == nil {
|
||||
log.WithFields(log.Fields{
|
||||
"event_id": msg.EventID,
|
||||
"sender_id": msg.TargetSenderID,
|
||||
log.ErrorKey: err,
|
||||
}).Errorf("failed to find userID for sender")
|
||||
return
|
||||
}
|
||||
s.notifier.OnNewInvite(types.StreamingToken{InvitePosition: pduPos}, userID.String())
|
||||
}
|
||||
|
||||
func (s *OutputRoomEventConsumer) onNewPeek(
|
||||
|
|
|
@ -134,9 +134,17 @@ func ApplyHistoryVisibilityFilter(
|
|||
}
|
||||
}
|
||||
// NOTSPEC: Always allow user to see their own membership events (spec contains more "rules")
|
||||
if ev.Type() == spec.MRoomMember && ev.StateKeyEquals(userID) {
|
||||
eventsFiltered = append(eventsFiltered, ev)
|
||||
continue
|
||||
|
||||
user, err := spec.NewUserID(userID, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
senderID, err := rsAPI.QuerySenderIDForUser(ctx, ev.RoomID(), *user)
|
||||
if err == nil {
|
||||
if ev.Type() == spec.MRoomMember && ev.StateKeyEquals(string(senderID)) {
|
||||
eventsFiltered = append(eventsFiltered, ev)
|
||||
continue
|
||||
}
|
||||
}
|
||||
// Always allow history evVis events on boundaries. This is done
|
||||
// by setting the effective evVis to the least restrictive
|
||||
|
|
|
@ -169,12 +169,16 @@ func TrackChangedUsers(
|
|||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
for _, state := range stateRes.Rooms {
|
||||
for roomID, state := range stateRes.Rooms {
|
||||
for tuple, membership := range state {
|
||||
if membership != spec.Join {
|
||||
continue
|
||||
}
|
||||
queryRes.UserIDsToCount[tuple.StateKey]--
|
||||
user, queryErr := rsAPI.QueryUserIDForSender(ctx, roomID, spec.SenderID(tuple.StateKey))
|
||||
if queryErr != nil || user == nil {
|
||||
continue
|
||||
}
|
||||
queryRes.UserIDsToCount[user.String()]--
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -211,14 +215,18 @@ func TrackChangedUsers(
|
|||
if err != nil {
|
||||
return nil, left, err
|
||||
}
|
||||
for _, state := range stateRes.Rooms {
|
||||
for roomID, state := range stateRes.Rooms {
|
||||
for tuple, membership := range state {
|
||||
if membership != spec.Join {
|
||||
continue
|
||||
}
|
||||
// new user who we weren't previously sharing rooms with
|
||||
if _, ok := queryRes.UserIDsToCount[tuple.StateKey]; !ok {
|
||||
changed = append(changed, tuple.StateKey) // changed is returned
|
||||
user, err := rsAPI.QueryUserIDForSender(ctx, roomID, spec.SenderID(tuple.StateKey))
|
||||
if err != nil || user == nil {
|
||||
continue
|
||||
}
|
||||
changed = append(changed, user.String()) // changed is returned
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -64,6 +64,10 @@ type mockRoomserverAPI struct {
|
|||
roomIDToJoinedMembers map[string][]string
|
||||
}
|
||||
|
||||
func (s *mockRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
return spec.NewUserID(string(senderID), true)
|
||||
}
|
||||
|
||||
// QueryRoomsForUser retrieves a list of room IDs matching the given query.
|
||||
func (s *mockRoomserverAPI) QueryRoomsForUser(ctx context.Context, req *api.QueryRoomsForUserRequest, res *api.QueryRoomsForUserResponse) error {
|
||||
return nil
|
||||
|
|
|
@ -20,6 +20,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/api"
|
||||
rstypes "github.com/matrix-org/dendrite/roomserver/types"
|
||||
"github.com/matrix-org/dendrite/syncapi/storage"
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
|
@ -36,7 +37,8 @@ import (
|
|||
// the event, but the token has already advanced by the time they fetch it, resulting
|
||||
// in missed events.
|
||||
type Notifier struct {
|
||||
lock *sync.RWMutex
|
||||
lock *sync.RWMutex
|
||||
rsAPI api.SyncRoomserverAPI
|
||||
// A map of RoomID => Set<UserID> : Must only be accessed by the OnNewEvent goroutine
|
||||
roomIDToJoinedUsers map[string]*userIDSet
|
||||
// A map of RoomID => Set<UserID> : Must only be accessed by the OnNewEvent goroutine
|
||||
|
@ -55,8 +57,9 @@ type Notifier struct {
|
|||
// NewNotifier creates a new notifier set to the given sync position.
|
||||
// In order for this to be of any use, the Notifier needs to be told all rooms and
|
||||
// the joined users within each of them by calling Notifier.Load(*storage.SyncServerDatabase).
|
||||
func NewNotifier() *Notifier {
|
||||
func NewNotifier(rsAPI api.SyncRoomserverAPI) *Notifier {
|
||||
return &Notifier{
|
||||
rsAPI: rsAPI,
|
||||
roomIDToJoinedUsers: make(map[string]*userIDSet),
|
||||
roomIDToPeekingDevices: make(map[string]peekingDeviceSet),
|
||||
userDeviceStreams: make(map[string]map[string]*UserDeviceStream),
|
||||
|
@ -104,26 +107,32 @@ func (n *Notifier) OnNewEvent(
|
|||
peekingDevicesToNotify := n._peekingDevices(ev.RoomID())
|
||||
// If this is an invite, also add in the invitee to this list.
|
||||
if ev.Type() == "m.room.member" && ev.StateKey() != nil {
|
||||
targetUserID := *ev.StateKey()
|
||||
membership, err := ev.Membership()
|
||||
targetUserID, err := n.rsAPI.QueryUserIDForSender(context.Background(), ev.RoomID(), spec.SenderID(*ev.StateKey()))
|
||||
if err != nil {
|
||||
log.WithError(err).WithField("event_id", ev.EventID()).Errorf(
|
||||
"Notifier.OnNewEvent: Failed to unmarshal member event",
|
||||
"Notifier.OnNewEvent: Failed to find the userID for this event",
|
||||
)
|
||||
} else {
|
||||
// Keep the joined user map up-to-date
|
||||
switch membership {
|
||||
case spec.Invite:
|
||||
usersToNotify = append(usersToNotify, targetUserID)
|
||||
case spec.Join:
|
||||
// Manually append the new user's ID so they get notified
|
||||
// along all members in the room
|
||||
usersToNotify = append(usersToNotify, targetUserID)
|
||||
n._addJoinedUser(ev.RoomID(), targetUserID)
|
||||
case spec.Leave:
|
||||
fallthrough
|
||||
case spec.Ban:
|
||||
n._removeJoinedUser(ev.RoomID(), targetUserID)
|
||||
membership, err := ev.Membership()
|
||||
if err != nil {
|
||||
log.WithError(err).WithField("event_id", ev.EventID()).Errorf(
|
||||
"Notifier.OnNewEvent: Failed to unmarshal member event",
|
||||
)
|
||||
} else {
|
||||
// Keep the joined user map up-to-date
|
||||
switch membership {
|
||||
case spec.Invite:
|
||||
usersToNotify = append(usersToNotify, targetUserID.String())
|
||||
case spec.Join:
|
||||
// Manually append the new user's ID so they get notified
|
||||
// along all members in the room
|
||||
usersToNotify = append(usersToNotify, targetUserID.String())
|
||||
n._addJoinedUser(ev.RoomID(), targetUserID.String())
|
||||
case spec.Leave:
|
||||
fallthrough
|
||||
case spec.Ban:
|
||||
n._removeJoinedUser(ev.RoomID(), targetUserID.String())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,9 +22,11 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/roomserver/api"
|
||||
rstypes "github.com/matrix-org/dendrite/roomserver/types"
|
||||
"github.com/matrix-org/dendrite/syncapi/types"
|
||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib/spec"
|
||||
"github.com/matrix-org/util"
|
||||
)
|
||||
|
||||
|
@ -105,9 +107,15 @@ func mustEqualPositions(t *testing.T, got, want types.StreamingToken) {
|
|||
}
|
||||
}
|
||||
|
||||
type TestRoomServer struct{ api.SyncRoomserverAPI }
|
||||
|
||||
func (t *TestRoomServer) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
return spec.NewUserID(string(senderID), true)
|
||||
}
|
||||
|
||||
// Test that the current position is returned if a request is already behind.
|
||||
func TestImmediateNotification(t *testing.T) {
|
||||
n := NewNotifier()
|
||||
n := NewNotifier(&TestRoomServer{})
|
||||
n.SetCurrentPosition(syncPositionBefore)
|
||||
pos, err := waitForEvents(n, newTestSyncRequest(alice, aliceDev, syncPositionVeryOld))
|
||||
if err != nil {
|
||||
|
@ -118,7 +126,7 @@ func TestImmediateNotification(t *testing.T) {
|
|||
|
||||
// Test that new events to a joined room unblocks the request.
|
||||
func TestNewEventAndJoinedToRoom(t *testing.T) {
|
||||
n := NewNotifier()
|
||||
n := NewNotifier(&TestRoomServer{})
|
||||
n.SetCurrentPosition(syncPositionBefore)
|
||||
n.setUsersJoinedToRooms(map[string][]string{
|
||||
roomID: {alice, bob},
|
||||
|
@ -144,7 +152,7 @@ func TestNewEventAndJoinedToRoom(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestCorrectStream(t *testing.T) {
|
||||
n := NewNotifier()
|
||||
n := NewNotifier(&TestRoomServer{})
|
||||
n.SetCurrentPosition(syncPositionBefore)
|
||||
stream := lockedFetchUserStream(n, bob, bobDev)
|
||||
if stream.UserID != bob {
|
||||
|
@ -156,7 +164,7 @@ func TestCorrectStream(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestCorrectStreamWakeup(t *testing.T) {
|
||||
n := NewNotifier()
|
||||
n := NewNotifier(&TestRoomServer{})
|
||||
n.SetCurrentPosition(syncPositionBefore)
|
||||
awoken := make(chan string)
|
||||
|
||||
|
@ -184,7 +192,7 @@ func TestCorrectStreamWakeup(t *testing.T) {
|
|||
|
||||
// Test that an invite unblocks the request
|
||||
func TestNewInviteEventForUser(t *testing.T) {
|
||||
n := NewNotifier()
|
||||
n := NewNotifier(&TestRoomServer{})
|
||||
n.SetCurrentPosition(syncPositionBefore)
|
||||
n.setUsersJoinedToRooms(map[string][]string{
|
||||
roomID: {alice, bob},
|
||||
|
@ -241,7 +249,7 @@ func TestEDUWakeup(t *testing.T) {
|
|||
|
||||
// Test that all blocked requests get woken up on a new event.
|
||||
func TestMultipleRequestWakeup(t *testing.T) {
|
||||
n := NewNotifier()
|
||||
n := NewNotifier(&TestRoomServer{})
|
||||
n.SetCurrentPosition(syncPositionBefore)
|
||||
n.setUsersJoinedToRooms(map[string][]string{
|
||||
roomID: {alice, bob},
|
||||
|
@ -278,7 +286,7 @@ func TestMultipleRequestWakeup(t *testing.T) {
|
|||
func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) {
|
||||
// listen as bob. Make bob leave room. Make alice send event to room.
|
||||
// Make sure alice gets woken up only and not bob as well.
|
||||
n := NewNotifier()
|
||||
n := NewNotifier(&TestRoomServer{})
|
||||
n.SetCurrentPosition(syncPositionBefore)
|
||||
n.setUsersJoinedToRooms(map[string][]string{
|
||||
roomID: {alice, bob},
|
||||
|
|
|
@ -85,9 +85,16 @@ func Context(
|
|||
*filter.Rooms = append(*filter.Rooms, roomID)
|
||||
}
|
||||
|
||||
userID, err := spec.NewUserID(device.UserID, true)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: spec.InvalidParam("Device UserID is invalid"),
|
||||
}
|
||||
}
|
||||
ctx := req.Context()
|
||||
membershipRes := roomserver.QueryMembershipForUserResponse{}
|
||||
membershipReq := roomserver.QueryMembershipForUserRequest{UserID: device.UserID, RoomID: roomID}
|
||||
membershipReq := roomserver.QueryMembershipForUserRequest{UserID: *userID, RoomID: roomID}
|
||||
if err = rsAPI.QueryMembershipForUser(ctx, &membershipReq, &membershipRes); err != nil {
|
||||
logrus.WithError(err).Error("unable to query membership")
|
||||
return util.JSONResponse{
|
||||
|
@ -217,12 +224,9 @@ func Context(
|
|||
}
|
||||
}
|
||||
|
||||
sender := spec.UserID{}
|
||||
userID, err := rsAPI.QueryUserIDForSender(ctx, requestedEvent.RoomID(), requestedEvent.SenderID())
|
||||
if err == nil && userID != nil {
|
||||
sender = *userID
|
||||
}
|
||||
ev := synctypes.ToClientEvent(&requestedEvent, synctypes.FormatAll, sender)
|
||||
ev := synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
|
||||
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
|
||||
}, requestedEvent)
|
||||
response := ContextRespsonse{
|
||||
Event: &ev,
|
||||
EventsAfter: eventsAfterClient,
|
||||
|
|
|
@ -106,8 +106,17 @@ func GetEvent(
|
|||
if err == nil && senderUserID != nil {
|
||||
sender = *senderUserID
|
||||
}
|
||||
|
||||
sk := events[0].StateKey()
|
||||
if sk != nil && *sk != "" {
|
||||
skUserID, err := rsAPI.QueryUserIDForSender(ctx, events[0].RoomID(), spec.SenderID(*events[0].StateKey()))
|
||||
if err == nil && skUserID != nil {
|
||||
skString := skUserID.String()
|
||||
sk = &skString
|
||||
}
|
||||
}
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, sender),
|
||||
JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, sender, sk),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -59,14 +59,21 @@ func GetMemberships(
|
|||
syncDB storage.Database, rsAPI api.SyncRoomserverAPI,
|
||||
joinedOnly bool, membership, notMembership *string, at string,
|
||||
) util.JSONResponse {
|
||||
userID, err := spec.NewUserID(device.UserID, true)
|
||||
if err != nil {
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusBadRequest,
|
||||
JSON: spec.InvalidParam("Device UserID is invalid"),
|
||||
}
|
||||
}
|
||||
queryReq := api.QueryMembershipForUserRequest{
|
||||
RoomID: roomID,
|
||||
UserID: device.UserID,
|
||||
UserID: *userID,
|
||||
}
|
||||
|
||||
var queryRes api.QueryMembershipForUserResponse
|
||||
if err := rsAPI.QueryMembershipForUser(req.Context(), &queryReq, &queryRes); err != nil {
|
||||
util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryMembershipsForRoom failed")
|
||||
if queryErr := rsAPI.QueryMembershipForUser(req.Context(), &queryReq, &queryRes); queryErr != nil {
|
||||
util.GetLogger(req.Context()).WithError(queryErr).Error("rsAPI.QueryMembershipsForRoom failed")
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusInternalServerError,
|
||||
JSON: spec.InternalServerError{},
|
||||
|
|
|
@ -296,9 +296,13 @@ func OnIncomingMessagesRequest(
|
|||
}
|
||||
|
||||
func getMembershipForUser(ctx context.Context, roomID, userID string, rsAPI api.SyncRoomserverAPI) (resp api.QueryMembershipForUserResponse, err error) {
|
||||
fullUserID, err := spec.NewUserID(userID, true)
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
req := api.QueryMembershipForUserRequest{
|
||||
RoomID: roomID,
|
||||
UserID: userID,
|
||||
UserID: *fullUserID,
|
||||
}
|
||||
if err := rsAPI.QueryMembershipForUser(ctx, &req, &resp); err != nil {
|
||||
return api.QueryMembershipForUserResponse{}, err
|
||||
|
|
|
@ -119,9 +119,18 @@ func Relations(
|
|||
if err == nil && userID != nil {
|
||||
sender = *userID
|
||||
}
|
||||
|
||||
sk := event.StateKey()
|
||||
if sk != nil && *sk != "" {
|
||||
skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), spec.SenderID(*event.StateKey()))
|
||||
if err == nil && skUserID != nil {
|
||||
skString := skUserID.String()
|
||||
sk = &skString
|
||||
}
|
||||
}
|
||||
res.Chunk = append(
|
||||
res.Chunk,
|
||||
synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender),
|
||||
synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender, sk),
|
||||
)
|
||||
}
|
||||
|
||||
|
|
|
@ -235,6 +235,15 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
|
|||
if err == nil && userID != nil {
|
||||
sender = *userID
|
||||
}
|
||||
|
||||
sk := event.StateKey()
|
||||
if sk != nil && *sk != "" {
|
||||
skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), spec.SenderID(*event.StateKey()))
|
||||
if err == nil && skUserID != nil {
|
||||
skString := skUserID.String()
|
||||
sk = &skString
|
||||
}
|
||||
}
|
||||
results = append(results, Result{
|
||||
Context: SearchContextResponse{
|
||||
Start: startToken.String(),
|
||||
|
@ -248,7 +257,7 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
|
|||
ProfileInfo: profileInfos,
|
||||
},
|
||||
Rank: eventScore[event.EventID()].Score,
|
||||
Result: synctypes.ToClientEvent(event, synctypes.FormatAll, sender),
|
||||
Result: synctypes.ToClientEvent(event, synctypes.FormatAll, sender, sk),
|
||||
})
|
||||
roomGroup := groups[event.RoomID()]
|
||||
roomGroup.Results = append(roomGroup.Results, event.EventID())
|
||||
|
|
|
@ -507,8 +507,20 @@ func (d *Database) CleanSendToDeviceUpdates(
|
|||
|
||||
// getMembershipFromEvent returns the value of content.membership iff the event is a state event
|
||||
// with type 'm.room.member' and state_key of userID. Otherwise, an empty string is returned.
|
||||
func getMembershipFromEvent(ev gomatrixserverlib.PDU, userID string) (string, string) {
|
||||
if ev.Type() != "m.room.member" || !ev.StateKeyEquals(userID) {
|
||||
func getMembershipFromEvent(ctx context.Context, ev gomatrixserverlib.PDU, userID string, rsAPI api.SyncRoomserverAPI) (string, string) {
|
||||
if ev.StateKey() == nil || *ev.StateKey() == "" {
|
||||
return "", ""
|
||||
}
|
||||
fullUser, err := spec.NewUserID(userID, true)
|
||||
if err != nil {
|
||||
return "", ""
|
||||
}
|
||||
senderID, err := rsAPI.QuerySenderIDForUser(ctx, ev.RoomID(), *fullUser)
|
||||
if err != nil {
|
||||
return "", ""
|
||||
}
|
||||
|
||||
if ev.Type() != "m.room.member" || !ev.StateKeyEquals(string(senderID)) {
|
||||
return "", ""
|
||||
}
|
||||
membership, err := ev.Membership()
|
||||
|
|
|
@ -430,7 +430,7 @@ func (d *DatabaseTransaction) GetStateDeltas(
|
|||
for _, ev := range stateStreamEvents {
|
||||
// Look for our membership in the state events and skip over any
|
||||
// membership events that are not related to us.
|
||||
membership, prevMembership := getMembershipFromEvent(ev.PDU, userID)
|
||||
membership, prevMembership := getMembershipFromEvent(ctx, ev.PDU, userID, rsAPI)
|
||||
if membership == "" {
|
||||
continue
|
||||
}
|
||||
|
@ -556,7 +556,7 @@ func (d *DatabaseTransaction) GetStateDeltasForFullStateSync(
|
|||
|
||||
for roomID, stateStreamEvents := range state {
|
||||
for _, ev := range stateStreamEvents {
|
||||
if membership, _ := getMembershipFromEvent(ev.PDU, userID); membership != "" {
|
||||
if membership, _ := getMembershipFromEvent(ctx, ev.PDU, userID, rsAPI); membership != "" {
|
||||
if membership != spec.Join { // We've already added full state for all joined rooms above.
|
||||
deltas[roomID] = types.StateDelta{
|
||||
Membership: membership,
|
||||
|
|
|
@ -70,11 +70,20 @@ func (p *InviteStreamProvider) IncrementalSync(
|
|||
user = *sender
|
||||
}
|
||||
|
||||
sk := inviteEvent.StateKey()
|
||||
if sk != nil && *sk != "" {
|
||||
skUserID, err := p.rsAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), spec.SenderID(*inviteEvent.StateKey()))
|
||||
if err == nil && skUserID != nil {
|
||||
skString := skUserID.String()
|
||||
sk = &skString
|
||||
}
|
||||
}
|
||||
|
||||
// skip ignored user events
|
||||
if _, ok := req.IgnoredUsers.List[user.String()]; ok {
|
||||
continue
|
||||
}
|
||||
ir := types.NewInviteResponse(inviteEvent, user)
|
||||
ir := types.NewInviteResponse(inviteEvent, user, sk)
|
||||
req.Response.Rooms.Invite[roomID] = ir
|
||||
}
|
||||
|
||||
|
|
|
@ -605,13 +605,17 @@ func (p *PDUStreamProvider) lazyLoadMembers(
|
|||
// If this is a gapped incremental sync, we still want this membership
|
||||
isGappedIncremental := limited && incremental
|
||||
// We want this users membership event, keep it in the list
|
||||
stateKey := *event.StateKey()
|
||||
if _, ok := timelineUsers[stateKey]; ok || isGappedIncremental || stateKey == device.UserID {
|
||||
userID := ""
|
||||
stateKeyUserID, err := p.rsAPI.QueryUserIDForSender(ctx, roomID, spec.SenderID(*event.StateKey()))
|
||||
if err == nil && stateKeyUserID != nil {
|
||||
userID = stateKeyUserID.String()
|
||||
}
|
||||
if _, ok := timelineUsers[userID]; ok || isGappedIncremental || userID == device.UserID {
|
||||
newStateEvents = append(newStateEvents, event)
|
||||
if !stateFilter.IncludeRedundantMembers {
|
||||
p.lazyLoadCache.StoreLazyLoadedUser(device, roomID, stateKey, event.EventID())
|
||||
p.lazyLoadCache.StoreLazyLoadedUser(device, roomID, userID, event.EventID())
|
||||
}
|
||||
delete(timelineUsers, stateKey)
|
||||
delete(timelineUsers, userID)
|
||||
}
|
||||
} else {
|
||||
newStateEvents = append(newStateEvents, event)
|
||||
|
|
|
@ -60,7 +60,7 @@ func AddPublicRoutes(
|
|||
}
|
||||
|
||||
eduCache := caching.NewTypingCache()
|
||||
notifier := notifier.NewNotifier()
|
||||
notifier := notifier.NewNotifier(rsAPI)
|
||||
streams := streams.NewSyncStreamProviders(syncDB, userAPI, rsAPI, eduCache, caches, notifier)
|
||||
notifier.SetCurrentPosition(streams.Latest(context.Background()))
|
||||
if err = notifier.Load(context.Background(), syncDB); err != nil {
|
||||
|
|
|
@ -55,18 +55,27 @@ func ToClientEvents(serverEvs []gomatrixserverlib.PDU, format ClientEventFormat,
|
|||
if err == nil && userID != nil {
|
||||
sender = *userID
|
||||
}
|
||||
evs = append(evs, ToClientEvent(se, format, sender))
|
||||
|
||||
sk := se.StateKey()
|
||||
if sk != nil && *sk != "" {
|
||||
skUserID, err := userIDForSender(se.RoomID(), spec.SenderID(*sk))
|
||||
if err == nil && skUserID != nil {
|
||||
skString := skUserID.String()
|
||||
sk = &skString
|
||||
}
|
||||
}
|
||||
evs = append(evs, ToClientEvent(se, format, sender, sk))
|
||||
}
|
||||
return evs
|
||||
}
|
||||
|
||||
// ToClientEvent converts a single server event to a client event.
|
||||
func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender spec.UserID) ClientEvent {
|
||||
func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender spec.UserID, stateKey *string) ClientEvent {
|
||||
ce := ClientEvent{
|
||||
Content: spec.RawJSON(se.Content()),
|
||||
Sender: sender.String(),
|
||||
Type: se.Type(),
|
||||
StateKey: se.StateKey(),
|
||||
StateKey: stateKey,
|
||||
Unsigned: spec.RawJSON(se.Unsigned()),
|
||||
OriginServerTS: se.OriginServerTS(),
|
||||
EventID: se.EventID(),
|
||||
|
@ -77,3 +86,23 @@ func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender sp
|
|||
}
|
||||
return ce
|
||||
}
|
||||
|
||||
// ToClientEvent converts a single server event to a client event.
|
||||
// It provides default logic for event.SenderID & event.StateKey -> userID conversions.
|
||||
func ToClientEventDefault(userIDQuery spec.UserIDForSender, event gomatrixserverlib.PDU) ClientEvent {
|
||||
sender := spec.UserID{}
|
||||
userID, err := userIDQuery(event.RoomID(), event.SenderID())
|
||||
if err == nil && userID != nil {
|
||||
sender = *userID
|
||||
}
|
||||
|
||||
sk := event.StateKey()
|
||||
if sk != nil && *sk != "" {
|
||||
skUserID, err := userIDQuery(event.RoomID(), spec.SenderID(*event.StateKey()))
|
||||
if err == nil && skUserID != nil {
|
||||
skString := skUserID.String()
|
||||
sk = &skString
|
||||
}
|
||||
}
|
||||
return ToClientEvent(event, FormatAll, sender, sk)
|
||||
}
|
||||
|
|
|
@ -48,7 +48,8 @@ func TestToClientEvent(t *testing.T) { // nolint: gocyclo
|
|||
if err != nil {
|
||||
t.Fatalf("failed to create userID: %s", err)
|
||||
}
|
||||
ce := ToClientEvent(ev, FormatAll, *userID)
|
||||
sk := ""
|
||||
ce := ToClientEvent(ev, FormatAll, *userID, &sk)
|
||||
if ce.EventID != ev.EventID() {
|
||||
t.Errorf("ClientEvent.EventID: wanted %s, got %s", ev.EventID(), ce.EventID)
|
||||
}
|
||||
|
@ -107,7 +108,8 @@ func TestToClientFormatSync(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("failed to create userID: %s", err)
|
||||
}
|
||||
ce := ToClientEvent(ev, FormatSync, *userID)
|
||||
sk := ""
|
||||
ce := ToClientEvent(ev, FormatSync, *userID, &sk)
|
||||
if ce.RoomID != "" {
|
||||
t.Errorf("ClientEvent.RoomID: wanted '', got %s", ce.RoomID)
|
||||
}
|
||||
|
|
|
@ -539,7 +539,7 @@ type InviteResponse struct {
|
|||
}
|
||||
|
||||
// NewInviteResponse creates an empty response with initialised arrays.
|
||||
func NewInviteResponse(event *types.HeaderedEvent, userID spec.UserID) *InviteResponse {
|
||||
func NewInviteResponse(event *types.HeaderedEvent, userID spec.UserID, stateKey *string) *InviteResponse {
|
||||
res := InviteResponse{}
|
||||
res.InviteState.Events = []json.RawMessage{}
|
||||
|
||||
|
@ -552,7 +552,7 @@ func NewInviteResponse(event *types.HeaderedEvent, userID spec.UserID) *InviteRe
|
|||
|
||||
// Then we'll see if we can create a partial of the invite event itself.
|
||||
// This is needed for clients to work out *who* sent the invite.
|
||||
inviteEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatSync, userID)
|
||||
inviteEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatSync, userID, stateKey)
|
||||
inviteEvent.Unsigned = nil
|
||||
if ev, err := json.Marshal(inviteEvent); err == nil {
|
||||
res.InviteState.Events = append(res.InviteState.Events, ev)
|
||||
|
|
|
@ -65,8 +65,14 @@ func TestNewInviteResponse(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
skUserID, err := spec.NewUserID("@neilalexander:dendrite.neilalexander.dev", true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
skString := skUserID.String()
|
||||
sk := &skString
|
||||
|
||||
res := NewInviteResponse(&types.HeaderedEvent{PDU: ev}, *sender)
|
||||
res := NewInviteResponse(&types.HeaderedEvent{PDU: ev}, *sender, sk)
|
||||
j, err := json.Marshal(res)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue