Cleanup remaining statekey usage for senderIDs (#3106)

This commit is contained in:
devonh 2023-06-12 11:19:25 +00:00 committed by GitHub
parent 832ccc32f6
commit 77d9e4e93d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
62 changed files with 760 additions and 455 deletions

View file

@ -373,7 +373,15 @@ func (s *OutputRoomEventConsumer) notifyJoinedPeeks(ctx context.Context, ev *rst
// TODO: check that it's a join and not a profile change (means unmarshalling prev_content)
if membership == spec.Join {
// check it's a local join
if _, _, err := s.cfg.Matrix.SplitLocalID('@', *ev.StateKey()); err != nil {
if ev.StateKey() == nil {
return sp, fmt.Errorf("unexpected nil state_key")
}
userID, err := s.rsAPI.QueryUserIDForSender(ctx, ev.RoomID(), spec.SenderID(*ev.StateKey()))
if err != nil || userID == nil {
return sp, fmt.Errorf("failed getting userID for sender: %w", err)
}
if !s.cfg.Matrix.IsLocalServerName(userID.Domain()) {
return sp, nil
}
@ -395,9 +403,15 @@ func (s *OutputRoomEventConsumer) onNewInviteEvent(
if msg.Event.StateKey() == nil {
return
}
if _, _, err := s.cfg.Matrix.SplitLocalID('@', *msg.Event.StateKey()); err != nil {
userID, err := s.rsAPI.QueryUserIDForSender(ctx, msg.Event.RoomID(), spec.SenderID(*msg.Event.StateKey()))
if err != nil || userID == nil {
return
}
if !s.cfg.Matrix.IsLocalServerName(userID.Domain()) {
return
}
pduPos, err := s.db.AddInviteEvent(ctx, msg.Event)
if err != nil {
sentry.CaptureException(err)
@ -440,7 +454,16 @@ func (s *OutputRoomEventConsumer) onRetireInviteEvent(
// Notify any active sync requests that the invite has been retired.
s.inviteStream.Advance(pduPos)
s.notifier.OnNewInvite(types.StreamingToken{InvitePosition: pduPos}, msg.TargetUserID)
userID, err := s.rsAPI.QueryUserIDForSender(ctx, msg.RoomID, msg.TargetSenderID)
if err != nil || userID == nil {
log.WithFields(log.Fields{
"event_id": msg.EventID,
"sender_id": msg.TargetSenderID,
log.ErrorKey: err,
}).Errorf("failed to find userID for sender")
return
}
s.notifier.OnNewInvite(types.StreamingToken{InvitePosition: pduPos}, userID.String())
}
func (s *OutputRoomEventConsumer) onNewPeek(

View file

@ -134,9 +134,17 @@ func ApplyHistoryVisibilityFilter(
}
}
// NOTSPEC: Always allow user to see their own membership events (spec contains more "rules")
if ev.Type() == spec.MRoomMember && ev.StateKeyEquals(userID) {
eventsFiltered = append(eventsFiltered, ev)
continue
user, err := spec.NewUserID(userID, true)
if err != nil {
return nil, err
}
senderID, err := rsAPI.QuerySenderIDForUser(ctx, ev.RoomID(), *user)
if err == nil {
if ev.Type() == spec.MRoomMember && ev.StateKeyEquals(string(senderID)) {
eventsFiltered = append(eventsFiltered, ev)
continue
}
}
// Always allow history evVis events on boundaries. This is done
// by setting the effective evVis to the least restrictive

View file

@ -169,12 +169,16 @@ func TrackChangedUsers(
if err != nil {
return nil, nil, err
}
for _, state := range stateRes.Rooms {
for roomID, state := range stateRes.Rooms {
for tuple, membership := range state {
if membership != spec.Join {
continue
}
queryRes.UserIDsToCount[tuple.StateKey]--
user, queryErr := rsAPI.QueryUserIDForSender(ctx, roomID, spec.SenderID(tuple.StateKey))
if queryErr != nil || user == nil {
continue
}
queryRes.UserIDsToCount[user.String()]--
}
}
@ -211,14 +215,18 @@ func TrackChangedUsers(
if err != nil {
return nil, left, err
}
for _, state := range stateRes.Rooms {
for roomID, state := range stateRes.Rooms {
for tuple, membership := range state {
if membership != spec.Join {
continue
}
// new user who we weren't previously sharing rooms with
if _, ok := queryRes.UserIDsToCount[tuple.StateKey]; !ok {
changed = append(changed, tuple.StateKey) // changed is returned
user, err := rsAPI.QueryUserIDForSender(ctx, roomID, spec.SenderID(tuple.StateKey))
if err != nil || user == nil {
continue
}
changed = append(changed, user.String()) // changed is returned
}
}
}

View file

@ -64,6 +64,10 @@ type mockRoomserverAPI struct {
roomIDToJoinedMembers map[string][]string
}
func (s *mockRoomserverAPI) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true)
}
// QueryRoomsForUser retrieves a list of room IDs matching the given query.
func (s *mockRoomserverAPI) QueryRoomsForUser(ctx context.Context, req *api.QueryRoomsForUserRequest, res *api.QueryRoomsForUserResponse) error {
return nil

View file

@ -20,6 +20,7 @@ import (
"time"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/api"
rstypes "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/types"
@ -36,7 +37,8 @@ import (
// the event, but the token has already advanced by the time they fetch it, resulting
// in missed events.
type Notifier struct {
lock *sync.RWMutex
lock *sync.RWMutex
rsAPI api.SyncRoomserverAPI
// A map of RoomID => Set<UserID> : Must only be accessed by the OnNewEvent goroutine
roomIDToJoinedUsers map[string]*userIDSet
// A map of RoomID => Set<UserID> : Must only be accessed by the OnNewEvent goroutine
@ -55,8 +57,9 @@ type Notifier struct {
// NewNotifier creates a new notifier set to the given sync position.
// In order for this to be of any use, the Notifier needs to be told all rooms and
// the joined users within each of them by calling Notifier.Load(*storage.SyncServerDatabase).
func NewNotifier() *Notifier {
func NewNotifier(rsAPI api.SyncRoomserverAPI) *Notifier {
return &Notifier{
rsAPI: rsAPI,
roomIDToJoinedUsers: make(map[string]*userIDSet),
roomIDToPeekingDevices: make(map[string]peekingDeviceSet),
userDeviceStreams: make(map[string]map[string]*UserDeviceStream),
@ -104,26 +107,32 @@ func (n *Notifier) OnNewEvent(
peekingDevicesToNotify := n._peekingDevices(ev.RoomID())
// If this is an invite, also add in the invitee to this list.
if ev.Type() == "m.room.member" && ev.StateKey() != nil {
targetUserID := *ev.StateKey()
membership, err := ev.Membership()
targetUserID, err := n.rsAPI.QueryUserIDForSender(context.Background(), ev.RoomID(), spec.SenderID(*ev.StateKey()))
if err != nil {
log.WithError(err).WithField("event_id", ev.EventID()).Errorf(
"Notifier.OnNewEvent: Failed to unmarshal member event",
"Notifier.OnNewEvent: Failed to find the userID for this event",
)
} else {
// Keep the joined user map up-to-date
switch membership {
case spec.Invite:
usersToNotify = append(usersToNotify, targetUserID)
case spec.Join:
// Manually append the new user's ID so they get notified
// along all members in the room
usersToNotify = append(usersToNotify, targetUserID)
n._addJoinedUser(ev.RoomID(), targetUserID)
case spec.Leave:
fallthrough
case spec.Ban:
n._removeJoinedUser(ev.RoomID(), targetUserID)
membership, err := ev.Membership()
if err != nil {
log.WithError(err).WithField("event_id", ev.EventID()).Errorf(
"Notifier.OnNewEvent: Failed to unmarshal member event",
)
} else {
// Keep the joined user map up-to-date
switch membership {
case spec.Invite:
usersToNotify = append(usersToNotify, targetUserID.String())
case spec.Join:
// Manually append the new user's ID so they get notified
// along all members in the room
usersToNotify = append(usersToNotify, targetUserID.String())
n._addJoinedUser(ev.RoomID(), targetUserID.String())
case spec.Leave:
fallthrough
case spec.Ban:
n._removeJoinedUser(ev.RoomID(), targetUserID.String())
}
}
}
}

View file

@ -22,9 +22,11 @@ import (
"testing"
"time"
"github.com/matrix-org/dendrite/roomserver/api"
rstypes "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/syncapi/types"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib/spec"
"github.com/matrix-org/util"
)
@ -105,9 +107,15 @@ func mustEqualPositions(t *testing.T, got, want types.StreamingToken) {
}
}
type TestRoomServer struct{ api.SyncRoomserverAPI }
func (t *TestRoomServer) QueryUserIDForSender(ctx context.Context, roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return spec.NewUserID(string(senderID), true)
}
// Test that the current position is returned if a request is already behind.
func TestImmediateNotification(t *testing.T) {
n := NewNotifier()
n := NewNotifier(&TestRoomServer{})
n.SetCurrentPosition(syncPositionBefore)
pos, err := waitForEvents(n, newTestSyncRequest(alice, aliceDev, syncPositionVeryOld))
if err != nil {
@ -118,7 +126,7 @@ func TestImmediateNotification(t *testing.T) {
// Test that new events to a joined room unblocks the request.
func TestNewEventAndJoinedToRoom(t *testing.T) {
n := NewNotifier()
n := NewNotifier(&TestRoomServer{})
n.SetCurrentPosition(syncPositionBefore)
n.setUsersJoinedToRooms(map[string][]string{
roomID: {alice, bob},
@ -144,7 +152,7 @@ func TestNewEventAndJoinedToRoom(t *testing.T) {
}
func TestCorrectStream(t *testing.T) {
n := NewNotifier()
n := NewNotifier(&TestRoomServer{})
n.SetCurrentPosition(syncPositionBefore)
stream := lockedFetchUserStream(n, bob, bobDev)
if stream.UserID != bob {
@ -156,7 +164,7 @@ func TestCorrectStream(t *testing.T) {
}
func TestCorrectStreamWakeup(t *testing.T) {
n := NewNotifier()
n := NewNotifier(&TestRoomServer{})
n.SetCurrentPosition(syncPositionBefore)
awoken := make(chan string)
@ -184,7 +192,7 @@ func TestCorrectStreamWakeup(t *testing.T) {
// Test that an invite unblocks the request
func TestNewInviteEventForUser(t *testing.T) {
n := NewNotifier()
n := NewNotifier(&TestRoomServer{})
n.SetCurrentPosition(syncPositionBefore)
n.setUsersJoinedToRooms(map[string][]string{
roomID: {alice, bob},
@ -241,7 +249,7 @@ func TestEDUWakeup(t *testing.T) {
// Test that all blocked requests get woken up on a new event.
func TestMultipleRequestWakeup(t *testing.T) {
n := NewNotifier()
n := NewNotifier(&TestRoomServer{})
n.SetCurrentPosition(syncPositionBefore)
n.setUsersJoinedToRooms(map[string][]string{
roomID: {alice, bob},
@ -278,7 +286,7 @@ func TestMultipleRequestWakeup(t *testing.T) {
func TestNewEventAndWasPreviouslyJoinedToRoom(t *testing.T) {
// listen as bob. Make bob leave room. Make alice send event to room.
// Make sure alice gets woken up only and not bob as well.
n := NewNotifier()
n := NewNotifier(&TestRoomServer{})
n.SetCurrentPosition(syncPositionBefore)
n.setUsersJoinedToRooms(map[string][]string{
roomID: {alice, bob},

View file

@ -85,9 +85,16 @@ func Context(
*filter.Rooms = append(*filter.Rooms, roomID)
}
userID, err := spec.NewUserID(device.UserID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.InvalidParam("Device UserID is invalid"),
}
}
ctx := req.Context()
membershipRes := roomserver.QueryMembershipForUserResponse{}
membershipReq := roomserver.QueryMembershipForUserRequest{UserID: device.UserID, RoomID: roomID}
membershipReq := roomserver.QueryMembershipForUserRequest{UserID: *userID, RoomID: roomID}
if err = rsAPI.QueryMembershipForUser(ctx, &membershipReq, &membershipRes); err != nil {
logrus.WithError(err).Error("unable to query membership")
return util.JSONResponse{
@ -217,12 +224,9 @@ func Context(
}
}
sender := spec.UserID{}
userID, err := rsAPI.QueryUserIDForSender(ctx, requestedEvent.RoomID(), requestedEvent.SenderID())
if err == nil && userID != nil {
sender = *userID
}
ev := synctypes.ToClientEvent(&requestedEvent, synctypes.FormatAll, sender)
ev := synctypes.ToClientEventDefault(func(roomID string, senderID spec.SenderID) (*spec.UserID, error) {
return rsAPI.QueryUserIDForSender(ctx, roomID, senderID)
}, requestedEvent)
response := ContextRespsonse{
Event: &ev,
EventsAfter: eventsAfterClient,

View file

@ -106,8 +106,17 @@ func GetEvent(
if err == nil && senderUserID != nil {
sender = *senderUserID
}
sk := events[0].StateKey()
if sk != nil && *sk != "" {
skUserID, err := rsAPI.QueryUserIDForSender(ctx, events[0].RoomID(), spec.SenderID(*events[0].StateKey()))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
}
}
return util.JSONResponse{
Code: http.StatusOK,
JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, sender),
JSON: synctypes.ToClientEvent(events[0], synctypes.FormatAll, sender, sk),
}
}

View file

@ -59,14 +59,21 @@ func GetMemberships(
syncDB storage.Database, rsAPI api.SyncRoomserverAPI,
joinedOnly bool, membership, notMembership *string, at string,
) util.JSONResponse {
userID, err := spec.NewUserID(device.UserID, true)
if err != nil {
return util.JSONResponse{
Code: http.StatusBadRequest,
JSON: spec.InvalidParam("Device UserID is invalid"),
}
}
queryReq := api.QueryMembershipForUserRequest{
RoomID: roomID,
UserID: device.UserID,
UserID: *userID,
}
var queryRes api.QueryMembershipForUserResponse
if err := rsAPI.QueryMembershipForUser(req.Context(), &queryReq, &queryRes); err != nil {
util.GetLogger(req.Context()).WithError(err).Error("rsAPI.QueryMembershipsForRoom failed")
if queryErr := rsAPI.QueryMembershipForUser(req.Context(), &queryReq, &queryRes); queryErr != nil {
util.GetLogger(req.Context()).WithError(queryErr).Error("rsAPI.QueryMembershipsForRoom failed")
return util.JSONResponse{
Code: http.StatusInternalServerError,
JSON: spec.InternalServerError{},

View file

@ -296,9 +296,13 @@ func OnIncomingMessagesRequest(
}
func getMembershipForUser(ctx context.Context, roomID, userID string, rsAPI api.SyncRoomserverAPI) (resp api.QueryMembershipForUserResponse, err error) {
fullUserID, err := spec.NewUserID(userID, true)
if err != nil {
return resp, err
}
req := api.QueryMembershipForUserRequest{
RoomID: roomID,
UserID: userID,
UserID: *fullUserID,
}
if err := rsAPI.QueryMembershipForUser(ctx, &req, &resp); err != nil {
return api.QueryMembershipForUserResponse{}, err

View file

@ -119,9 +119,18 @@ func Relations(
if err == nil && userID != nil {
sender = *userID
}
sk := event.StateKey()
if sk != nil && *sk != "" {
skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), spec.SenderID(*event.StateKey()))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
}
}
res.Chunk = append(
res.Chunk,
synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender),
synctypes.ToClientEvent(event.PDU, synctypes.FormatAll, sender, sk),
)
}

View file

@ -235,6 +235,15 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
if err == nil && userID != nil {
sender = *userID
}
sk := event.StateKey()
if sk != nil && *sk != "" {
skUserID, err := rsAPI.QueryUserIDForSender(req.Context(), event.RoomID(), spec.SenderID(*event.StateKey()))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
}
}
results = append(results, Result{
Context: SearchContextResponse{
Start: startToken.String(),
@ -248,7 +257,7 @@ func Search(req *http.Request, device *api.Device, syncDB storage.Database, fts
ProfileInfo: profileInfos,
},
Rank: eventScore[event.EventID()].Score,
Result: synctypes.ToClientEvent(event, synctypes.FormatAll, sender),
Result: synctypes.ToClientEvent(event, synctypes.FormatAll, sender, sk),
})
roomGroup := groups[event.RoomID()]
roomGroup.Results = append(roomGroup.Results, event.EventID())

View file

@ -507,8 +507,20 @@ func (d *Database) CleanSendToDeviceUpdates(
// getMembershipFromEvent returns the value of content.membership iff the event is a state event
// with type 'm.room.member' and state_key of userID. Otherwise, an empty string is returned.
func getMembershipFromEvent(ev gomatrixserverlib.PDU, userID string) (string, string) {
if ev.Type() != "m.room.member" || !ev.StateKeyEquals(userID) {
func getMembershipFromEvent(ctx context.Context, ev gomatrixserverlib.PDU, userID string, rsAPI api.SyncRoomserverAPI) (string, string) {
if ev.StateKey() == nil || *ev.StateKey() == "" {
return "", ""
}
fullUser, err := spec.NewUserID(userID, true)
if err != nil {
return "", ""
}
senderID, err := rsAPI.QuerySenderIDForUser(ctx, ev.RoomID(), *fullUser)
if err != nil {
return "", ""
}
if ev.Type() != "m.room.member" || !ev.StateKeyEquals(string(senderID)) {
return "", ""
}
membership, err := ev.Membership()

View file

@ -430,7 +430,7 @@ func (d *DatabaseTransaction) GetStateDeltas(
for _, ev := range stateStreamEvents {
// Look for our membership in the state events and skip over any
// membership events that are not related to us.
membership, prevMembership := getMembershipFromEvent(ev.PDU, userID)
membership, prevMembership := getMembershipFromEvent(ctx, ev.PDU, userID, rsAPI)
if membership == "" {
continue
}
@ -556,7 +556,7 @@ func (d *DatabaseTransaction) GetStateDeltasForFullStateSync(
for roomID, stateStreamEvents := range state {
for _, ev := range stateStreamEvents {
if membership, _ := getMembershipFromEvent(ev.PDU, userID); membership != "" {
if membership, _ := getMembershipFromEvent(ctx, ev.PDU, userID, rsAPI); membership != "" {
if membership != spec.Join { // We've already added full state for all joined rooms above.
deltas[roomID] = types.StateDelta{
Membership: membership,

View file

@ -70,11 +70,20 @@ func (p *InviteStreamProvider) IncrementalSync(
user = *sender
}
sk := inviteEvent.StateKey()
if sk != nil && *sk != "" {
skUserID, err := p.rsAPI.QueryUserIDForSender(ctx, inviteEvent.RoomID(), spec.SenderID(*inviteEvent.StateKey()))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
}
}
// skip ignored user events
if _, ok := req.IgnoredUsers.List[user.String()]; ok {
continue
}
ir := types.NewInviteResponse(inviteEvent, user)
ir := types.NewInviteResponse(inviteEvent, user, sk)
req.Response.Rooms.Invite[roomID] = ir
}

View file

@ -605,13 +605,17 @@ func (p *PDUStreamProvider) lazyLoadMembers(
// If this is a gapped incremental sync, we still want this membership
isGappedIncremental := limited && incremental
// We want this users membership event, keep it in the list
stateKey := *event.StateKey()
if _, ok := timelineUsers[stateKey]; ok || isGappedIncremental || stateKey == device.UserID {
userID := ""
stateKeyUserID, err := p.rsAPI.QueryUserIDForSender(ctx, roomID, spec.SenderID(*event.StateKey()))
if err == nil && stateKeyUserID != nil {
userID = stateKeyUserID.String()
}
if _, ok := timelineUsers[userID]; ok || isGappedIncremental || userID == device.UserID {
newStateEvents = append(newStateEvents, event)
if !stateFilter.IncludeRedundantMembers {
p.lazyLoadCache.StoreLazyLoadedUser(device, roomID, stateKey, event.EventID())
p.lazyLoadCache.StoreLazyLoadedUser(device, roomID, userID, event.EventID())
}
delete(timelineUsers, stateKey)
delete(timelineUsers, userID)
}
} else {
newStateEvents = append(newStateEvents, event)

View file

@ -60,7 +60,7 @@ func AddPublicRoutes(
}
eduCache := caching.NewTypingCache()
notifier := notifier.NewNotifier()
notifier := notifier.NewNotifier(rsAPI)
streams := streams.NewSyncStreamProviders(syncDB, userAPI, rsAPI, eduCache, caches, notifier)
notifier.SetCurrentPosition(streams.Latest(context.Background()))
if err = notifier.Load(context.Background(), syncDB); err != nil {

View file

@ -55,18 +55,27 @@ func ToClientEvents(serverEvs []gomatrixserverlib.PDU, format ClientEventFormat,
if err == nil && userID != nil {
sender = *userID
}
evs = append(evs, ToClientEvent(se, format, sender))
sk := se.StateKey()
if sk != nil && *sk != "" {
skUserID, err := userIDForSender(se.RoomID(), spec.SenderID(*sk))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
}
}
evs = append(evs, ToClientEvent(se, format, sender, sk))
}
return evs
}
// ToClientEvent converts a single server event to a client event.
func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender spec.UserID) ClientEvent {
func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender spec.UserID, stateKey *string) ClientEvent {
ce := ClientEvent{
Content: spec.RawJSON(se.Content()),
Sender: sender.String(),
Type: se.Type(),
StateKey: se.StateKey(),
StateKey: stateKey,
Unsigned: spec.RawJSON(se.Unsigned()),
OriginServerTS: se.OriginServerTS(),
EventID: se.EventID(),
@ -77,3 +86,23 @@ func ToClientEvent(se gomatrixserverlib.PDU, format ClientEventFormat, sender sp
}
return ce
}
// ToClientEvent converts a single server event to a client event.
// It provides default logic for event.SenderID & event.StateKey -> userID conversions.
func ToClientEventDefault(userIDQuery spec.UserIDForSender, event gomatrixserverlib.PDU) ClientEvent {
sender := spec.UserID{}
userID, err := userIDQuery(event.RoomID(), event.SenderID())
if err == nil && userID != nil {
sender = *userID
}
sk := event.StateKey()
if sk != nil && *sk != "" {
skUserID, err := userIDQuery(event.RoomID(), spec.SenderID(*event.StateKey()))
if err == nil && skUserID != nil {
skString := skUserID.String()
sk = &skString
}
}
return ToClientEvent(event, FormatAll, sender, sk)
}

View file

@ -48,7 +48,8 @@ func TestToClientEvent(t *testing.T) { // nolint: gocyclo
if err != nil {
t.Fatalf("failed to create userID: %s", err)
}
ce := ToClientEvent(ev, FormatAll, *userID)
sk := ""
ce := ToClientEvent(ev, FormatAll, *userID, &sk)
if ce.EventID != ev.EventID() {
t.Errorf("ClientEvent.EventID: wanted %s, got %s", ev.EventID(), ce.EventID)
}
@ -107,7 +108,8 @@ func TestToClientFormatSync(t *testing.T) {
if err != nil {
t.Fatalf("failed to create userID: %s", err)
}
ce := ToClientEvent(ev, FormatSync, *userID)
sk := ""
ce := ToClientEvent(ev, FormatSync, *userID, &sk)
if ce.RoomID != "" {
t.Errorf("ClientEvent.RoomID: wanted '', got %s", ce.RoomID)
}

View file

@ -539,7 +539,7 @@ type InviteResponse struct {
}
// NewInviteResponse creates an empty response with initialised arrays.
func NewInviteResponse(event *types.HeaderedEvent, userID spec.UserID) *InviteResponse {
func NewInviteResponse(event *types.HeaderedEvent, userID spec.UserID, stateKey *string) *InviteResponse {
res := InviteResponse{}
res.InviteState.Events = []json.RawMessage{}
@ -552,7 +552,7 @@ func NewInviteResponse(event *types.HeaderedEvent, userID spec.UserID) *InviteRe
// Then we'll see if we can create a partial of the invite event itself.
// This is needed for clients to work out *who* sent the invite.
inviteEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatSync, userID)
inviteEvent := synctypes.ToClientEvent(event.PDU, synctypes.FormatSync, userID, stateKey)
inviteEvent.Unsigned = nil
if ev, err := json.Marshal(inviteEvent); err == nil {
res.InviteState.Events = append(res.InviteState.Events, ev)

View file

@ -65,8 +65,14 @@ func TestNewInviteResponse(t *testing.T) {
if err != nil {
t.Fatal(err)
}
skUserID, err := spec.NewUserID("@neilalexander:dendrite.neilalexander.dev", true)
if err != nil {
t.Fatal(err)
}
skString := skUserID.String()
sk := &skString
res := NewInviteResponse(&types.HeaderedEvent{PDU: ev}, *sender)
res := NewInviteResponse(&types.HeaderedEvent{PDU: ev}, *sender, sk)
j, err := json.Marshal(res)
if err != nil {
t.Fatal(err)