Use memberships to determine whether to reset latest events/state on room join (#1047)

* Track local/remote memberships, re-scope some input stuff

* Check if we're in the room already before resetting latest events/state

* Fix postgres, fix lint

* Review comments
This commit is contained in:
Neil Alexander 2020-05-20 18:03:06 +01:00 committed by GitHub
parent 6091bf044f
commit f2c07437fe
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 162 additions and 100 deletions

View file

@ -569,9 +569,9 @@ func (u *roomRecentEventsUpdater) MarkEventAsSent(eventNID types.EventNID) error
return err
}
func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID) (mu types.MembershipUpdater, err error) {
func (u *roomRecentEventsUpdater) MembershipUpdater(targetUserNID types.EventStateKeyNID, targetLocal bool) (mu types.MembershipUpdater, err error) {
err = common.WithTransaction(u.d.db, func(txn *sql.Tx) error {
mu, err = u.d.membershipUpdaterTxn(u.ctx, txn, u.roomNID, targetUserNID)
mu, err = u.d.membershipUpdaterTxn(u.ctx, txn, u.roomNID, targetUserNID, targetLocal)
return err
})
return
@ -680,7 +680,7 @@ func (d *Database) StateEntriesForTuples(
// MembershipUpdater implements input.RoomEventDatabase
func (d *Database) MembershipUpdater(
ctx context.Context, roomID, targetUserID string,
roomVersion gomatrixserverlib.RoomVersion,
targetLocal bool, roomVersion gomatrixserverlib.RoomVersion,
) (updater types.MembershipUpdater, err error) {
var txn *sql.Tx
txn, err = d.db.Begin()
@ -716,7 +716,7 @@ func (d *Database) MembershipUpdater(
return nil, err
}
updater, err = d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID)
updater, err = d.membershipUpdaterTxn(ctx, txn, roomNID, targetUserNID, targetLocal)
if err != nil {
return nil, err
}
@ -738,9 +738,10 @@ func (d *Database) membershipUpdaterTxn(
txn *sql.Tx,
roomNID types.RoomNID,
targetUserNID types.EventStateKeyNID,
targetLocal bool,
) (types.MembershipUpdater, error) {
if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID); err != nil {
if err := d.statements.insertMembership(ctx, txn, roomNID, targetUserNID, targetLocal); err != nil {
return nil, err
}
@ -896,17 +897,17 @@ func (d *Database) GetMembership(
// GetMembershipEventNIDsForRoom implements query.RoomserverQueryAPIDB
func (d *Database) GetMembershipEventNIDsForRoom(
ctx context.Context, roomNID types.RoomNID, joinOnly bool,
ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool,
) (eventNIDs []types.EventNID, err error) {
err = common.WithTransaction(d.db, func(txn *sql.Tx) error {
if joinOnly {
eventNIDs, err = d.statements.selectMembershipsFromRoomAndMembership(
ctx, txn, roomNID, membershipStateJoin,
ctx, txn, roomNID, membershipStateJoin, localOnly,
)
return nil
}
eventNIDs, err = d.statements.selectMembershipsFromRoom(ctx, txn, roomNID)
eventNIDs, err = d.statements.selectMembershipsFromRoom(ctx, txn, roomNID, localOnly)
return nil
})
return