Track knocking in membership updater (#1935)

* Topologically sort outliers in SendEventWithState

* Knock in membership updater

* Update gomatrixserverlib

* Update gomatrixserverlib

* Get the NID of the knock event properly for the membership updater
This commit is contained in:
Neil Alexander 2021-07-22 12:26:58 +01:00 committed by GitHub
parent 43ac66e0b4
commit 39e8d1cc6f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 47 additions and 3 deletions

View file

@ -136,6 +136,8 @@ func (r *Inputer) updateMembership(
return updateToJoinMembership(mu, add, updates)
case gomatrixserverlib.Leave, gomatrixserverlib.Ban:
return updateToLeaveMembership(mu, add, newMembership, updates)
case gomatrixserverlib.Knock:
return updateToKnockMembership(mu, add, updates)
default:
panic(fmt.Errorf(
"input: membership %q is not one of the allowed values", newMembership,
@ -220,6 +222,18 @@ func updateToLeaveMembership(
return updates, nil
}
func updateToKnockMembership(
mu *shared.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent,
) ([]api.OutputEvent, error) {
if mu.IsLeave() {
_, err := mu.SetToKnock(add)
if err != nil {
return nil, err
}
}
return updates, nil
}
// membershipChanges pairs up the membership state changes.
func membershipChanges(removed, added []types.StateEntry) []stateChange {
changes := pairUpChanges(removed, added)

View file

@ -86,6 +86,11 @@ func (u *MembershipUpdater) IsLeave() bool {
return u.membership == tables.MembershipStateLeaveOrBan
}
// IsKnock implements types.MembershipUpdater
func (u *MembershipUpdater) IsKnock() bool {
return u.membership == tables.MembershipStateKnock
}
// SetToInvite implements types.MembershipUpdater
func (u *MembershipUpdater) SetToInvite(event gomatrixserverlib.Event) (bool, error) {
var inserted bool
@ -180,3 +185,27 @@ func (u *MembershipUpdater) SetToLeave(senderUserID string, eventID string) ([]s
})
return inviteEventIDs, err
}
// SetToKnock implements types.MembershipUpdater
func (u *MembershipUpdater) SetToKnock(event *gomatrixserverlib.Event) (bool, error) {
var inserted bool
err := u.d.Writer.Do(u.d.DB, u.txn, func(txn *sql.Tx) error {
senderUserNID, err := u.d.assignStateKeyNID(u.ctx, u.txn, event.Sender())
if err != nil {
return fmt.Errorf("u.d.AssignStateKeyNID: %w", err)
}
if u.membership != tables.MembershipStateKnock {
// Look up the NID of the new knock event
nIDs, err := u.d.EventNIDs(u.ctx, []string{event.EventID()})
if err != nil {
return fmt.Errorf("u.d.EventNIDs: %w", err)
}
if err = u.d.MembershipTable.UpdateMembership(u.ctx, u.txn, u.roomNID, u.targetUserNID, senderUserNID, tables.MembershipStateKnock, nIDs[event.EventID()], false); err != nil {
return fmt.Errorf("u.d.MembershipTable.UpdateMembership: %w", err)
}
}
return nil
})
return inserted, err
}

View file

@ -120,6 +120,7 @@ const (
MembershipStateLeaveOrBan MembershipState = 1
MembershipStateInvite MembershipState = 2
MembershipStateJoin MembershipState = 3
MembershipStateKnock MembershipState = 4
)
type Membership interface {