Refactor forward extremities (#1556)

* Add resolve-state helper

* Tweaks

* Refactor forward extremities, again

* Tweaks

* Minor optimisation

* Make path a bit clearer

* Only process state/membership if forward extremities have changed

* Usage comments in resolve-state
This commit is contained in:
Neil Alexander 2020-10-21 15:37:07 +01:00 committed by GitHub
parent e4f3f38f35
commit 534f9a9eb6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 209 additions and 58 deletions

132
cmd/resolve-state/main.go Normal file
View file

@ -0,0 +1,132 @@
package main
import (
"context"
"flag"
"fmt"
"os"
"strconv"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/setup"
"github.com/matrix-org/dendrite/roomserver/state"
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
)
// This is a utility for inspecting state snapshots and running state resolution
// against real snapshots in an actual database.
// It takes one or more state snapshot NIDs as arguments, along with a room version
// to use for unmarshalling events, and will produce resolved output.
//
// Usage: ./resolve-state --roomversion=version snapshot [snapshot ...]
// e.g. ./resolve-state --roomversion=5 1254 1235 1282
var roomVersion = flag.String("roomversion", "5", "the room version to parse events as")
// nolint:gocyclo
func main() {
ctx := context.Background()
cfg := setup.ParseFlags(true)
args := os.Args[1:]
fmt.Println("Room version", *roomVersion)
snapshotNIDs := []types.StateSnapshotNID{}
for _, arg := range args {
if i, err := strconv.Atoi(arg); err == nil {
snapshotNIDs = append(snapshotNIDs, types.StateSnapshotNID(i))
}
}
fmt.Println("Fetching", len(snapshotNIDs), "snapshot NIDs")
cache, err := caching.NewInMemoryLRUCache(true)
if err != nil {
panic(err)
}
roomserverDB, err := storage.Open(&cfg.RoomServer.Database, cache)
if err != nil {
panic(err)
}
blockNIDs, err := roomserverDB.StateBlockNIDs(ctx, snapshotNIDs)
if err != nil {
panic(err)
}
var stateEntries []types.StateEntryList
for _, list := range blockNIDs {
entries, err2 := roomserverDB.StateEntries(ctx, list.StateBlockNIDs)
if err2 != nil {
panic(err2)
}
stateEntries = append(stateEntries, entries...)
}
var eventNIDs []types.EventNID
for _, entry := range stateEntries {
for _, e := range entry.StateEntries {
eventNIDs = append(eventNIDs, e.EventNID)
}
}
fmt.Println("Fetching", len(eventNIDs), "state events")
eventEntries, err := roomserverDB.Events(ctx, eventNIDs)
if err != nil {
panic(err)
}
authEventIDMap := make(map[string]struct{})
eventPtrs := make([]*gomatrixserverlib.Event, len(eventEntries))
for i := range eventEntries {
eventPtrs[i] = &eventEntries[i].Event
for _, authEventID := range eventEntries[i].AuthEventIDs() {
authEventIDMap[authEventID] = struct{}{}
}
}
authEventIDs := make([]string, 0, len(authEventIDMap))
for authEventID := range authEventIDMap {
authEventIDs = append(authEventIDs, authEventID)
}
fmt.Println("Fetching", len(authEventIDs), "auth events")
authEventEntries, err := roomserverDB.EventsFromIDs(ctx, authEventIDs)
if err != nil {
panic(err)
}
authEventPtrs := make([]*gomatrixserverlib.Event, len(authEventEntries))
for i := range authEventEntries {
authEventPtrs[i] = &authEventEntries[i].Event
}
events := make([]gomatrixserverlib.Event, len(eventEntries))
authEvents := make([]gomatrixserverlib.Event, len(authEventEntries))
for i, ptr := range eventPtrs {
events[i] = *ptr
}
for i, ptr := range authEventPtrs {
authEvents[i] = *ptr
}
fmt.Println("Resolving state")
resolved, err := state.ResolveConflictsAdhoc(
gomatrixserverlib.RoomVersion(*roomVersion),
events,
authEvents,
)
if err != nil {
panic(err)
}
fmt.Println("Resolved state contains", len(resolved), "events")
for _, event := range resolved {
fmt.Println()
fmt.Printf("* %s %s %q\n", event.EventID(), event.Type(), *event.StateKey())
fmt.Printf(" %s\n", string(event.Content()))
}
}

View file

@ -17,7 +17,6 @@
package input package input
import ( import (
"bytes"
"context" "context"
"fmt" "fmt"
@ -28,7 +27,6 @@ import (
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/sirupsen/logrus"
) )
// updateLatestEvents updates the list of latest events for this room in the database and writes the // updateLatestEvents updates the list of latest events for this room in the database and writes the
@ -141,28 +139,31 @@ func (u *latestEventsUpdater) doUpdateLatestEvents() error {
// Work out what the latest events are. This will include the new // Work out what the latest events are. This will include the new
// event if it is not already referenced. // event if it is not already referenced.
if err := u.calculateLatest( extremitiesChanged, err := u.calculateLatest(
oldLatest, oldLatest, &u.event,
types.StateAtEventAndReference{ types.StateAtEventAndReference{
EventReference: u.event.EventReference(), EventReference: u.event.EventReference(),
StateAtEvent: u.stateAtEvent, StateAtEvent: u.stateAtEvent,
}, },
); err != nil { )
if err != nil {
return fmt.Errorf("u.calculateLatest: %w", err) return fmt.Errorf("u.calculateLatest: %w", err)
} }
// Now that we know what the latest events are, it's time to get the // Now that we know what the latest events are, it's time to get the
// latest state. // latest state.
if err := u.latestState(); err != nil { var updates []api.OutputEvent
if extremitiesChanged {
if err = u.latestState(); err != nil {
return fmt.Errorf("u.latestState: %w", err) return fmt.Errorf("u.latestState: %w", err)
} }
// If we need to generate any output events then here's where we do it. // If we need to generate any output events then here's where we do it.
// TODO: Move this! // TODO: Move this!
updates, err := u.api.updateMemberships(u.ctx, u.updater, u.removed, u.added) if updates, err = u.api.updateMemberships(u.ctx, u.updater, u.removed, u.added); err != nil {
if err != nil {
return fmt.Errorf("u.api.updateMemberships: %w", err) return fmt.Errorf("u.api.updateMemberships: %w", err)
} }
}
update, err := u.makeOutputNewRoomEvent() update, err := u.makeOutputNewRoomEvent()
if err != nil { if err != nil {
@ -250,50 +251,74 @@ func (u *latestEventsUpdater) latestState() error {
// true if the new event is included in those extremites, false otherwise. // true if the new event is included in those extremites, false otherwise.
func (u *latestEventsUpdater) calculateLatest( func (u *latestEventsUpdater) calculateLatest(
oldLatest []types.StateAtEventAndReference, oldLatest []types.StateAtEventAndReference,
newEvent types.StateAtEventAndReference, newEvent *gomatrixserverlib.Event,
) error { newStateAndRef types.StateAtEventAndReference,
var newLatest []types.StateAtEventAndReference ) (bool, error) {
// First of all, get a list of all of the events in our current
// set of forward extremities.
existingRefs := make(map[string]*types.StateAtEventAndReference)
existingNIDs := make([]types.EventNID, len(oldLatest))
for i, old := range oldLatest {
existingRefs[old.EventID] = &oldLatest[i]
existingNIDs[i] = old.EventNID
}
// First of all, let's see if any of the existing forward extremities // Look up the old extremity events. This allows us to find their
// now have entries in the previous events table. If they do then we // prev events.
// will no longer include them as forward extremities. events, err := u.api.DB.Events(u.ctx, existingNIDs)
for _, l := range oldLatest {
referenced, err := u.updater.IsReferenced(l.EventReference)
if err != nil { if err != nil {
logrus.WithError(err).Errorf("Failed to retrieve event reference for %q", l.EventID) return false, fmt.Errorf("u.api.DB.Events: %w", err)
return fmt.Errorf("u.updater.IsReferenced (old): %w", err) }
} else if !referenced {
newLatest = append(newLatest, l) // Make a list of all of the prev events as referenced by all of
// the current forward extremities.
existingPrevs := make(map[string]struct{})
for _, old := range events {
for _, prevEventID := range old.PrevEventIDs() {
existingPrevs[prevEventID] = struct{}{}
} }
} }
// Then check and see if our new event is already included in that set. // If the "new" event is already referenced by a forward extremity
// This ordinarily won't happen but it covers the edge-case that we've // then do nothing - it's not a candidate to be a new extremity if
// already seen this event before and it's a forward extremity, so rather // it has been referenced.
// than adding a duplicate, we'll just return the set as complete. if _, ok := existingPrevs[newEvent.EventID()]; ok {
for _, l := range newLatest { return false, nil
if l.EventReference.EventID == newEvent.EventReference.EventID && bytes.Equal(l.EventReference.EventSHA256, newEvent.EventReference.EventSHA256) { }
// We've already referenced this new event so we can just return
// the newly completed extremities at this point. // If the "new" event is already a forward extremity then stop, as
u.latest = newLatest // nothing changes.
return nil for _, event := range events {
if event.EventID() == newEvent.EventID() {
return false, nil
} }
} }
// At this point we've processed the old extremities, and we've checked // Include our new event in the extremities.
// that our new event isn't already in that set. Therefore now we can newLatest := []types.StateAtEventAndReference{newStateAndRef}
// check if our *new* event is a forward extremity, and if it is, add
// it in. // Then run through and see if the other extremities are still valid.
referenced, err := u.updater.IsReferenced(newEvent.EventReference) // If our new event references them then they are no longer good
if err != nil { // candidates.
logrus.WithError(err).Errorf("Failed to retrieve event reference for %q", newEvent.EventReference.EventID) for _, prevEventID := range newEvent.PrevEventIDs() {
return fmt.Errorf("u.updater.IsReferenced (new): %w", err) delete(existingRefs, prevEventID)
} else if !referenced || len(newLatest) == 0 { }
newLatest = append(newLatest, newEvent)
// Ensure that we don't add any candidate forward extremities from
// the old set that are, themselves, referenced by the old set of
// forward extremities. This shouldn't happen but guards against
// the possibility anyway.
for prevEventID := range existingPrevs {
delete(existingRefs, prevEventID)
}
// Then re-add any old extremities that are still valid after all.
for _, old := range existingRefs {
newLatest = append(newLatest, *old)
} }
u.latest = newLatest u.latest = newLatest
return nil return true, nil
} }
func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error) { func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error) {

View file

@ -526,13 +526,7 @@ func (v StateResolution) CalculateAndStoreStateBeforeEvent(
isRejected bool, isRejected bool,
) (types.StateSnapshotNID, error) { ) (types.StateSnapshotNID, error) {
// Load the state at the prev events. // Load the state at the prev events.
prevEventRefs := event.PrevEvents() prevStates, err := v.db.StateAtEventIDs(ctx, event.PrevEventIDs())
prevEventIDs := make([]string, len(prevEventRefs))
for i := range prevEventRefs {
prevEventIDs[i] = prevEventRefs[i].EventID
}
prevStates, err := v.db.StateAtEventIDs(ctx, prevEventIDs)
if err != nil { if err != nil {
return 0, err return 0, err
} }