Use PDU not *Event in HeaderedEvent (#3073)

Requires https://github.com/matrix-org/gomatrixserverlib/pull/376

This has numerous upsides:
 - Less type casting to `*Event` is required.
- Making Dendrite work with `PDU` interfaces means we can swap out Event
impls more easily.
 - Tests which represent weird event shapes are easier to write.

Part of a series of refactors on GMSL.
This commit is contained in:
kegsay 2023-05-02 15:03:16 +01:00 committed by GitHub
parent 696cbb70b8
commit f5b3144dc3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
64 changed files with 296 additions and 284 deletions

View file

@ -102,7 +102,7 @@ func (r *Inputer) processRoomEvent(
// Parse and validate the event JSON
headered := input.Event
event := headered.Event
event := headered.PDU
logger := util.GetLogger(ctx).WithFields(logrus.Fields{
"event_id": event.EventID(),
"room_id": event.RoomID(),
@ -232,7 +232,7 @@ func (r *Inputer) processRoomEvent(
roomsMu: internal.NewMutexByRoom(),
servers: serverRes.ServerNames,
hadEvents: map[string]bool{},
haveEvents: map[string]*gomatrixserverlib.Event{},
haveEvents: map[string]gomatrixserverlib.PDU{},
}
var stateSnapshot *parsedRespState
if stateSnapshot, err = missingState.processEventWithMissingState(ctx, event, headered.Version()); err != nil {
@ -389,8 +389,8 @@ func (r *Inputer) processRoomEvent(
// we do this after calculating state for this event as we may need to get power levels
var (
redactedEventID string
redactionEvent *gomatrixserverlib.Event
redactedEvent *gomatrixserverlib.Event
redactionEvent gomatrixserverlib.PDU
redactedEvent gomatrixserverlib.PDU
)
if !isRejected && !isCreateEvent {
resolver := state.NewStateResolution(r.DB, roomInfo)
@ -467,7 +467,7 @@ func (r *Inputer) processRoomEvent(
Type: api.OutputTypeRedactedEvent,
RedactedEvent: &api.OutputRedactedEvent{
RedactedEventID: redactedEventID,
RedactedBecause: &types.HeaderedEvent{Event: redactionEvent},
RedactedBecause: &types.HeaderedEvent{PDU: redactionEvent},
},
},
})
@ -490,7 +490,7 @@ func (r *Inputer) processRoomEvent(
}
// handleRemoteRoomUpgrade updates published rooms and room aliases
func (r *Inputer) handleRemoteRoomUpgrade(ctx context.Context, event *gomatrixserverlib.Event) error {
func (r *Inputer) handleRemoteRoomUpgrade(ctx context.Context, event gomatrixserverlib.PDU) error {
oldRoomID := event.RoomID()
newRoomID := gjson.GetBytes(event.Content(), "replacement_room").Str
return r.DB.UpgradeRoom(ctx, oldRoomID, newRoomID, event.Sender())
@ -509,9 +509,9 @@ func (r *Inputer) processStateBefore(
missingPrev bool,
) (historyVisibility gomatrixserverlib.HistoryVisibility, rejectionErr error, err error) {
historyVisibility = gomatrixserverlib.HistoryVisibilityShared // Default to shared.
event := input.Event.Event
event := input.Event.PDU
isCreateEvent := event.Type() == spec.MRoomCreate && event.StateKeyEquals("")
var stateBeforeEvent []*gomatrixserverlib.Event
var stateBeforeEvent []gomatrixserverlib.PDU
switch {
case isCreateEvent:
// There's no state before a create event so there is nothing
@ -524,9 +524,9 @@ func (r *Inputer) processStateBefore(
if err != nil {
return "", nil, fmt.Errorf("r.DB.EventsFromIDs: %w", err)
}
stateBeforeEvent = make([]*gomatrixserverlib.Event, 0, len(stateEvents))
stateBeforeEvent = make([]gomatrixserverlib.PDU, 0, len(stateEvents))
for _, entry := range stateEvents {
stateBeforeEvent = append(stateBeforeEvent, entry.Event)
stateBeforeEvent = append(stateBeforeEvent, entry.PDU)
}
case missingPrev:
// We don't know all of the prev events, so we can't work out
@ -567,9 +567,9 @@ func (r *Inputer) processStateBefore(
rejectionErr = fmt.Errorf("prev events of %q are not known", event.EventID())
return
default:
stateBeforeEvent = make([]*gomatrixserverlib.Event, len(stateBeforeRes.StateEvents))
stateBeforeEvent = make([]gomatrixserverlib.PDU, len(stateBeforeRes.StateEvents))
for i := range stateBeforeRes.StateEvents {
stateBeforeEvent[i] = stateBeforeRes.StateEvents[i].Event
stateBeforeEvent[i] = stateBeforeRes.StateEvents[i].PDU
}
}
}
@ -626,7 +626,7 @@ func (r *Inputer) fetchAuthEvents(
for _, authEventID := range authEventIDs {
authEvents, err := r.DB.EventsFromIDs(ctx, roomInfo, []string{authEventID})
if err != nil || len(authEvents) == 0 || authEvents[0].Event == nil {
if err != nil || len(authEvents) == 0 || authEvents[0].PDU == nil {
unknown[authEventID] = struct{}{}
continue
}
@ -641,7 +641,7 @@ func (r *Inputer) fetchAuthEvents(
}
known[authEventID] = &ev // don't take the pointer of the iterated event
if !isRejected {
if err = auth.AddEvent(ev.Event); err != nil {
if err = auth.AddEvent(ev.PDU); err != nil {
return fmt.Errorf("auth.AddEvent: %w", err)
}
}
@ -745,7 +745,7 @@ nextAuthEvent:
// Now we know about this event, it was stored and the signatures were OK.
known[authEvent.EventID()] = &types.Event{
EventNID: eventNID,
Event: authEvent.(*gomatrixserverlib.Event),
PDU: authEvent,
}
}
@ -757,7 +757,7 @@ func (r *Inputer) calculateAndSetState(
input *api.InputRoomEvent,
roomInfo *types.RoomInfo,
stateAtEvent *types.StateAtEvent,
event *gomatrixserverlib.Event,
event gomatrixserverlib.PDU,
isRejected bool,
) error {
trace, ctx := internal.StartRegion(ctx, "calculateAndSetState")
@ -799,7 +799,7 @@ func (r *Inputer) calculateAndSetState(
}
// kickGuests kicks guests users from m.room.guest_access rooms, if guest access is now prohibited.
func (r *Inputer) kickGuests(ctx context.Context, event *gomatrixserverlib.Event, roomInfo *types.RoomInfo) error {
func (r *Inputer) kickGuests(ctx context.Context, event gomatrixserverlib.PDU, roomInfo *types.RoomInfo) error {
membershipNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, roomInfo.RoomNID, true, true)
if err != nil {
return err

View file

@ -18,21 +18,21 @@ func Test_EventAuth(t *testing.T) {
room2 := test.NewRoom(t, alice, test.RoomPreset(test.PresetPublicChat))
authEventIDs := make([]string, 0, 4)
authEvents := []*gomatrixserverlib.Event{}
authEvents := []gomatrixserverlib.PDU{}
// Add the legal auth events from room2
for _, x := range room2.Events() {
if x.Type() == spec.MRoomCreate {
authEventIDs = append(authEventIDs, x.EventID())
authEvents = append(authEvents, x.Event)
authEvents = append(authEvents, x.PDU)
}
if x.Type() == spec.MRoomPowerLevels {
authEventIDs = append(authEventIDs, x.EventID())
authEvents = append(authEvents, x.Event)
authEvents = append(authEvents, x.PDU)
}
if x.Type() == spec.MRoomJoinRules {
authEventIDs = append(authEventIDs, x.EventID())
authEvents = append(authEvents, x.Event)
authEvents = append(authEvents, x.PDU)
}
}
@ -40,7 +40,7 @@ func Test_EventAuth(t *testing.T) {
for _, x := range room1.Events() {
if x.Type() == spec.MRoomMember {
authEventIDs = append(authEventIDs, x.EventID())
authEvents = append(authEvents, x.Event)
authEvents = append(authEvents, x.PDU)
}
}
@ -58,7 +58,7 @@ func Test_EventAuth(t *testing.T) {
}
// Finally check that the event is NOT allowed
if err := gomatrixserverlib.Allowed(ev.Event, &allower); err == nil {
if err := gomatrixserverlib.Allowed(ev.PDU, &allower); err == nil {
t.Fatalf("event should not be allowed, but it was")
}
}

View file

@ -53,7 +53,7 @@ func (r *Inputer) updateLatestEvents(
ctx context.Context,
roomInfo *types.RoomInfo,
stateAtEvent types.StateAtEvent,
event *gomatrixserverlib.Event,
event gomatrixserverlib.PDU,
sendAsServer string,
transactionID *api.TransactionID,
rewritesState bool,
@ -101,7 +101,7 @@ type latestEventsUpdater struct {
updater *shared.RoomUpdater
roomInfo *types.RoomInfo
stateAtEvent types.StateAtEvent
event *gomatrixserverlib.Event
event gomatrixserverlib.PDU
transactionID *api.TransactionID
rewritesState bool
// Which server to send this event as.
@ -326,7 +326,7 @@ func (u *latestEventsUpdater) latestState() error {
// true if the new event is included in those extremites, false otherwise.
func (u *latestEventsUpdater) calculateLatest(
oldLatest []types.StateAtEventAndReference,
newEvent *gomatrixserverlib.Event,
newEvent gomatrixserverlib.PDU,
newStateAndRef types.StateAtEventAndReference,
) (bool, error) {
trace, _ := internal.StartRegion(u.ctx, "calculateLatest")
@ -393,7 +393,7 @@ func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error)
}
ore := api.OutputNewRoomEvent{
Event: &types.HeaderedEvent{Event: u.event},
Event: &types.HeaderedEvent{PDU: u.event},
RewritesState: u.rewritesState,
LastSentEventID: u.lastEventIDSent,
LatestEventIDs: latestEventIDs,

View file

@ -22,19 +22,19 @@ import (
)
type parsedRespState struct {
AuthEvents []*gomatrixserverlib.Event
StateEvents []*gomatrixserverlib.Event
AuthEvents []gomatrixserverlib.PDU
StateEvents []gomatrixserverlib.PDU
}
func (p *parsedRespState) Events() []gomatrixserverlib.PDU {
eventsByID := make(map[string]*gomatrixserverlib.Event, len(p.AuthEvents)+len(p.StateEvents))
eventsByID := make(map[string]gomatrixserverlib.PDU, len(p.AuthEvents)+len(p.StateEvents))
for i, event := range p.AuthEvents {
eventsByID[event.EventID()] = p.AuthEvents[i]
}
for i, event := range p.StateEvents {
eventsByID[event.EventID()] = p.StateEvents[i]
}
allEvents := make([]*gomatrixserverlib.Event, 0, len(eventsByID))
allEvents := make([]gomatrixserverlib.PDU, 0, len(eventsByID))
for _, event := range eventsByID {
allEvents = append(allEvents, event)
}
@ -55,7 +55,7 @@ type missingStateReq struct {
servers []spec.ServerName
hadEvents map[string]bool
hadEventsMutex sync.Mutex
haveEvents map[string]*gomatrixserverlib.Event
haveEvents map[string]gomatrixserverlib.PDU
haveEventsMutex sync.Mutex
}
@ -63,7 +63,7 @@ type missingStateReq struct {
// request, as called from processRoomEvent.
// nolint:gocyclo
func (t *missingStateReq) processEventWithMissingState(
ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion,
ctx context.Context, e gomatrixserverlib.PDU, roomVersion gomatrixserverlib.RoomVersion,
) (*parsedRespState, error) {
trace, ctx := internal.StartRegion(ctx, "processEventWithMissingState")
defer trace.EndRegion()
@ -107,7 +107,7 @@ func (t *missingStateReq) processEventWithMissingState(
for _, newEvent := range newEvents {
err = t.inputer.processRoomEvent(ctx, t.virtualHost, &api.InputRoomEvent{
Kind: api.KindOld,
Event: &types.HeaderedEvent{Event: newEvent},
Event: &types.HeaderedEvent{PDU: newEvent},
Origin: t.origin,
SendAsServer: api.DoNotSendToOtherServers,
})
@ -156,7 +156,7 @@ func (t *missingStateReq) processEventWithMissingState(
}
outlierRoomEvents = append(outlierRoomEvents, api.InputRoomEvent{
Kind: api.KindOutlier,
Event: &types.HeaderedEvent{Event: outlier.(*gomatrixserverlib.Event)},
Event: &types.HeaderedEvent{PDU: outlier},
Origin: t.origin,
})
}
@ -186,7 +186,7 @@ func (t *missingStateReq) processEventWithMissingState(
err = t.inputer.processRoomEvent(ctx, t.virtualHost, &api.InputRoomEvent{
Kind: api.KindOld,
Event: &types.HeaderedEvent{Event: backwardsExtremity},
Event: &types.HeaderedEvent{PDU: backwardsExtremity},
Origin: t.origin,
HasState: true,
StateEventIDs: stateIDs,
@ -205,7 +205,7 @@ func (t *missingStateReq) processEventWithMissingState(
for _, newEvent := range newEvents {
err = t.inputer.processRoomEvent(ctx, t.virtualHost, &api.InputRoomEvent{
Kind: api.KindOld,
Event: &types.HeaderedEvent{Event: newEvent},
Event: &types.HeaderedEvent{PDU: newEvent},
Origin: t.origin,
SendAsServer: api.DoNotSendToOtherServers,
})
@ -243,7 +243,7 @@ func (t *missingStateReq) processEventWithMissingState(
return resolvedState, nil
}
func (t *missingStateReq) lookupResolvedStateBeforeEvent(ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) (*parsedRespState, error) {
func (t *missingStateReq) lookupResolvedStateBeforeEvent(ctx context.Context, e gomatrixserverlib.PDU, roomVersion gomatrixserverlib.RoomVersion) (*parsedRespState, error) {
trace, ctx := internal.StartRegion(ctx, "lookupResolvedStateBeforeEvent")
defer trace.EndRegion()
@ -368,7 +368,7 @@ func (t *missingStateReq) lookupStateAfterEvent(ctx context.Context, roomVersion
return respState, false, nil
}
func (t *missingStateReq) cacheAndReturn(ev *gomatrixserverlib.Event) *gomatrixserverlib.Event {
func (t *missingStateReq) cacheAndReturn(ev gomatrixserverlib.PDU) gomatrixserverlib.PDU {
t.haveEventsMutex.Lock()
defer t.haveEventsMutex.Unlock()
if cached, exists := t.haveEvents[ev.EventID()]; exists {
@ -403,11 +403,11 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, even
t.log.WithError(err).Warnf("failed to load state events locally")
return nil
}
res.StateEvents = make([]*gomatrixserverlib.Event, 0, len(stateEvents))
res.StateEvents = make([]gomatrixserverlib.PDU, 0, len(stateEvents))
for _, ev := range stateEvents {
// set the event from the haveEvents cache - this means we will share pointers with other prev_event branches for this
// processEvent request, which is better for memory.
res.StateEvents = append(res.StateEvents, t.cacheAndReturn(ev.Event))
res.StateEvents = append(res.StateEvents, t.cacheAndReturn(ev.PDU))
t.hadEvent(ev.EventID())
}
@ -415,7 +415,7 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, even
stateEvents, stateEventNIDs, stateEntries, stateAtEvents = nil, nil, nil, nil // nolint:ineffassign
missingAuthEvents := map[string]bool{}
res.AuthEvents = make([]*gomatrixserverlib.Event, 0, len(stateEvents)*3)
res.AuthEvents = make([]gomatrixserverlib.PDU, 0, len(stateEvents)*3)
for _, ev := range stateEvents {
t.haveEventsMutex.Lock()
for _, ae := range ev.AuthEventIDs() {
@ -440,7 +440,7 @@ func (t *missingStateReq) lookupStateAfterEventLocally(ctx context.Context, even
return nil
}
for i, ev := range events {
res.AuthEvents = append(res.AuthEvents, t.cacheAndReturn(events[i].Event))
res.AuthEvents = append(res.AuthEvents, t.cacheAndReturn(events[i].PDU))
t.hadEvent(ev.EventID())
}
}
@ -459,12 +459,12 @@ func (t *missingStateReq) lookupStateBeforeEvent(ctx context.Context, roomVersio
return t.lookupMissingStateViaStateIDs(ctx, roomID, eventID, roomVersion)
}
func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, states []*parsedRespState, backwardsExtremity *gomatrixserverlib.Event) (*parsedRespState, error) {
func (t *missingStateReq) resolveStatesAndCheck(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, states []*parsedRespState, backwardsExtremity gomatrixserverlib.PDU) (*parsedRespState, error) {
trace, ctx := internal.StartRegion(ctx, "resolveStatesAndCheck")
defer trace.EndRegion()
var authEventList []*gomatrixserverlib.Event
var stateEventList []*gomatrixserverlib.Event
var authEventList []gomatrixserverlib.PDU
var stateEventList []gomatrixserverlib.PDU
for _, state := range states {
authEventList = append(authEventList, state.AuthEvents...)
stateEventList = append(stateEventList, state.StateEvents...)
@ -485,7 +485,7 @@ retryAllowedState:
case verifySigError:
return &parsedRespState{
AuthEvents: authEventList,
StateEvents: gomatrixserverlib.TempCastToEvents(resolvedStateEvents),
StateEvents: resolvedStateEvents,
}, nil
case nil:
// do nothing
@ -501,13 +501,13 @@ retryAllowedState:
}
return &parsedRespState{
AuthEvents: authEventList,
StateEvents: gomatrixserverlib.TempCastToEvents(resolvedStateEvents),
StateEvents: resolvedStateEvents,
}, nil
}
// get missing events for `e`. If `isGapFilled`=true then `newEvents` contains all the events to inject,
// without `e`. If `isGapFilled=false` then `newEvents` contains the response to /get_missing_events
func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserverlib.Event, roomVersion gomatrixserverlib.RoomVersion) (newEvents []*gomatrixserverlib.Event, isGapFilled, prevStateKnown bool, err error) {
func (t *missingStateReq) getMissingEvents(ctx context.Context, e gomatrixserverlib.PDU, roomVersion gomatrixserverlib.RoomVersion) (newEvents []gomatrixserverlib.PDU, isGapFilled, prevStateKnown bool, err error) {
trace, ctx := internal.StartRegion(ctx, "getMissingEvents")
defer trace.EndRegion()
@ -560,7 +560,7 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserve
// Make sure events from the missingResp are using the cache - missing events
// will be added and duplicates will be removed.
missingEvents := make([]*gomatrixserverlib.Event, 0, len(missingResp.Events))
missingEvents := make([]gomatrixserverlib.PDU, 0, len(missingResp.Events))
for _, ev := range missingResp.Events.UntrustedEvents(roomVersion) {
if err = gomatrixserverlib.VerifyEventSignatures(ctx, ev, t.keys); err != nil {
continue
@ -570,9 +570,8 @@ func (t *missingStateReq) getMissingEvents(ctx context.Context, e *gomatrixserve
logger.Debugf("get_missing_events returned %d events (%d passed signature checks)", len(missingResp.Events), len(missingEvents))
// topologically sort and sanity check that we are making forward progress
newEventsPDUs := gomatrixserverlib.ReverseTopologicalOrdering(
newEvents = gomatrixserverlib.ReverseTopologicalOrdering(
gomatrixserverlib.ToPDUs(missingEvents), gomatrixserverlib.TopologicalOrderByPrevEvents)
newEvents = gomatrixserverlib.TempCastToEvents(newEventsPDUs)
shouldHaveSomeEventIDs := e.PrevEventIDs()
hasPrevEvent := false
Event:
@ -618,7 +617,7 @@ Event:
return newEvents, true, t.isPrevStateKnown(ctx, e), nil
}
func (t *missingStateReq) isPrevStateKnown(ctx context.Context, e *gomatrixserverlib.Event) bool {
func (t *missingStateReq) isPrevStateKnown(ctx context.Context, e gomatrixserverlib.PDU) bool {
expected := len(e.PrevEventIDs())
state, err := t.db.StateAtEventIDs(ctx, e.PrevEventIDs())
if err != nil || len(state) != expected {
@ -707,7 +706,7 @@ func (t *missingStateReq) lookupMissingStateViaStateIDs(ctx context.Context, roo
}
for i, ev := range events {
events[i].Event = t.cacheAndReturn(events[i].Event)
events[i].PDU = t.cacheAndReturn(events[i].PDU)
t.hadEvent(ev.EventID())
evID := events[i].EventID()
if missing[evID] {
@ -839,7 +838,7 @@ func (t *missingStateReq) createRespStateFromStateIDs(
return &respState, nil
}
func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, _, missingEventID string, localFirst bool) (*gomatrixserverlib.Event, error) {
func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixserverlib.RoomVersion, _, missingEventID string, localFirst bool) (gomatrixserverlib.PDU, error) {
trace, ctx := internal.StartRegion(ctx, "lookupEvent")
defer trace.EndRegion()
@ -854,7 +853,7 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs
if err != nil {
t.log.Warnf("Failed to query roomserver for missing event %s: %s - falling back to remote", missingEventID, err)
} else if len(events) == 1 {
return events[0].Event, nil
return events[0].PDU, nil
}
}
var event *gomatrixserverlib.Event
@ -894,7 +893,7 @@ func (t *missingStateReq) lookupEvent(ctx context.Context, roomVersion gomatrixs
return t.cacheAndReturn(event), nil
}
func checkAllowedByState(e *gomatrixserverlib.Event, stateEvents []gomatrixserverlib.PDU) error {
func checkAllowedByState(e gomatrixserverlib.PDU, stateEvents []gomatrixserverlib.PDU) error {
authUsingState := gomatrixserverlib.NewAuthEvents(nil)
for i := range stateEvents {
err := authUsingState.AddEvent(stateEvents[i])

View file

@ -45,7 +45,7 @@ func TestSingleTransactionOnInput(t *testing.T) {
}
in := api.InputRoomEvent{
Kind: api.KindOutlier, // don't panic if we generate an output event
Event: &types.HeaderedEvent{Event: event},
Event: &types.HeaderedEvent{PDU: event},
}
inputter := &input.Inputer{