From 6bf1912525724baa1a8bf5e0bf51524d7cee59ba Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Mon, 10 Oct 2022 16:54:04 +0100 Subject: [PATCH] Fix joined hosts with `RewritesState` (#2785) This ensures that the joined hosts in the federation API are correct after the state is rewritten. This might fix some races around the time of joining federated rooms. --- federationapi/storage/shared/storage.go | 32 ++++++++++++------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/federationapi/storage/shared/storage.go b/federationapi/storage/shared/storage.go index a00d782f..9e40f311 100644 --- a/federationapi/storage/shared/storage.go +++ b/federationapi/storage/shared/storage.go @@ -70,27 +70,27 @@ func (d *Database) UpdateRoom( ) (joinedHosts []types.JoinedHost, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { if purgeRoomFirst { - // If the event is a create event then we'll delete all of the existing - // data for the room. The only reason that a create event would be replayed - // to us in this way is if we're about to receive the entire room state. if err = d.FederationJoinedHosts.DeleteJoinedHostsForRoom(ctx, txn, roomID); err != nil { return fmt.Errorf("d.FederationJoinedHosts.DeleteJoinedHosts: %w", err) } - } - - joinedHosts, err = d.FederationJoinedHosts.SelectJoinedHostsWithTx(ctx, txn, roomID) - if err != nil { - return err - } - - for _, add := range addHosts { - err = d.FederationJoinedHosts.InsertJoinedHosts(ctx, txn, roomID, add.MemberEventID, add.ServerName) - if err != nil { + for _, add := range addHosts { + if err = d.FederationJoinedHosts.InsertJoinedHosts(ctx, txn, roomID, add.MemberEventID, add.ServerName); err != nil { + return err + } + joinedHosts = append(joinedHosts, add) + } + } else { + if joinedHosts, err = d.FederationJoinedHosts.SelectJoinedHostsWithTx(ctx, txn, roomID); err != nil { + return err + } + for _, add := range addHosts { + if err = d.FederationJoinedHosts.InsertJoinedHosts(ctx, txn, roomID, add.MemberEventID, add.ServerName); err != nil { + return err + } + } + if err = d.FederationJoinedHosts.DeleteJoinedHosts(ctx, txn, removeHosts); err != nil { return err } - } - if err = d.FederationJoinedHosts.DeleteJoinedHosts(ctx, txn, removeHosts); err != nil { - return err } return nil })