Refactor StoreEvent, add MaybeRedactEvent, create an EventDatabase (#2989)

This PR changes the following:
- `StoreEvent` now only stores an event (and possibly prev event),
instead of also doing redactions
- Adds a `MaybeRedactEvent` (pulled out from `StoreEvent`), which should
be called after storing events
- a few other things
This commit is contained in:
Till 2023-03-01 17:06:47 +01:00 committed by GitHub
parent 1aa70b0f56
commit 6c20f8f742
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
34 changed files with 488 additions and 420 deletions

View file

@ -21,11 +21,12 @@ import (
"errors"
"fmt"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/roomserver/acls"
@ -102,7 +103,7 @@ func (r *Queryer) QueryStateAfterEvents(
return err
}
stateEvents, err := helpers.LoadStateEvents(ctx, r.DB, info.RoomNID, stateEntries)
stateEvents, err := helpers.LoadStateEvents(ctx, r.DB, info, stateEntries)
if err != nil {
return err
}
@ -114,7 +115,7 @@ func (r *Queryer) QueryStateAfterEvents(
}
authEventIDs = util.UniqueStrings(authEventIDs)
authEvents, err := GetAuthChain(ctx, r.DB.EventsFromIDs, authEventIDs)
authEvents, err := GetAuthChain(ctx, r.DB.EventsFromIDs, info, authEventIDs)
if err != nil {
return fmt.Errorf("getAuthChain: %w", err)
}
@ -132,24 +133,46 @@ func (r *Queryer) QueryStateAfterEvents(
return nil
}
// QueryEventsByID implements api.RoomserverInternalAPI
// QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine
// which room to use by querying the first events roomID.
func (r *Queryer) QueryEventsByID(
ctx context.Context,
request *api.QueryEventsByIDRequest,
response *api.QueryEventsByIDResponse,
) error {
events, err := r.DB.EventsFromIDs(ctx, 0, request.EventIDs)
if len(request.EventIDs) == 0 {
return nil
}
var err error
// We didn't receive a room ID, we need to fetch it first before we can continue.
// This happens for e.g. ` /_matrix/federation/v1/event/{eventId}`
var roomInfo *types.RoomInfo
if request.RoomID == "" {
var eventNIDs map[string]types.EventMetadata
eventNIDs, err = r.DB.EventNIDs(ctx, []string{request.EventIDs[0]})
if err != nil {
return err
}
if len(eventNIDs) == 0 {
return nil
}
roomInfo, err = r.DB.RoomInfoByNID(ctx, eventNIDs[request.EventIDs[0]].RoomNID)
} else {
roomInfo, err = r.DB.RoomInfo(ctx, request.RoomID)
}
if err != nil {
return err
}
if roomInfo == nil {
return nil
}
events, err := r.DB.EventsFromIDs(ctx, roomInfo, request.EventIDs)
if err != nil {
return err
}
for _, event := range events {
roomVersion, verr := r.roomVersion(event.RoomID())
if verr != nil {
return verr
}
response.Events = append(response.Events, event.Headered(roomVersion))
response.Events = append(response.Events, event.Headered(roomInfo.RoomVersion))
}
return nil
@ -186,7 +209,7 @@ func (r *Queryer) QueryMembershipForUser(
response.IsInRoom = stillInRoom
response.HasBeenInRoom = true
evs, err := r.DB.Events(ctx, info.RoomNID, []types.EventNID{membershipEventNID})
evs, err := r.DB.Events(ctx, info, []types.EventNID{membershipEventNID})
if err != nil {
return err
}
@ -268,10 +291,10 @@ func (r *Queryer) QueryMembershipAtEvent(
// once. If we have more than one membership event, we need to get the state for each state entry.
if canShortCircuit {
if len(memberships) == 0 {
memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info.RoomNID, stateEntry, false)
memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info, stateEntry, false)
}
} else {
memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info.RoomNID, stateEntry, false)
memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info, stateEntry, false)
}
if err != nil {
return fmt.Errorf("unable to get memberships at state: %w", err)
@ -318,7 +341,7 @@ func (r *Queryer) QueryMembershipsForRoom(
}
return fmt.Errorf("r.DB.GetMembershipEventNIDsForRoom: %w", err)
}
events, err = r.DB.Events(ctx, info.RoomNID, eventNIDs)
events, err = r.DB.Events(ctx, info, eventNIDs)
if err != nil {
return fmt.Errorf("r.DB.Events: %w", err)
}
@ -357,14 +380,14 @@ func (r *Queryer) QueryMembershipsForRoom(
return err
}
events, err = r.DB.Events(ctx, info.RoomNID, eventNIDs)
events, err = r.DB.Events(ctx, info, eventNIDs)
} else {
stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, info, membershipEventNID)
if err != nil {
logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event")
return err
}
events, err = helpers.GetMembershipsAtState(ctx, r.DB, info.RoomNID, stateEntries, request.JoinedOnly)
events, err = helpers.GetMembershipsAtState(ctx, r.DB, info, stateEntries, request.JoinedOnly)
}
if err != nil {
@ -412,39 +435,39 @@ func (r *Queryer) QueryServerJoinedToRoom(
// QueryServerAllowedToSeeEvent implements api.RoomserverInternalAPI
func (r *Queryer) QueryServerAllowedToSeeEvent(
ctx context.Context,
request *api.QueryServerAllowedToSeeEventRequest,
response *api.QueryServerAllowedToSeeEventResponse,
) (err error) {
events, err := r.DB.EventsFromIDs(ctx, 0, []string{request.EventID})
serverName gomatrixserverlib.ServerName,
eventID string,
) (allowed bool, err error) {
events, err := r.DB.EventNIDs(ctx, []string{eventID})
if err != nil {
return
}
if len(events) == 0 {
response.AllowedToSeeEvent = false // event doesn't exist so not allowed to see
return
return allowed, nil
}
roomID := events[0].RoomID()
inRoomReq := &api.QueryServerJoinedToRoomRequest{
RoomID: roomID,
ServerName: request.ServerName,
}
inRoomRes := &api.QueryServerJoinedToRoomResponse{}
if err = r.QueryServerJoinedToRoom(ctx, inRoomReq, inRoomRes); err != nil {
return fmt.Errorf("r.Queryer.QueryServerJoinedToRoom: %w", err)
}
info, err := r.DB.RoomInfo(ctx, roomID)
info, err := r.DB.RoomInfoByNID(ctx, events[eventID].RoomNID)
if err != nil {
return err
return allowed, err
}
if info == nil || info.IsStub() {
return nil
return allowed, nil
}
response.AllowedToSeeEvent, err = helpers.CheckServerAllowedToSeeEvent(
ctx, r.DB, info, request.EventID, request.ServerName, inRoomRes.IsInRoom,
var isInRoom bool
if r.IsLocalServerName(serverName) || serverName == "" {
isInRoom, err = r.DB.GetLocalServerInRoom(ctx, info.RoomNID)
if err != nil {
return allowed, fmt.Errorf("r.DB.GetLocalServerInRoom: %w", err)
}
} else {
isInRoom, err = r.DB.GetServerInRoom(ctx, info.RoomNID, serverName)
if err != nil {
return allowed, fmt.Errorf("r.DB.GetServerInRoom: %w", err)
}
}
return helpers.CheckServerAllowedToSeeEvent(
ctx, r.DB, info, eventID, serverName, isInRoom,
)
return
}
// QueryMissingEvents implements api.RoomserverInternalAPI
@ -466,19 +489,22 @@ func (r *Queryer) QueryMissingEvents(
eventsToFilter[id] = true
}
}
events, err := r.DB.EventsFromIDs(ctx, 0, front)
if len(front) == 0 {
return nil // no events to query, give up.
}
events, err := r.DB.EventNIDs(ctx, []string{front[0]})
if err != nil {
return err
}
if len(events) == 0 {
return nil // we are missing the events being asked to search from, give up.
}
info, err := r.DB.RoomInfo(ctx, events[0].RoomID())
info, err := r.DB.RoomInfoByNID(ctx, events[front[0]].RoomNID)
if err != nil {
return err
}
if info == nil || info.IsStub() {
return fmt.Errorf("missing RoomInfo for room %s", events[0].RoomID())
return fmt.Errorf("missing RoomInfo for room %d", events[front[0]].RoomNID)
}
resultNIDs, redactEventIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName)
@ -486,7 +512,7 @@ func (r *Queryer) QueryMissingEvents(
return err
}
loadedEvents, err := helpers.LoadEvents(ctx, r.DB, info.RoomNID, resultNIDs)
loadedEvents, err := helpers.LoadEvents(ctx, r.DB, info, resultNIDs)
if err != nil {
return err
}
@ -529,7 +555,7 @@ func (r *Queryer) QueryStateAndAuthChain(
// TODO: this probably means it should be a different query operation...
if request.OnlyFetchAuthChain {
var authEvents []*gomatrixserverlib.Event
authEvents, err = GetAuthChain(ctx, r.DB.EventsFromIDs, request.AuthEventIDs)
authEvents, err = GetAuthChain(ctx, r.DB.EventsFromIDs, info, request.AuthEventIDs)
if err != nil {
return err
}
@ -556,7 +582,7 @@ func (r *Queryer) QueryStateAndAuthChain(
}
authEventIDs = util.UniqueStrings(authEventIDs) // de-dupe
authEvents, err := GetAuthChain(ctx, r.DB.EventsFromIDs, authEventIDs)
authEvents, err := GetAuthChain(ctx, r.DB.EventsFromIDs, info, authEventIDs)
if err != nil {
return err
}
@ -611,18 +637,18 @@ func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo *types.RoomI
return nil, rejected, false, err
}
events, err := helpers.LoadStateEvents(ctx, r.DB, roomInfo.RoomNID, stateEntries)
events, err := helpers.LoadStateEvents(ctx, r.DB, roomInfo, stateEntries)
return events, rejected, false, err
}
type eventsFromIDs func(context.Context, types.RoomNID, []string) ([]types.Event, error)
type eventsFromIDs func(context.Context, *types.RoomInfo, []string) ([]types.Event, error)
// GetAuthChain fetches the auth chain for the given auth events. An auth chain
// is the list of all events that are referenced in the auth_events section, and
// all their auth_events, recursively. The returned set of events contain the
// given events. Will *not* error if we don't have all auth events.
func GetAuthChain(
ctx context.Context, fn eventsFromIDs, authEventIDs []string,
ctx context.Context, fn eventsFromIDs, roomInfo *types.RoomInfo, authEventIDs []string,
) ([]*gomatrixserverlib.Event, error) {
// List of event IDs to fetch. On each pass, these events will be requested
// from the database and the `eventsToFetch` will be updated with any new
@ -633,7 +659,7 @@ func GetAuthChain(
for len(eventsToFetch) > 0 {
// Try to retrieve the events from the database.
events, err := fn(ctx, 0, eventsToFetch)
events, err := fn(ctx, roomInfo, eventsToFetch)
if err != nil {
return nil, err
}
@ -852,7 +878,7 @@ func (r *Queryer) QueryServerBannedFromRoom(ctx context.Context, req *api.QueryS
}
func (r *Queryer) QueryAuthChain(ctx context.Context, req *api.QueryAuthChainRequest, res *api.QueryAuthChainResponse) error {
chain, err := GetAuthChain(ctx, r.DB.EventsFromIDs, req.EventIDs)
chain, err := GetAuthChain(ctx, r.DB.EventsFromIDs, nil, req.EventIDs)
if err != nil {
return err
}
@ -971,7 +997,7 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, req *api.Query
// For each of the joined users, let's see if we can get a valid
// membership event.
for _, joinNID := range joinNIDs {
events, err := r.DB.Events(ctx, roomInfo.RoomNID, []types.EventNID{joinNID})
events, err := r.DB.Events(ctx, roomInfo, []types.EventNID{joinNID})
if err != nil || len(events) != 1 {
continue
}

View file

@ -80,7 +80,7 @@ func (db *getEventDB) addFakeEvents(graph map[string][]string) error {
}
// EventsFromIDs implements RoomserverInternalAPIEventDB
func (db *getEventDB) EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) (res []types.Event, err error) {
func (db *getEventDB) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) (res []types.Event, err error) {
for _, evID := range eventIDs {
res = append(res, types.Event{
EventNID: 0,
@ -106,7 +106,7 @@ func TestGetAuthChainSingle(t *testing.T) {
t.Fatalf("Failed to add events to db: %v", err)
}
result, err := GetAuthChain(context.TODO(), db.EventsFromIDs, []string{"e"})
result, err := GetAuthChain(context.TODO(), db.EventsFromIDs, nil, []string{"e"})
if err != nil {
t.Fatalf("getAuthChain failed: %v", err)
}
@ -139,7 +139,7 @@ func TestGetAuthChainMultiple(t *testing.T) {
t.Fatalf("Failed to add events to db: %v", err)
}
result, err := GetAuthChain(context.TODO(), db.EventsFromIDs, []string{"e", "f"})
result, err := GetAuthChain(context.TODO(), db.EventsFromIDs, nil, []string{"e", "f"})
if err != nil {
t.Fatalf("getAuthChain failed: %v", err)
}