Refactor StoreEvent and create a new RoomDatabase interface (#2985)

This PR changes a few things:
- It pulls out the creation of several NIDs from the `StoreEvent`
function to make the functions more reusable
- Uses more caching when using those NIDs to avoid DB round trips
This commit is contained in:
Till 2023-02-24 09:40:20 +01:00 committed by GitHub
parent e6aa0955ff
commit ad07b169b8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
31 changed files with 459 additions and 302 deletions

View file

@ -31,7 +31,8 @@ import (
// the soft-fail bool.
func CheckForSoftFail(
ctx context.Context,
db storage.Database,
db storage.RoomDatabase,
roomInfo *types.RoomInfo,
event *gomatrixserverlib.HeaderedEvent,
stateEventIDs []string,
) (bool, error) {
@ -45,16 +46,6 @@ func CheckForSoftFail(
return true, fmt.Errorf("StateEntriesForEventIDs failed: %w", err)
}
} else {
// Work out if the room exists.
var roomInfo *types.RoomInfo
roomInfo, err = db.RoomInfo(ctx, event.RoomID())
if err != nil {
return false, fmt.Errorf("db.RoomNID: %w", err)
}
if roomInfo == nil || roomInfo.IsStub() {
return false, nil
}
// Then get the state entries for the current state snapshot.
// We'll use this to check if the event is allowed right now.
roomState := state.NewStateResolution(db, roomInfo)
@ -76,7 +67,7 @@ func CheckForSoftFail(
stateNeeded := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{event.Unwrap()})
// Load the actual auth events from the database.
authEvents, err := loadAuthEvents(ctx, db, stateNeeded, authStateEntries)
authEvents, err := loadAuthEvents(ctx, db, roomInfo.RoomNID, stateNeeded, authStateEntries)
if err != nil {
return true, fmt.Errorf("loadAuthEvents: %w", err)
}
@ -93,7 +84,8 @@ func CheckForSoftFail(
// Returns the numeric IDs for the auth events.
func CheckAuthEvents(
ctx context.Context,
db storage.Database,
db storage.RoomDatabase,
roomNID types.RoomNID,
event *gomatrixserverlib.HeaderedEvent,
authEventIDs []string,
) ([]types.EventNID, error) {
@ -108,7 +100,7 @@ func CheckAuthEvents(
stateNeeded := gomatrixserverlib.StateNeededForAuth([]*gomatrixserverlib.Event{event.Unwrap()})
// Load the actual auth events from the database.
authEvents, err := loadAuthEvents(ctx, db, stateNeeded, authStateEntries)
authEvents, err := loadAuthEvents(ctx, db, roomNID, stateNeeded, authStateEntries)
if err != nil {
return nil, fmt.Errorf("loadAuthEvents: %w", err)
}
@ -201,6 +193,7 @@ func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) *
func loadAuthEvents(
ctx context.Context,
db state.StateResolutionStorage,
roomNID types.RoomNID,
needed gomatrixserverlib.StateNeeded,
state []types.StateEntry,
) (result authEvents, err error) {
@ -223,7 +216,7 @@ func loadAuthEvents(
eventNIDs = append(eventNIDs, eventNID)
}
}
if result.events, err = db.Events(ctx, eventNIDs); err != nil {
if result.events, err = db.Events(ctx, roomNID, eventNIDs); err != nil {
return
}
roomID := ""

View file

@ -85,7 +85,7 @@ func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverNam
return false, err
}
events, err := db.Events(ctx, eventNIDs)
events, err := db.Events(ctx, info.RoomNID, eventNIDs)
if err != nil {
return false, err
}
@ -157,7 +157,7 @@ func IsInvitePending(
// only keep the "m.room.member" events with a "join" membership. These events are returned.
// Returns an error if there was an issue fetching the events.
func GetMembershipsAtState(
ctx context.Context, db storage.Database, stateEntries []types.StateEntry, joinedOnly bool,
ctx context.Context, db storage.RoomDatabase, roomNID types.RoomNID, stateEntries []types.StateEntry, joinedOnly bool,
) ([]types.Event, error) {
var eventNIDs types.EventNIDs
@ -177,7 +177,7 @@ func GetMembershipsAtState(
util.Unique(eventNIDs)
// Get all of the events in this state
stateEvents, err := db.Events(ctx, eventNIDs)
stateEvents, err := db.Events(ctx, roomNID, eventNIDs)
if err != nil {
return nil, err
}
@ -220,16 +220,16 @@ func StateBeforeEvent(ctx context.Context, db storage.Database, info *types.Room
return roomState.LoadCombinedStateAfterEvents(ctx, prevState)
}
func MembershipAtEvent(ctx context.Context, db storage.Database, info *types.RoomInfo, eventIDs []string, stateKeyNID types.EventStateKeyNID) (map[string][]types.StateEntry, error) {
func MembershipAtEvent(ctx context.Context, db storage.RoomDatabase, info *types.RoomInfo, eventIDs []string, stateKeyNID types.EventStateKeyNID) (map[string][]types.StateEntry, error) {
roomState := state.NewStateResolution(db, info)
// Fetch the state as it was when this event was fired
return roomState.LoadMembershipAtEvent(ctx, eventIDs, stateKeyNID)
}
func LoadEvents(
ctx context.Context, db storage.Database, eventNIDs []types.EventNID,
ctx context.Context, db storage.RoomDatabase, roomNID types.RoomNID, eventNIDs []types.EventNID,
) ([]*gomatrixserverlib.Event, error) {
stateEvents, err := db.Events(ctx, eventNIDs)
stateEvents, err := db.Events(ctx, roomNID, eventNIDs)
if err != nil {
return nil, err
}
@ -242,13 +242,13 @@ func LoadEvents(
}
func LoadStateEvents(
ctx context.Context, db storage.Database, stateEntries []types.StateEntry,
ctx context.Context, db storage.RoomDatabase, roomNID types.RoomNID, stateEntries []types.StateEntry,
) ([]*gomatrixserverlib.Event, error) {
eventNIDs := make([]types.EventNID, len(stateEntries))
for i := range stateEntries {
eventNIDs[i] = stateEntries[i].EventNID
}
return LoadEvents(ctx, db, eventNIDs)
return LoadEvents(ctx, db, roomNID, eventNIDs)
}
func CheckServerAllowedToSeeEvent(
@ -326,7 +326,7 @@ func slowGetHistoryVisibilityState(
return nil, nil
}
return LoadStateEvents(ctx, db, filteredEntries)
return LoadStateEvents(ctx, db, info.RoomNID, filteredEntries)
}
// TODO: Remove this when we have tests to assert correctness of this function
@ -366,7 +366,7 @@ BFSLoop:
next = make([]string, 0)
}
// Retrieve the events to process from the database.
events, err = db.EventsFromIDs(ctx, front)
events, err = db.EventsFromIDs(ctx, info.RoomNID, front)
if err != nil {
return resultNIDs, redactEventIDs, err
}
@ -467,7 +467,7 @@ func QueryLatestEventsAndState(
return err
}
stateEvents, err := LoadStateEvents(ctx, db, stateEntries)
stateEvents, err := LoadStateEvents(ctx, db, roomInfo.RoomNID, stateEntries)
if err != nil {
return err
}

View file

@ -38,7 +38,18 @@ func TestIsInvitePendingWithoutNID(t *testing.T) {
var authNIDs []types.EventNID
for _, x := range room.Events() {
evNID, _, _, _, _, err := db.StoreEvent(context.Background(), x.Event, authNIDs, false)
roomNID, err := db.GetOrCreateRoomNID(context.Background(), x.Unwrap())
assert.NoError(t, err)
assert.Greater(t, roomNID, types.RoomNID(0))
eventTypeNID, err := db.GetOrCreateEventTypeNID(context.Background(), x.Type())
assert.NoError(t, err)
assert.Greater(t, eventTypeNID, types.EventTypeNID(0))
eventStateKeyNID, err := db.GetOrCreateEventStateKeyNID(context.Background(), x.StateKey())
assert.NoError(t, err)
evNID, _, _, _, err := db.StoreEvent(context.Background(), x.Event, roomNID, eventTypeNID, eventStateKeyNID, authNIDs, false)
assert.NoError(t, err)
authNIDs = append(authNIDs, evNID)
}