mirror of
https://github.com/hoernschen/dendrite.git
synced 2025-08-01 13:52:46 +00:00
Refactor StoreEvent
, add MaybeRedactEvent
, create an EventDatabase
(#2989)
This PR changes the following: - `StoreEvent` now only stores an event (and possibly prev event), instead of also doing redactions - Adds a `MaybeRedactEvent` (pulled out from `StoreEvent`), which should be called after storing events - a few other things
This commit is contained in:
parent
1aa70b0f56
commit
6c20f8f742
34 changed files with 488 additions and 420 deletions
|
@ -21,11 +21,12 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
"github.com/matrix-org/dendrite/internal/caching"
|
||||
"github.com/matrix-org/dendrite/roomserver/acls"
|
||||
|
@ -102,7 +103,7 @@ func (r *Queryer) QueryStateAfterEvents(
|
|||
return err
|
||||
}
|
||||
|
||||
stateEvents, err := helpers.LoadStateEvents(ctx, r.DB, info.RoomNID, stateEntries)
|
||||
stateEvents, err := helpers.LoadStateEvents(ctx, r.DB, info, stateEntries)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -114,7 +115,7 @@ func (r *Queryer) QueryStateAfterEvents(
|
|||
}
|
||||
authEventIDs = util.UniqueStrings(authEventIDs)
|
||||
|
||||
authEvents, err := GetAuthChain(ctx, r.DB.EventsFromIDs, authEventIDs)
|
||||
authEvents, err := GetAuthChain(ctx, r.DB.EventsFromIDs, info, authEventIDs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getAuthChain: %w", err)
|
||||
}
|
||||
|
@ -132,24 +133,46 @@ func (r *Queryer) QueryStateAfterEvents(
|
|||
return nil
|
||||
}
|
||||
|
||||
// QueryEventsByID implements api.RoomserverInternalAPI
|
||||
// QueryEventsByID queries a list of events by event ID for one room. If no room is specified, it will try to determine
|
||||
// which room to use by querying the first events roomID.
|
||||
func (r *Queryer) QueryEventsByID(
|
||||
ctx context.Context,
|
||||
request *api.QueryEventsByIDRequest,
|
||||
response *api.QueryEventsByIDResponse,
|
||||
) error {
|
||||
events, err := r.DB.EventsFromIDs(ctx, 0, request.EventIDs)
|
||||
if len(request.EventIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
var err error
|
||||
// We didn't receive a room ID, we need to fetch it first before we can continue.
|
||||
// This happens for e.g. ` /_matrix/federation/v1/event/{eventId}`
|
||||
var roomInfo *types.RoomInfo
|
||||
if request.RoomID == "" {
|
||||
var eventNIDs map[string]types.EventMetadata
|
||||
eventNIDs, err = r.DB.EventNIDs(ctx, []string{request.EventIDs[0]})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(eventNIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
roomInfo, err = r.DB.RoomInfoByNID(ctx, eventNIDs[request.EventIDs[0]].RoomNID)
|
||||
} else {
|
||||
roomInfo, err = r.DB.RoomInfo(ctx, request.RoomID)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if roomInfo == nil {
|
||||
return nil
|
||||
}
|
||||
events, err := r.DB.EventsFromIDs(ctx, roomInfo, request.EventIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, event := range events {
|
||||
roomVersion, verr := r.roomVersion(event.RoomID())
|
||||
if verr != nil {
|
||||
return verr
|
||||
}
|
||||
|
||||
response.Events = append(response.Events, event.Headered(roomVersion))
|
||||
response.Events = append(response.Events, event.Headered(roomInfo.RoomVersion))
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -186,7 +209,7 @@ func (r *Queryer) QueryMembershipForUser(
|
|||
response.IsInRoom = stillInRoom
|
||||
response.HasBeenInRoom = true
|
||||
|
||||
evs, err := r.DB.Events(ctx, info.RoomNID, []types.EventNID{membershipEventNID})
|
||||
evs, err := r.DB.Events(ctx, info, []types.EventNID{membershipEventNID})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -268,10 +291,10 @@ func (r *Queryer) QueryMembershipAtEvent(
|
|||
// once. If we have more than one membership event, we need to get the state for each state entry.
|
||||
if canShortCircuit {
|
||||
if len(memberships) == 0 {
|
||||
memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info.RoomNID, stateEntry, false)
|
||||
memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info, stateEntry, false)
|
||||
}
|
||||
} else {
|
||||
memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info.RoomNID, stateEntry, false)
|
||||
memberships, err = helpers.GetMembershipsAtState(ctx, r.DB, info, stateEntry, false)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to get memberships at state: %w", err)
|
||||
|
@ -318,7 +341,7 @@ func (r *Queryer) QueryMembershipsForRoom(
|
|||
}
|
||||
return fmt.Errorf("r.DB.GetMembershipEventNIDsForRoom: %w", err)
|
||||
}
|
||||
events, err = r.DB.Events(ctx, info.RoomNID, eventNIDs)
|
||||
events, err = r.DB.Events(ctx, info, eventNIDs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("r.DB.Events: %w", err)
|
||||
}
|
||||
|
@ -357,14 +380,14 @@ func (r *Queryer) QueryMembershipsForRoom(
|
|||
return err
|
||||
}
|
||||
|
||||
events, err = r.DB.Events(ctx, info.RoomNID, eventNIDs)
|
||||
events, err = r.DB.Events(ctx, info, eventNIDs)
|
||||
} else {
|
||||
stateEntries, err = helpers.StateBeforeEvent(ctx, r.DB, info, membershipEventNID)
|
||||
if err != nil {
|
||||
logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event")
|
||||
return err
|
||||
}
|
||||
events, err = helpers.GetMembershipsAtState(ctx, r.DB, info.RoomNID, stateEntries, request.JoinedOnly)
|
||||
events, err = helpers.GetMembershipsAtState(ctx, r.DB, info, stateEntries, request.JoinedOnly)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
@ -412,39 +435,39 @@ func (r *Queryer) QueryServerJoinedToRoom(
|
|||
// QueryServerAllowedToSeeEvent implements api.RoomserverInternalAPI
|
||||
func (r *Queryer) QueryServerAllowedToSeeEvent(
|
||||
ctx context.Context,
|
||||
request *api.QueryServerAllowedToSeeEventRequest,
|
||||
response *api.QueryServerAllowedToSeeEventResponse,
|
||||
) (err error) {
|
||||
events, err := r.DB.EventsFromIDs(ctx, 0, []string{request.EventID})
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
eventID string,
|
||||
) (allowed bool, err error) {
|
||||
events, err := r.DB.EventNIDs(ctx, []string{eventID})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if len(events) == 0 {
|
||||
response.AllowedToSeeEvent = false // event doesn't exist so not allowed to see
|
||||
return
|
||||
return allowed, nil
|
||||
}
|
||||
roomID := events[0].RoomID()
|
||||
|
||||
inRoomReq := &api.QueryServerJoinedToRoomRequest{
|
||||
RoomID: roomID,
|
||||
ServerName: request.ServerName,
|
||||
}
|
||||
inRoomRes := &api.QueryServerJoinedToRoomResponse{}
|
||||
if err = r.QueryServerJoinedToRoom(ctx, inRoomReq, inRoomRes); err != nil {
|
||||
return fmt.Errorf("r.Queryer.QueryServerJoinedToRoom: %w", err)
|
||||
}
|
||||
|
||||
info, err := r.DB.RoomInfo(ctx, roomID)
|
||||
info, err := r.DB.RoomInfoByNID(ctx, events[eventID].RoomNID)
|
||||
if err != nil {
|
||||
return err
|
||||
return allowed, err
|
||||
}
|
||||
if info == nil || info.IsStub() {
|
||||
return nil
|
||||
return allowed, nil
|
||||
}
|
||||
response.AllowedToSeeEvent, err = helpers.CheckServerAllowedToSeeEvent(
|
||||
ctx, r.DB, info, request.EventID, request.ServerName, inRoomRes.IsInRoom,
|
||||
var isInRoom bool
|
||||
if r.IsLocalServerName(serverName) || serverName == "" {
|
||||
isInRoom, err = r.DB.GetLocalServerInRoom(ctx, info.RoomNID)
|
||||
if err != nil {
|
||||
return allowed, fmt.Errorf("r.DB.GetLocalServerInRoom: %w", err)
|
||||
}
|
||||
} else {
|
||||
isInRoom, err = r.DB.GetServerInRoom(ctx, info.RoomNID, serverName)
|
||||
if err != nil {
|
||||
return allowed, fmt.Errorf("r.DB.GetServerInRoom: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return helpers.CheckServerAllowedToSeeEvent(
|
||||
ctx, r.DB, info, eventID, serverName, isInRoom,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
// QueryMissingEvents implements api.RoomserverInternalAPI
|
||||
|
@ -466,19 +489,22 @@ func (r *Queryer) QueryMissingEvents(
|
|||
eventsToFilter[id] = true
|
||||
}
|
||||
}
|
||||
events, err := r.DB.EventsFromIDs(ctx, 0, front)
|
||||
if len(front) == 0 {
|
||||
return nil // no events to query, give up.
|
||||
}
|
||||
events, err := r.DB.EventNIDs(ctx, []string{front[0]})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(events) == 0 {
|
||||
return nil // we are missing the events being asked to search from, give up.
|
||||
}
|
||||
info, err := r.DB.RoomInfo(ctx, events[0].RoomID())
|
||||
info, err := r.DB.RoomInfoByNID(ctx, events[front[0]].RoomNID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if info == nil || info.IsStub() {
|
||||
return fmt.Errorf("missing RoomInfo for room %s", events[0].RoomID())
|
||||
return fmt.Errorf("missing RoomInfo for room %d", events[front[0]].RoomNID)
|
||||
}
|
||||
|
||||
resultNIDs, redactEventIDs, err := helpers.ScanEventTree(ctx, r.DB, info, front, visited, request.Limit, request.ServerName)
|
||||
|
@ -486,7 +512,7 @@ func (r *Queryer) QueryMissingEvents(
|
|||
return err
|
||||
}
|
||||
|
||||
loadedEvents, err := helpers.LoadEvents(ctx, r.DB, info.RoomNID, resultNIDs)
|
||||
loadedEvents, err := helpers.LoadEvents(ctx, r.DB, info, resultNIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -529,7 +555,7 @@ func (r *Queryer) QueryStateAndAuthChain(
|
|||
// TODO: this probably means it should be a different query operation...
|
||||
if request.OnlyFetchAuthChain {
|
||||
var authEvents []*gomatrixserverlib.Event
|
||||
authEvents, err = GetAuthChain(ctx, r.DB.EventsFromIDs, request.AuthEventIDs)
|
||||
authEvents, err = GetAuthChain(ctx, r.DB.EventsFromIDs, info, request.AuthEventIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -556,7 +582,7 @@ func (r *Queryer) QueryStateAndAuthChain(
|
|||
}
|
||||
authEventIDs = util.UniqueStrings(authEventIDs) // de-dupe
|
||||
|
||||
authEvents, err := GetAuthChain(ctx, r.DB.EventsFromIDs, authEventIDs)
|
||||
authEvents, err := GetAuthChain(ctx, r.DB.EventsFromIDs, info, authEventIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -611,18 +637,18 @@ func (r *Queryer) loadStateAtEventIDs(ctx context.Context, roomInfo *types.RoomI
|
|||
return nil, rejected, false, err
|
||||
}
|
||||
|
||||
events, err := helpers.LoadStateEvents(ctx, r.DB, roomInfo.RoomNID, stateEntries)
|
||||
events, err := helpers.LoadStateEvents(ctx, r.DB, roomInfo, stateEntries)
|
||||
return events, rejected, false, err
|
||||
}
|
||||
|
||||
type eventsFromIDs func(context.Context, types.RoomNID, []string) ([]types.Event, error)
|
||||
type eventsFromIDs func(context.Context, *types.RoomInfo, []string) ([]types.Event, error)
|
||||
|
||||
// GetAuthChain fetches the auth chain for the given auth events. An auth chain
|
||||
// is the list of all events that are referenced in the auth_events section, and
|
||||
// all their auth_events, recursively. The returned set of events contain the
|
||||
// given events. Will *not* error if we don't have all auth events.
|
||||
func GetAuthChain(
|
||||
ctx context.Context, fn eventsFromIDs, authEventIDs []string,
|
||||
ctx context.Context, fn eventsFromIDs, roomInfo *types.RoomInfo, authEventIDs []string,
|
||||
) ([]*gomatrixserverlib.Event, error) {
|
||||
// List of event IDs to fetch. On each pass, these events will be requested
|
||||
// from the database and the `eventsToFetch` will be updated with any new
|
||||
|
@ -633,7 +659,7 @@ func GetAuthChain(
|
|||
|
||||
for len(eventsToFetch) > 0 {
|
||||
// Try to retrieve the events from the database.
|
||||
events, err := fn(ctx, 0, eventsToFetch)
|
||||
events, err := fn(ctx, roomInfo, eventsToFetch)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -852,7 +878,7 @@ func (r *Queryer) QueryServerBannedFromRoom(ctx context.Context, req *api.QueryS
|
|||
}
|
||||
|
||||
func (r *Queryer) QueryAuthChain(ctx context.Context, req *api.QueryAuthChainRequest, res *api.QueryAuthChainResponse) error {
|
||||
chain, err := GetAuthChain(ctx, r.DB.EventsFromIDs, req.EventIDs)
|
||||
chain, err := GetAuthChain(ctx, r.DB.EventsFromIDs, nil, req.EventIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -971,7 +997,7 @@ func (r *Queryer) QueryRestrictedJoinAllowed(ctx context.Context, req *api.Query
|
|||
// For each of the joined users, let's see if we can get a valid
|
||||
// membership event.
|
||||
for _, joinNID := range joinNIDs {
|
||||
events, err := r.DB.Events(ctx, roomInfo.RoomNID, []types.EventNID{joinNID})
|
||||
events, err := r.DB.Events(ctx, roomInfo, []types.EventNID{joinNID})
|
||||
if err != nil || len(events) != 1 {
|
||||
continue
|
||||
}
|
||||
|
|
|
@ -80,7 +80,7 @@ func (db *getEventDB) addFakeEvents(graph map[string][]string) error {
|
|||
}
|
||||
|
||||
// EventsFromIDs implements RoomserverInternalAPIEventDB
|
||||
func (db *getEventDB) EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) (res []types.Event, err error) {
|
||||
func (db *getEventDB) EventsFromIDs(ctx context.Context, roomInfo *types.RoomInfo, eventIDs []string) (res []types.Event, err error) {
|
||||
for _, evID := range eventIDs {
|
||||
res = append(res, types.Event{
|
||||
EventNID: 0,
|
||||
|
@ -106,7 +106,7 @@ func TestGetAuthChainSingle(t *testing.T) {
|
|||
t.Fatalf("Failed to add events to db: %v", err)
|
||||
}
|
||||
|
||||
result, err := GetAuthChain(context.TODO(), db.EventsFromIDs, []string{"e"})
|
||||
result, err := GetAuthChain(context.TODO(), db.EventsFromIDs, nil, []string{"e"})
|
||||
if err != nil {
|
||||
t.Fatalf("getAuthChain failed: %v", err)
|
||||
}
|
||||
|
@ -139,7 +139,7 @@ func TestGetAuthChainMultiple(t *testing.T) {
|
|||
t.Fatalf("Failed to add events to db: %v", err)
|
||||
}
|
||||
|
||||
result, err := GetAuthChain(context.TODO(), db.EventsFromIDs, []string{"e", "f"})
|
||||
result, err := GetAuthChain(context.TODO(), db.EventsFromIDs, nil, []string{"e", "f"})
|
||||
if err != nil {
|
||||
t.Fatalf("getAuthChain failed: %v", err)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue