Membership updater refactoring (#2541)

* Membership updater refactoring

* Pass in membership state

* Use membership check rather than referring to state directly

* Delete irrelevant membership states

* We don't need the leave event after all

* Tweaks

* Put a log entry in that I might stand a chance of finding

* Be less panicky

* Tweak invite handling

* Don't freak if we can't find the event NID

* Use event NID from `types.Event`

* Clean up

* Better invite handling

* Placate the almighty linter

* Blacklist a Sytest which is otherwise fine under Complement for reasons I don't understand

* Fix the sytest after all (thanks @S7evinK for the spot)
This commit is contained in:
Neil Alexander 2022-07-22 14:44:04 +01:00 committed by GitHub
parent a201b4400d
commit f0c8a03649
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 181 additions and 245 deletions

View file

@ -26,7 +26,6 @@ import (
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/sirupsen/logrus"
) )
// InviteV2 implements /_matrix/federation/v2/invite/{roomID}/{eventID} // InviteV2 implements /_matrix/federation/v2/invite/{roomID}/{eventID}
@ -144,7 +143,6 @@ func processInvite(
// Check that the event is signed by the server sending the request. // Check that the event is signed by the server sending the request.
redacted, err := gomatrixserverlib.RedactEventJSON(event.JSON(), event.Version()) redacted, err := gomatrixserverlib.RedactEventJSON(event.JSON(), event.Version())
if err != nil { if err != nil {
logrus.WithError(err).Errorf("XXX: invite.go")
return util.JSONResponse{ return util.JSONResponse{
Code: http.StatusBadRequest, Code: http.StatusBadRequest,
JSON: jsonerror.BadJSON("The event JSON could not be redacted"), JSON: jsonerror.BadJSON("The event JSON could not be redacted"),

View file

@ -12,6 +12,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/state" "github.com/matrix-org/dendrite/roomserver/state"
"github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
@ -21,14 +22,14 @@ import (
// Move these to a more sensible place. // Move these to a more sensible place.
func UpdateToInviteMembership( func UpdateToInviteMembership(
mu *shared.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent, mu *shared.MembershipUpdater, add *types.Event, updates []api.OutputEvent,
roomVersion gomatrixserverlib.RoomVersion, roomVersion gomatrixserverlib.RoomVersion,
) ([]api.OutputEvent, error) { ) ([]api.OutputEvent, error) {
// We may have already sent the invite to the user, either because we are // We may have already sent the invite to the user, either because we are
// reprocessing this event, or because the we received this invite from a // reprocessing this event, or because the we received this invite from a
// remote server via the federation invite API. In those cases we don't need // remote server via the federation invite API. In those cases we don't need
// to send the event. // to send the event.
needsSending, err := mu.SetToInvite(add) needsSending, retired, err := mu.Update(tables.MembershipStateInvite, add)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -38,13 +39,23 @@ func UpdateToInviteMembership(
// room event stream. This ensures that the consumers only have to // room event stream. This ensures that the consumers only have to
// consider a single stream of events when determining whether a user // consider a single stream of events when determining whether a user
// is invited, rather than having to combine multiple streams themselves. // is invited, rather than having to combine multiple streams themselves.
onie := api.OutputNewInviteEvent{
Event: add.Headered(roomVersion),
RoomVersion: roomVersion,
}
updates = append(updates, api.OutputEvent{ updates = append(updates, api.OutputEvent{
Type: api.OutputTypeNewInviteEvent, Type: api.OutputTypeNewInviteEvent,
NewInviteEvent: &onie, NewInviteEvent: &api.OutputNewInviteEvent{
Event: add.Headered(roomVersion),
RoomVersion: roomVersion,
},
})
}
for _, eventID := range retired {
updates = append(updates, api.OutputEvent{
Type: api.OutputTypeRetireInviteEvent,
RetireInviteEvent: &api.OutputRetireInviteEvent{
EventID: eventID,
Membership: gomatrixserverlib.Join,
RetiredByEventID: add.EventID(),
TargetUserID: *add.StateKey(),
},
}) })
} }
return updates, nil return updates, nil

View file

@ -21,6 +21,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/api" "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/internal/helpers" "github.com/matrix-org/dendrite/roomserver/internal/helpers"
"github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/opentracing/opentracing-go" "github.com/opentracing/opentracing-go"
@ -60,20 +61,14 @@ func (r *Inputer) updateMemberships(
var updates []api.OutputEvent var updates []api.OutputEvent
for _, change := range changes { for _, change := range changes {
var ae *gomatrixserverlib.Event var ae *types.Event
var re *gomatrixserverlib.Event var re *types.Event
targetUserNID := change.EventStateKeyNID targetUserNID := change.EventStateKeyNID
if change.removedEventNID != 0 { if change.removedEventNID != 0 {
ev, _ := helpers.EventMap(events).Lookup(change.removedEventNID) re, _ = helpers.EventMap(events).Lookup(change.removedEventNID)
if ev != nil {
re = ev.Event
}
} }
if change.addedEventNID != 0 { if change.addedEventNID != 0 {
ev, _ := helpers.EventMap(events).Lookup(change.addedEventNID) ae, _ = helpers.EventMap(events).Lookup(change.addedEventNID)
if ev != nil {
ae = ev.Event
}
} }
if updates, err = r.updateMembership(updater, targetUserNID, re, ae, updates); err != nil { if updates, err = r.updateMembership(updater, targetUserNID, re, ae, updates); err != nil {
return nil, err return nil, err
@ -85,30 +80,27 @@ func (r *Inputer) updateMemberships(
func (r *Inputer) updateMembership( func (r *Inputer) updateMembership(
updater *shared.RoomUpdater, updater *shared.RoomUpdater,
targetUserNID types.EventStateKeyNID, targetUserNID types.EventStateKeyNID,
remove, add *gomatrixserverlib.Event, remove, add *types.Event,
updates []api.OutputEvent, updates []api.OutputEvent,
) ([]api.OutputEvent, error) { ) ([]api.OutputEvent, error) {
var err error var err error
// Default the membership to Leave if no event was added or removed. // Default the membership to Leave if no event was added or removed.
oldMembership := gomatrixserverlib.Leave
newMembership := gomatrixserverlib.Leave newMembership := gomatrixserverlib.Leave
if remove != nil {
oldMembership, err = remove.Membership()
if err != nil {
return nil, err
}
}
if add != nil { if add != nil {
newMembership, err = add.Membership() newMembership, err = add.Membership()
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
if oldMembership == newMembership && newMembership != gomatrixserverlib.Join {
// If the membership is the same then nothing changed and we can return var targetLocal bool
// immediately, unless it's a Join update (e.g. profile update). if add != nil {
return updates, nil targetLocal = r.isLocalTarget(add)
}
mu, err := updater.MembershipUpdater(targetUserNID, targetLocal)
if err != nil {
return nil, err
} }
// In an ideal world, we shouldn't ever have "add" be nil and "remove" be // In an ideal world, we shouldn't ever have "add" be nil and "remove" be
@ -120,17 +112,10 @@ func (r *Inputer) updateMembership(
// after a state reset, often thinking that the user was still joined to // after a state reset, often thinking that the user was still joined to
// the room even though the room state said otherwise, and this would prevent // the room even though the room state said otherwise, and this would prevent
// the user from being able to attempt to rejoin the room without modifying // the user from being able to attempt to rejoin the room without modifying
// the database. So instead what we'll do is we'll just update the membership // the database. So instead we're going to remove the membership from the
// table to say that the user is "leave" and we'll use the old event to // database altogether, so that it doesn't create future problems.
// avoid nil pointer exceptions on the code path that follows. if add == nil && remove != nil {
if add == nil { return nil, mu.Delete()
add = remove
newMembership = gomatrixserverlib.Leave
}
mu, err := updater.MembershipUpdater(targetUserNID, r.isLocalTarget(add))
if err != nil {
return nil, err
} }
switch newMembership { switch newMembership {
@ -149,7 +134,7 @@ func (r *Inputer) updateMembership(
} }
} }
func (r *Inputer) isLocalTarget(event *gomatrixserverlib.Event) bool { func (r *Inputer) isLocalTarget(event *types.Event) bool {
isTargetLocalUser := false isTargetLocalUser := false
if statekey := event.StateKey(); statekey != nil { if statekey := event.StateKey(); statekey != nil {
_, domain, _ := gomatrixserverlib.SplitID('@', *statekey) _, domain, _ := gomatrixserverlib.SplitID('@', *statekey)
@ -159,82 +144,62 @@ func (r *Inputer) isLocalTarget(event *gomatrixserverlib.Event) bool {
} }
func updateToJoinMembership( func updateToJoinMembership(
mu *shared.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent, mu *shared.MembershipUpdater, add *types.Event, updates []api.OutputEvent,
) ([]api.OutputEvent, error) { ) ([]api.OutputEvent, error) {
// If the user is already marked as being joined, we call SetToJoin to update
// the event ID then we can return immediately. Retired is ignored as there
// is no invite event to retire.
if mu.IsJoin() {
_, err := mu.SetToJoin(add.Sender(), add.EventID(), true)
if err != nil {
return nil, err
}
return updates, nil
}
// When we mark a user as being joined we will invalidate any invites that // When we mark a user as being joined we will invalidate any invites that
// are active for that user. We notify the consumers that the invites have // are active for that user. We notify the consumers that the invites have
// been retired using a special event, even though they could infer this // been retired using a special event, even though they could infer this
// by studying the state changes in the room event stream. // by studying the state changes in the room event stream.
retired, err := mu.SetToJoin(add.Sender(), add.EventID(), false) _, retired, err := mu.Update(tables.MembershipStateJoin, add)
if err != nil { if err != nil {
return nil, err return nil, err
} }
for _, eventID := range retired { for _, eventID := range retired {
orie := api.OutputRetireInviteEvent{ updates = append(updates, api.OutputEvent{
Type: api.OutputTypeRetireInviteEvent,
RetireInviteEvent: &api.OutputRetireInviteEvent{
EventID: eventID, EventID: eventID,
Membership: gomatrixserverlib.Join, Membership: gomatrixserverlib.Join,
RetiredByEventID: add.EventID(), RetiredByEventID: add.EventID(),
TargetUserID: *add.StateKey(), TargetUserID: *add.StateKey(),
} },
updates = append(updates, api.OutputEvent{
Type: api.OutputTypeRetireInviteEvent,
RetireInviteEvent: &orie,
}) })
} }
return updates, nil return updates, nil
} }
func updateToLeaveMembership( func updateToLeaveMembership(
mu *shared.MembershipUpdater, add *gomatrixserverlib.Event, mu *shared.MembershipUpdater, add *types.Event,
newMembership string, updates []api.OutputEvent, newMembership string, updates []api.OutputEvent,
) ([]api.OutputEvent, error) { ) ([]api.OutputEvent, error) {
// If the user is already neither joined, nor invited to the room then we
// can return immediately.
if mu.IsLeave() {
return updates, nil
}
// When we mark a user as having left we will invalidate any invites that // When we mark a user as having left we will invalidate any invites that
// are active for that user. We notify the consumers that the invites have // are active for that user. We notify the consumers that the invites have
// been retired using a special event, even though they could infer this // been retired using a special event, even though they could infer this
// by studying the state changes in the room event stream. // by studying the state changes in the room event stream.
retired, err := mu.SetToLeave(add.Sender(), add.EventID()) _, retired, err := mu.Update(tables.MembershipStateLeaveOrBan, add)
if err != nil { if err != nil {
return nil, err return nil, err
} }
for _, eventID := range retired { for _, eventID := range retired {
orie := api.OutputRetireInviteEvent{ updates = append(updates, api.OutputEvent{
Type: api.OutputTypeRetireInviteEvent,
RetireInviteEvent: &api.OutputRetireInviteEvent{
EventID: eventID, EventID: eventID,
Membership: newMembership, Membership: newMembership,
RetiredByEventID: add.EventID(), RetiredByEventID: add.EventID(),
TargetUserID: *add.StateKey(), TargetUserID: *add.StateKey(),
} },
updates = append(updates, api.OutputEvent{
Type: api.OutputTypeRetireInviteEvent,
RetireInviteEvent: &orie,
}) })
} }
return updates, nil return updates, nil
} }
func updateToKnockMembership( func updateToKnockMembership(
mu *shared.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent, mu *shared.MembershipUpdater, add *types.Event, updates []api.OutputEvent,
) ([]api.OutputEvent, error) { ) ([]api.OutputEvent, error) {
if mu.IsLeave() { if _, _, err := mu.Update(tables.MembershipStateKnock, add); err != nil {
_, err := mu.SetToKnock(add)
if err != nil {
return nil, err return nil, err
} }
}
return updates, nil return updates, nil
} }

View file

@ -39,11 +39,13 @@ type Inviter struct {
Inputer *input.Inputer Inputer *input.Inputer
} }
// nolint:gocyclo
func (r *Inviter) PerformInvite( func (r *Inviter) PerformInvite(
ctx context.Context, ctx context.Context,
req *api.PerformInviteRequest, req *api.PerformInviteRequest,
res *api.PerformInviteResponse, res *api.PerformInviteResponse,
) ([]api.OutputEvent, error) { ) ([]api.OutputEvent, error) {
var outputUpdates []api.OutputEvent
event := req.Event event := req.Event
if event.StateKey() == nil { if event.StateKey() == nil {
return nil, fmt.Errorf("invite must be a state event") return nil, fmt.Errorf("invite must be a state event")
@ -66,6 +68,13 @@ func (r *Inviter) PerformInvite(
} }
isTargetLocal := domain == r.Cfg.Matrix.ServerName isTargetLocal := domain == r.Cfg.Matrix.ServerName
isOriginLocal := event.Origin() == r.Cfg.Matrix.ServerName isOriginLocal := event.Origin() == r.Cfg.Matrix.ServerName
if !isOriginLocal && !isTargetLocal {
res.Error = &api.PerformError{
Code: api.PerformErrorBadRequest,
Msg: "The invite must be either from or to a local user",
}
return nil, nil
}
logger := util.GetLogger(ctx).WithFields(map[string]interface{}{ logger := util.GetLogger(ctx).WithFields(map[string]interface{}{
"inviter": event.Sender(), "inviter": event.Sender(),
@ -97,6 +106,34 @@ func (r *Inviter) PerformInvite(
} }
} }
updateMembershipTableManually := func() ([]api.OutputEvent, error) {
var updater *shared.MembershipUpdater
if updater, err = r.DB.MembershipUpdater(ctx, roomID, targetUserID, isTargetLocal, req.RoomVersion); err != nil {
return nil, fmt.Errorf("r.DB.MembershipUpdater: %w", err)
}
outputUpdates, err = helpers.UpdateToInviteMembership(updater, &types.Event{
EventNID: 0,
Event: event.Unwrap(),
}, outputUpdates, req.Event.RoomVersion)
if err != nil {
return nil, fmt.Errorf("updateToInviteMembership: %w", err)
}
if err = updater.Commit(); err != nil {
return nil, fmt.Errorf("updater.Commit: %w", err)
}
logger.Debugf("updated membership to invite and sending invite OutputEvent")
return outputUpdates, nil
}
if (info == nil || info.IsStub) && !isOriginLocal && isTargetLocal {
// The invite came in over federation for a room that we don't know about
// yet. We need to handle this a bit differently to most invites because
// we don't know the room state, therefore the roomserver can't process
// an input event. Instead we will update the membership table with the
// new invite and generate an output event.
return updateMembershipTableManually()
}
var isAlreadyJoined bool var isAlreadyJoined bool
if info != nil { if info != nil {
_, isAlreadyJoined, _, err = r.DB.GetMembership(ctx, info.RoomNID, *event.StateKey()) _, isAlreadyJoined, _, err = r.DB.GetMembership(ctx, info.RoomNID, *event.StateKey())
@ -140,31 +177,13 @@ func (r *Inviter) PerformInvite(
return nil, nil return nil, nil
} }
// If the invite originated remotely then we can't send an
// InputRoomEvent for the invite as it will never pass auth checks
// due to lacking room state, but we still need to tell the client
// about the invite so we can accept it, hence we return an output
// event to send to the Sync API.
if !isOriginLocal { if !isOriginLocal {
// The invite originated over federation. Process the membership return updateMembershipTableManually()
// update, which will notify the sync API etc about the incoming
// invite. We do NOT send an InputRoomEvent for the invite as it
// will never pass auth checks due to lacking room state, but we
// still need to tell the client about the invite so we can accept
// it, hence we return an output event to send to the sync api.
var updater *shared.MembershipUpdater
updater, err = r.DB.MembershipUpdater(ctx, roomID, targetUserID, isTargetLocal, req.RoomVersion)
if err != nil {
return nil, fmt.Errorf("r.DB.MembershipUpdater: %w", err)
}
unwrapped := event.Unwrap()
var outputUpdates []api.OutputEvent
outputUpdates, err = helpers.UpdateToInviteMembership(updater, unwrapped, nil, req.Event.RoomVersion)
if err != nil {
return nil, fmt.Errorf("updateToInviteMembership: %w", err)
}
if err = updater.Commit(); err != nil {
return nil, fmt.Errorf("updater.Commit: %w", err)
}
logger.Debugf("updated membership to invite and sending invite OutputEvent")
return outputUpdates, nil
} }
// The invite originated locally. Therefore we have a responsibility to // The invite originated locally. Therefore we have a responsibility to
@ -229,12 +248,11 @@ func (r *Inviter) PerformInvite(
Code: api.PerformErrorNotAllowed, Code: api.PerformErrorNotAllowed,
} }
logger.WithError(err).WithField("event_id", event.EventID()).Error("r.InputRoomEvents failed") logger.WithError(err).WithField("event_id", event.EventID()).Error("r.InputRoomEvents failed")
return nil, nil
} }
// Don't notify the sync api of this event in the same way as a federated invite so the invitee // Don't notify the sync api of this event in the same way as a federated invite so the invitee
// gets the invite, as the roomserver will do this when it processes the m.room.member invite. // gets the invite, as the roomserver will do this when it processes the m.room.member invite.
return nil, nil return outputUpdates, nil
} }
func buildInviteStrippedState( func buildInviteStrippedState(

View file

@ -268,21 +268,19 @@ func (r *Joiner) performJoinRoomByID(
case nil: case nil:
// The room join is local. Send the new join event into the // The room join is local. Send the new join event into the
// roomserver. First of all check that the user isn't already // roomserver. First of all check that the user isn't already
// a member of the room. // a member of the room. This is best-effort (as in we won't
alreadyJoined := false // fail if we can't find the existing membership) because there
for _, se := range buildRes.StateEvents { // is really no harm in just sending another membership event.
if !se.StateKeyEquals(userID) { membershipReq := &api.QueryMembershipForUserRequest{
continue RoomID: req.RoomIDOrAlias,
} UserID: userID,
if membership, merr := se.Membership(); merr == nil {
alreadyJoined = (membership == gomatrixserverlib.Join)
break
}
} }
membershipRes := &api.QueryMembershipForUserResponse{}
_ = r.Queryer.QueryMembershipForUser(ctx, membershipReq, membershipRes)
// If we haven't already joined the room then send an event // If we haven't already joined the room then send an event
// into the room changing our membership status. // into the room changing our membership status.
if !alreadyJoined { if !membershipRes.RoomExists || !membershipRes.IsInRoom {
inputReq := rsAPI.InputRoomEventsRequest{ inputReq := rsAPI.InputRoomEventsRequest{
InputRoomEvents: []rsAPI.InputRoomEvent{ InputRoomEvents: []rsAPI.InputRoomEvent{
{ {

View file

@ -228,14 +228,14 @@ func (r *Leaver) performFederatedRejectInvite(
util.GetLogger(ctx).WithError(err).Errorf("failed to get MembershipUpdater, still retiring invite event") util.GetLogger(ctx).WithError(err).Errorf("failed to get MembershipUpdater, still retiring invite event")
} }
if updater != nil { if updater != nil {
if _, err = updater.SetToLeave(req.UserID, eventID); err != nil { if err = updater.Delete(); err != nil {
util.GetLogger(ctx).WithError(err).Errorf("failed to set membership to leave, still retiring invite event") util.GetLogger(ctx).WithError(err).Errorf("failed to delete membership, still retiring invite event")
if err = updater.Rollback(); err != nil { if err = updater.Rollback(); err != nil {
util.GetLogger(ctx).WithError(err).Errorf("failed to rollback membership leave, still retiring invite event") util.GetLogger(ctx).WithError(err).Errorf("failed to rollback deleting membership, still retiring invite event")
} }
} else { } else {
if err = updater.Commit(); err != nil { if err = updater.Commit(); err != nil {
util.GetLogger(ctx).WithError(err).Errorf("failed to commit membership update, still retiring invite event") util.GetLogger(ctx).WithError(err).Errorf("failed to commit deleting membership, still retiring invite event")
} }
} }
} }

View file

@ -118,6 +118,9 @@ const updateMembershipForgetRoom = "" +
"UPDATE roomserver_membership SET forgotten = $3" + "UPDATE roomserver_membership SET forgotten = $3" +
" WHERE room_nid = $1 AND target_nid = $2" " WHERE room_nid = $1 AND target_nid = $2"
const deleteMembershipSQL = "" +
"DELETE FROM roomserver_membership WHERE room_nid = $1 AND target_nid = $2"
const selectRoomsWithMembershipSQL = "" + const selectRoomsWithMembershipSQL = "" +
"SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2 and forgotten = false" "SELECT room_nid FROM roomserver_membership WHERE membership_nid = $1 AND target_nid = $2 and forgotten = false"
@ -165,6 +168,7 @@ type membershipStatements struct {
updateMembershipForgetRoomStmt *sql.Stmt updateMembershipForgetRoomStmt *sql.Stmt
selectLocalServerInRoomStmt *sql.Stmt selectLocalServerInRoomStmt *sql.Stmt
selectServerInRoomStmt *sql.Stmt selectServerInRoomStmt *sql.Stmt
deleteMembershipStmt *sql.Stmt
} }
func CreateMembershipTable(db *sql.DB) error { func CreateMembershipTable(db *sql.DB) error {
@ -191,6 +195,7 @@ func PrepareMembershipTable(db *sql.DB) (tables.Membership, error) {
{&s.updateMembershipForgetRoomStmt, updateMembershipForgetRoom}, {&s.updateMembershipForgetRoomStmt, updateMembershipForgetRoom},
{&s.selectLocalServerInRoomStmt, selectLocalServerInRoomSQL}, {&s.selectLocalServerInRoomStmt, selectLocalServerInRoomSQL},
{&s.selectServerInRoomStmt, selectServerInRoomSQL}, {&s.selectServerInRoomStmt, selectServerInRoomSQL},
{&s.deleteMembershipStmt, deleteMembershipSQL},
}.Prepare(db) }.Prepare(db)
} }
@ -412,3 +417,13 @@ func (s *membershipStatements) SelectServerInRoom(
} }
return roomNID == nid, nil return roomNID == nid, nil
} }
func (s *membershipStatements) DeleteMembership(
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) error {
_, err := sqlutil.TxStmt(txn, s.deleteMembershipStmt).ExecContext(
ctx, roomNID, targetUserNID,
)
return err
}

View file

@ -15,7 +15,7 @@ type MembershipUpdater struct {
d *Database d *Database
roomNID types.RoomNID roomNID types.RoomNID
targetUserNID types.EventStateKeyNID targetUserNID types.EventStateKeyNID
membership tables.MembershipState oldMembership tables.MembershipState
} }
func NewMembershipUpdater( func NewMembershipUpdater(
@ -30,7 +30,6 @@ func NewMembershipUpdater(
if err != nil { if err != nil {
return err return err
} }
targetUserNID, err = d.assignStateKeyNID(ctx, targetUserID) targetUserNID, err = d.assignStateKeyNID(ctx, targetUserID)
if err != nil { if err != nil {
return err return err
@ -73,146 +72,62 @@ func (d *Database) membershipUpdaterTxn(
// IsInvite implements types.MembershipUpdater // IsInvite implements types.MembershipUpdater
func (u *MembershipUpdater) IsInvite() bool { func (u *MembershipUpdater) IsInvite() bool {
return u.membership == tables.MembershipStateInvite return u.oldMembership == tables.MembershipStateInvite
} }
// IsJoin implements types.MembershipUpdater // IsJoin implements types.MembershipUpdater
func (u *MembershipUpdater) IsJoin() bool { func (u *MembershipUpdater) IsJoin() bool {
return u.membership == tables.MembershipStateJoin return u.oldMembership == tables.MembershipStateJoin
} }
// IsLeave implements types.MembershipUpdater // IsLeave implements types.MembershipUpdater
func (u *MembershipUpdater) IsLeave() bool { func (u *MembershipUpdater) IsLeave() bool {
return u.membership == tables.MembershipStateLeaveOrBan return u.oldMembership == tables.MembershipStateLeaveOrBan
} }
// IsKnock implements types.MembershipUpdater // IsKnock implements types.MembershipUpdater
func (u *MembershipUpdater) IsKnock() bool { func (u *MembershipUpdater) IsKnock() bool {
return u.membership == tables.MembershipStateKnock return u.oldMembership == tables.MembershipStateKnock
} }
// SetToInvite implements types.MembershipUpdater func (u *MembershipUpdater) Delete() error {
func (u *MembershipUpdater) SetToInvite(event *gomatrixserverlib.Event) (bool, error) { if _, err := u.d.InvitesTable.UpdateInviteRetired(u.ctx, u.txn, u.roomNID, u.targetUserNID); err != nil {
var inserted bool return err
err := u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error { }
return u.d.MembershipTable.DeleteMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID)
}
func (u *MembershipUpdater) Update(newMembership tables.MembershipState, event *types.Event) (bool, []string, error) {
var inserted bool // Did the query result in a membership change?
var retired []string // Did we retire any updates in the process?
return inserted, retired, u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
senderUserNID, err := u.d.assignStateKeyNID(u.ctx, event.Sender()) senderUserNID, err := u.d.assignStateKeyNID(u.ctx, event.Sender())
if err != nil { if err != nil {
return fmt.Errorf("u.d.AssignStateKeyNID: %w", err) return fmt.Errorf("u.d.AssignStateKeyNID: %w", err)
} }
inserted, err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, newMembership, event.EventNID, false)
if err != nil {
return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err)
}
if !inserted {
return nil
}
switch {
case u.oldMembership != tables.MembershipStateInvite && newMembership == tables.MembershipStateInvite:
inserted, err = u.d.InvitesTable.InsertInviteEvent( inserted, err = u.d.InvitesTable.InsertInviteEvent(
u.ctx, u.txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(), u.ctx, u.txn, event.EventID(), u.roomNID, u.targetUserNID, senderUserNID, event.JSON(),
) )
if err != nil { if err != nil {
return fmt.Errorf("u.d.InvitesTable.InsertInviteEvent: %w", err) return fmt.Errorf("u.d.InvitesTable.InsertInviteEvent: %w", err)
} }
case u.oldMembership == tables.MembershipStateInvite && newMembership != tables.MembershipStateInvite:
// Look up the NID of the invite event retired, err = u.d.InvitesTable.UpdateInviteRetired(
nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{event.EventID()}, false)
if err != nil {
return fmt.Errorf("u.d.EventNIDs: %w", err)
}
if u.membership != tables.MembershipStateInvite {
if inserted, err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateInvite, nIDs[event.EventID()], false); err != nil {
return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err)
}
}
return nil
})
return inserted, err
}
// SetToJoin implements types.MembershipUpdater
func (u *MembershipUpdater) SetToJoin(senderUserID string, eventID string, isUpdate bool) ([]string, error) {
var inviteEventIDs []string
err := u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
senderUserNID, err := u.d.assignStateKeyNID(u.ctx, senderUserID)
if err != nil {
return fmt.Errorf("u.d.AssignStateKeyNID: %w", err)
}
// If this is a join event update, there is no invite to update
if !isUpdate {
inviteEventIDs, err = u.d.InvitesTable.UpdateInviteRetired(
u.ctx, u.txn, u.roomNID, u.targetUserNID, u.ctx, u.txn, u.roomNID, u.targetUserNID,
) )
if err != nil { if err != nil {
return fmt.Errorf("u.d.InvitesTables.UpdateInviteRetired: %w", err) return fmt.Errorf("u.d.InvitesTables.UpdateInviteRetired: %w", err)
} }
} }
// Look up the NID of the new join event
nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{eventID}, false)
if err != nil {
return fmt.Errorf("u.d.EventNIDs: %w", err)
}
if u.membership != tables.MembershipStateJoin || isUpdate {
if _, err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateJoin, nIDs[eventID], false); err != nil {
return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err)
}
}
return nil return nil
}) })
return inviteEventIDs, err
}
// SetToLeave implements types.MembershipUpdater
func (u *MembershipUpdater) SetToLeave(senderUserID string, eventID string) ([]string, error) {
var inviteEventIDs []string
err := u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
senderUserNID, err := u.d.assignStateKeyNID(u.ctx, senderUserID)
if err != nil {
return fmt.Errorf("u.d.AssignStateKeyNID: %w", err)
}
inviteEventIDs, err = u.d.InvitesTable.UpdateInviteRetired(
u.ctx, u.txn, u.roomNID, u.targetUserNID,
)
if err != nil {
return fmt.Errorf("u.d.InvitesTable.updateInviteRetired: %w", err)
}
// Look up the NID of the new leave event
nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{eventID}, false)
if err != nil {
return fmt.Errorf("u.d.EventNIDs: %w", err)
}
if u.membership != tables.MembershipStateLeaveOrBan {
if _, err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateLeaveOrBan, nIDs[eventID], false); err != nil {
return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err)
}
}
return nil
})
return inviteEventIDs, err
}
// SetToKnock implements types.MembershipUpdater
func (u *MembershipUpdater) SetToKnock(event *gomatrixserverlib.Event) (bool, error) {
var inserted bool
err := u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
senderUserNID, err := u.d.assignStateKeyNID(u.ctx, event.Sender())
if err != nil {
return fmt.Errorf("u.d.AssignStateKeyNID: %w", err)
}
if u.membership != tables.MembershipStateKnock {
// Look up the NID of the new knock event
nIDs, err := u.d.eventNIDs(u.ctx, u.txn, []string{event.EventID()}, false)
if err != nil {
return fmt.Errorf("u.d.EventNIDs: %w", err)
}
if inserted, err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateKnock, nIDs[event.EventID()], false); err != nil {
return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err)
}
}
return nil
})
return inserted, err
} }

View file

@ -125,6 +125,9 @@ const selectServerInRoomSQL = "" +
" JOIN roomserver_event_state_keys ON roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" + " JOIN roomserver_event_state_keys ON roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" +
" WHERE membership_nid = $1 AND room_nid = $2 AND event_state_key LIKE '%:' || $3 LIMIT 1" " WHERE membership_nid = $1 AND room_nid = $2 AND event_state_key LIKE '%:' || $3 LIMIT 1"
const deleteMembershipSQL = "" +
"DELETE FROM roomserver_membership WHERE room_nid = $1 AND target_nid = $2"
type membershipStatements struct { type membershipStatements struct {
db *sql.DB db *sql.DB
insertMembershipStmt *sql.Stmt insertMembershipStmt *sql.Stmt
@ -140,6 +143,7 @@ type membershipStatements struct {
updateMembershipForgetRoomStmt *sql.Stmt updateMembershipForgetRoomStmt *sql.Stmt
selectLocalServerInRoomStmt *sql.Stmt selectLocalServerInRoomStmt *sql.Stmt
selectServerInRoomStmt *sql.Stmt selectServerInRoomStmt *sql.Stmt
deleteMembershipStmt *sql.Stmt
} }
func CreateMembershipTable(db *sql.DB) error { func CreateMembershipTable(db *sql.DB) error {
@ -166,6 +170,7 @@ func PrepareMembershipTable(db *sql.DB) (tables.Membership, error) {
{&s.updateMembershipForgetRoomStmt, updateMembershipForgetRoom}, {&s.updateMembershipForgetRoomStmt, updateMembershipForgetRoom},
{&s.selectLocalServerInRoomStmt, selectLocalServerInRoomSQL}, {&s.selectLocalServerInRoomStmt, selectLocalServerInRoomSQL},
{&s.selectServerInRoomStmt, selectServerInRoomSQL}, {&s.selectServerInRoomStmt, selectServerInRoomSQL},
{&s.deleteMembershipStmt, deleteMembershipSQL},
}.Prepare(db) }.Prepare(db)
} }
@ -383,3 +388,13 @@ func (s *membershipStatements) SelectServerInRoom(ctx context.Context, txn *sql.
} }
return roomNID == nid, nil return roomNID == nid, nil
} }
func (s *membershipStatements) DeleteMembership(
ctx context.Context, txn *sql.Tx,
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
) error {
_, err := sqlutil.TxStmt(txn, s.deleteMembershipStmt).ExecContext(
ctx, roomNID, targetUserNID,
)
return err
}

View file

@ -133,6 +133,7 @@ type Membership interface {
UpdateForgetMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool) error UpdateForgetMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID, forget bool) error
SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error) SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error)
SelectServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error) SelectServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error)
DeleteMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) error
} }
type Published interface { type Published interface {

View file

@ -365,7 +365,7 @@ func (s *OutputRoomEventConsumer) onNewInviteEvent(
"event": string(msg.Event.JSON()), "event": string(msg.Event.JSON()),
"pdupos": pduPos, "pdupos": pduPos,
log.ErrorKey: err, log.ErrorKey: err,
}).Panicf("roomserver output log: write invite failure") }).Errorf("roomserver output log: write invite failure")
return return
} }
@ -385,7 +385,7 @@ func (s *OutputRoomEventConsumer) onRetireInviteEvent(
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"event_id": msg.EventID, "event_id": msg.EventID,
log.ErrorKey: err, log.ErrorKey: err,
}).Panicf("roomserver output log: remove invite failure") }).Errorf("roomserver output log: remove invite failure")
return return
} }
@ -403,7 +403,7 @@ func (s *OutputRoomEventConsumer) onNewPeek(
// panic rather than continue with an inconsistent database // panic rather than continue with an inconsistent database
log.WithFields(log.Fields{ log.WithFields(log.Fields{
log.ErrorKey: err, log.ErrorKey: err,
}).Panicf("roomserver output log: write peek failure") }).Errorf("roomserver output log: write peek failure")
return return
} }
@ -422,7 +422,7 @@ func (s *OutputRoomEventConsumer) onRetirePeek(
// panic rather than continue with an inconsistent database // panic rather than continue with an inconsistent database
log.WithFields(log.Fields{ log.WithFields(log.Fields{
log.ErrorKey: err, log.ErrorKey: err,
}).Panicf("roomserver output log: write peek failure") }).Errorf("roomserver output log: write peek failure")
return return
} }