Refactor roomserver/internal - split perform stuff out (#1380)

- New package `perform` which contains all `Perform` functions
- New package `helpers` which contains helper functions used by both
  perform and query/input functions.
- Perform invite/leave have no idea how to `WriteOutputEvents` and this
  is now returned from `PerformInvite` or `PerformLeave` respectively.

Still to do:
 - RSAPI is fed into the inviter/joiner/leaver - this introduces circular
   logic so will need to be removed.
 - Put query operations in a `query` package.
 - Put input operations (and output) in an `input` package.
 - Factor out helper functions as much as possible, possibly rejigging the
   storage layer in the process.
This commit is contained in:
Kegsay 2020-09-02 13:47:31 +01:00 committed by GitHub
parent 02a73f29f8
commit e473320e73
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 820 additions and 647 deletions

View file

@ -0,0 +1,244 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package helpers
import (
"context"
"sort"
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
)
// CheckAuthEvents checks that the event passes authentication checks
// Returns the numeric IDs for the auth events.
func CheckAuthEvents(
ctx context.Context,
db storage.Database,
event gomatrixserverlib.HeaderedEvent,
authEventIDs []string,
) ([]types.EventNID, error) {
// Grab the numeric IDs for the supplied auth state events from the database.
authStateEntries, err := db.StateEntriesForEventIDs(ctx, authEventIDs)
if err != nil {
return nil, err
}
// TODO: check for duplicate state keys here.
// Work out which of the state events we actually need.
stateNeeded := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{event.Unwrap()})
// Load the actual auth events from the database.
authEvents, err := loadAuthEvents(ctx, db, stateNeeded, authStateEntries)
if err != nil {
return nil, err
}
// Check if the event is allowed.
if err = gomatrixserverlib.Allowed(event.Event, &authEvents); err != nil {
return nil, err
}
// Return the numeric IDs for the auth events.
result := make([]types.EventNID, len(authStateEntries))
for i := range authStateEntries {
result[i] = authStateEntries[i].EventNID
}
return result, nil
}
type authEvents struct {
stateKeyNIDMap map[string]types.EventStateKeyNID
state stateEntryMap
events EventMap
}
// Create implements gomatrixserverlib.AuthEventProvider
func (ae *authEvents) Create() (*gomatrixserverlib.Event, error) {
return ae.lookupEventWithEmptyStateKey(types.MRoomCreateNID), nil
}
// PowerLevels implements gomatrixserverlib.AuthEventProvider
func (ae *authEvents) PowerLevels() (*gomatrixserverlib.Event, error) {
return ae.lookupEventWithEmptyStateKey(types.MRoomPowerLevelsNID), nil
}
// JoinRules implements gomatrixserverlib.AuthEventProvider
func (ae *authEvents) JoinRules() (*gomatrixserverlib.Event, error) {
return ae.lookupEventWithEmptyStateKey(types.MRoomJoinRulesNID), nil
}
// Memmber implements gomatrixserverlib.AuthEventProvider
func (ae *authEvents) Member(stateKey string) (*gomatrixserverlib.Event, error) {
return ae.lookupEvent(types.MRoomMemberNID, stateKey), nil
}
// ThirdPartyInvite implements gomatrixserverlib.AuthEventProvider
func (ae *authEvents) ThirdPartyInvite(stateKey string) (*gomatrixserverlib.Event, error) {
return ae.lookupEvent(types.MRoomThirdPartyInviteNID, stateKey), nil
}
func (ae *authEvents) lookupEventWithEmptyStateKey(typeNID types.EventTypeNID) *gomatrixserverlib.Event {
eventNID, ok := ae.state.lookup(types.StateKeyTuple{
EventTypeNID: typeNID,
EventStateKeyNID: types.EmptyStateKeyNID,
})
if !ok {
return nil
}
event, ok := ae.events.Lookup(eventNID)
if !ok {
return nil
}
return &event.Event
}
func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) *gomatrixserverlib.Event {
stateKeyNID, ok := ae.stateKeyNIDMap[stateKey]
if !ok {
return nil
}
eventNID, ok := ae.state.lookup(types.StateKeyTuple{
EventTypeNID: typeNID,
EventStateKeyNID: stateKeyNID,
})
if !ok {
return nil
}
event, ok := ae.events.Lookup(eventNID)
if !ok {
return nil
}
return &event.Event
}
// loadAuthEvents loads the events needed for authentication from the supplied room state.
func loadAuthEvents(
ctx context.Context,
db storage.Database,
needed gomatrixserverlib.StateNeeded,
state []types.StateEntry,
) (result authEvents, err error) {
// Look up the numeric IDs for the state keys needed for auth.
var neededStateKeys []string
neededStateKeys = append(neededStateKeys, needed.Member...)
neededStateKeys = append(neededStateKeys, needed.ThirdPartyInvite...)
if result.stateKeyNIDMap, err = db.EventStateKeyNIDs(ctx, neededStateKeys); err != nil {
return
}
// Load the events we need.
result.state = state
var eventNIDs []types.EventNID
keyTuplesNeeded := stateKeyTuplesNeeded(result.stateKeyNIDMap, needed)
for _, keyTuple := range keyTuplesNeeded {
eventNID, ok := result.state.lookup(keyTuple)
if ok {
eventNIDs = append(eventNIDs, eventNID)
}
}
if result.events, err = db.Events(ctx, eventNIDs); err != nil {
return
}
return
}
// stateKeyTuplesNeeded works out which numeric state key tuples we need to authenticate some events.
func stateKeyTuplesNeeded(
stateKeyNIDMap map[string]types.EventStateKeyNID,
stateNeeded gomatrixserverlib.StateNeeded,
) []types.StateKeyTuple {
var keyTuples []types.StateKeyTuple
if stateNeeded.Create {
keyTuples = append(keyTuples, types.StateKeyTuple{
EventTypeNID: types.MRoomCreateNID,
EventStateKeyNID: types.EmptyStateKeyNID,
})
}
if stateNeeded.PowerLevels {
keyTuples = append(keyTuples, types.StateKeyTuple{
EventTypeNID: types.MRoomPowerLevelsNID,
EventStateKeyNID: types.EmptyStateKeyNID,
})
}
if stateNeeded.JoinRules {
keyTuples = append(keyTuples, types.StateKeyTuple{
EventTypeNID: types.MRoomJoinRulesNID,
EventStateKeyNID: types.EmptyStateKeyNID,
})
}
for _, member := range stateNeeded.Member {
stateKeyNID, ok := stateKeyNIDMap[member]
if ok {
keyTuples = append(keyTuples, types.StateKeyTuple{
EventTypeNID: types.MRoomMemberNID,
EventStateKeyNID: stateKeyNID,
})
}
}
for _, token := range stateNeeded.ThirdPartyInvite {
stateKeyNID, ok := stateKeyNIDMap[token]
if ok {
keyTuples = append(keyTuples, types.StateKeyTuple{
EventTypeNID: types.MRoomThirdPartyInviteNID,
EventStateKeyNID: stateKeyNID,
})
}
}
return keyTuples
}
// Map from event type, state key tuple to numeric event ID.
// Implemented using binary search on a sorted array.
type stateEntryMap []types.StateEntry
// lookup an entry in the event map.
func (m stateEntryMap) lookup(stateKey types.StateKeyTuple) (eventNID types.EventNID, ok bool) {
// Since the list is sorted we can implement this using binary search.
// This is faster than using a hash map.
// We don't have to worry about pathological cases because the keys are fixed
// size and are controlled by us.
list := []types.StateEntry(m)
i := sort.Search(len(list), func(i int) bool {
return !list[i].StateKeyTuple.LessThan(stateKey)
})
if i < len(list) && list[i].StateKeyTuple == stateKey {
ok = true
eventNID = list[i].EventNID
}
return
}
// Map from numeric event ID to event.
// Implemented using binary search on a sorted array.
type EventMap []types.Event
// lookup an entry in the event map.
func (m EventMap) Lookup(eventNID types.EventNID) (event *types.Event, ok bool) {
// Since the list is sorted we can implement this using binary search.
// This is faster than using a hash map.
// We don't have to worry about pathological cases because the keys are fixed
// size are controlled by us.
list := []types.Event(m)
i := sort.Search(len(list), func(i int) bool {
return list[i].EventNID >= eventNID
})
if i < len(list) && list[i].EventNID == eventNID {
ok = true
event = &list[i]
}
return
}

View file

@ -0,0 +1,136 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package helpers
import (
"testing"
"github.com/matrix-org/dendrite/roomserver/types"
)
func benchmarkStateEntryMapLookup(entries, lookups int64, b *testing.B) {
var list []types.StateEntry
for i := int64(0); i < entries; i++ {
list = append(list, types.StateEntry{
StateKeyTuple: types.StateKeyTuple{
EventTypeNID: types.EventTypeNID(i),
EventStateKeyNID: types.EventStateKeyNID(i),
},
EventNID: types.EventNID(i),
})
}
for i := 0; i < b.N; i++ {
entryMap := stateEntryMap(list)
for j := int64(0); j < lookups; j++ {
entryMap.lookup(types.StateKeyTuple{
EventTypeNID: types.EventTypeNID(j),
EventStateKeyNID: types.EventStateKeyNID(j),
})
}
}
}
func BenchmarkStateEntryMap100Lookup10(b *testing.B) {
benchmarkStateEntryMapLookup(100, 10, b)
}
func BenchmarkStateEntryMap1000Lookup100(b *testing.B) {
benchmarkStateEntryMapLookup(1000, 100, b)
}
func BenchmarkStateEntryMap100Lookup100(b *testing.B) {
benchmarkStateEntryMapLookup(100, 100, b)
}
func BenchmarkStateEntryMap1000Lookup10000(b *testing.B) {
benchmarkStateEntryMapLookup(1000, 10000, b)
}
func TestStateEntryMap(t *testing.T) {
entryMap := stateEntryMap([]types.StateEntry{
{StateKeyTuple: types.StateKeyTuple{EventTypeNID: 1, EventStateKeyNID: 1}, EventNID: 1},
{StateKeyTuple: types.StateKeyTuple{EventTypeNID: 1, EventStateKeyNID: 3}, EventNID: 2},
{StateKeyTuple: types.StateKeyTuple{EventTypeNID: 2, EventStateKeyNID: 1}, EventNID: 3},
})
testCases := []struct {
inputTypeNID types.EventTypeNID
inputStateKey types.EventStateKeyNID
wantOK bool
wantEventNID types.EventNID
}{
// Check that tuples that in the array are in the map.
{1, 1, true, 1},
{1, 3, true, 2},
{2, 1, true, 3},
// Check that tuples that aren't in the array aren't in the map.
{0, 0, false, 0},
{1, 2, false, 0},
{3, 1, false, 0},
}
for _, testCase := range testCases {
keyTuple := types.StateKeyTuple{EventTypeNID: testCase.inputTypeNID, EventStateKeyNID: testCase.inputStateKey}
gotEventNID, gotOK := entryMap.lookup(keyTuple)
if testCase.wantOK != gotOK {
t.Fatalf("stateEntryMap lookup(%v): want ok to be %v, got %v", keyTuple, testCase.wantOK, gotOK)
}
if testCase.wantEventNID != gotEventNID {
t.Fatalf("stateEntryMap lookup(%v): want eventNID to be %v, got %v", keyTuple, testCase.wantEventNID, gotEventNID)
}
}
}
func TestEventMap(t *testing.T) {
events := EventMap([]types.Event{
{EventNID: 1},
{EventNID: 2},
{EventNID: 3},
{EventNID: 5},
{EventNID: 8},
})
testCases := []struct {
inputEventNID types.EventNID
wantOK bool
wantEvent *types.Event
}{
// Check that the IDs that are in the array are in the map.
{1, true, &events[0]},
{2, true, &events[1]},
{3, true, &events[2]},
{5, true, &events[3]},
{8, true, &events[4]},
// Check that tuples that aren't in the array aren't in the map.
{0, false, nil},
{4, false, nil},
{6, false, nil},
{7, false, nil},
{9, false, nil},
}
for _, testCase := range testCases {
gotEvent, gotOK := events.Lookup(testCase.inputEventNID)
if testCase.wantOK != gotOK {
t.Fatalf("eventMap lookup(%v): want ok to be %v, got %v", testCase.inputEventNID, testCase.wantOK, gotOK)
}
if testCase.wantEvent != gotEvent {
t.Fatalf("eventMap lookup(%v): want event to be %v, got %v", testCase.inputEventNID, testCase.wantEvent, gotEvent)
}
}
}

View file

@ -0,0 +1,326 @@
package helpers
import (
"context"
"fmt"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/auth"
"github.com/matrix-org/dendrite/roomserver/state"
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
// TODO: temporary package which has helper functions used by both internal/perform packages.
// Move these to a more sensible place.
func UpdateToInviteMembership(
mu *shared.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent,
roomVersion gomatrixserverlib.RoomVersion,
) ([]api.OutputEvent, error) {
// We may have already sent the invite to the user, either because we are
// reprocessing this event, or because the we received this invite from a
// remote server via the federation invite API. In those cases we don't need
// to send the event.
needsSending, err := mu.SetToInvite(*add)
if err != nil {
return nil, err
}
if needsSending {
// We notify the consumers using a special event even though we will
// notify them about the change in current state as part of the normal
// room event stream. This ensures that the consumers only have to
// consider a single stream of events when determining whether a user
// is invited, rather than having to combine multiple streams themselves.
onie := api.OutputNewInviteEvent{
Event: add.Headered(roomVersion),
RoomVersion: roomVersion,
}
updates = append(updates, api.OutputEvent{
Type: api.OutputTypeNewInviteEvent,
NewInviteEvent: &onie,
})
}
return updates, nil
}
func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverName gomatrixserverlib.ServerName, roomID string) (bool, error) {
info, err := db.RoomInfo(ctx, roomID)
if err != nil {
return false, err
}
if info == nil {
return false, fmt.Errorf("unknown room %s", roomID)
}
eventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, true, false)
if err != nil {
return false, err
}
events, err := db.Events(ctx, eventNIDs)
if err != nil {
return false, err
}
gmslEvents := make([]gomatrixserverlib.Event, len(events))
for i := range events {
gmslEvents[i] = events[i].Event
}
return auth.IsAnyUserOnServerWithMembership(serverName, gmslEvents, gomatrixserverlib.Join), nil
}
func IsInvitePending(
ctx context.Context, db storage.Database,
roomID, userID string,
) (bool, string, string, error) {
// Look up the room NID for the supplied room ID.
info, err := db.RoomInfo(ctx, roomID)
if err != nil {
return false, "", "", fmt.Errorf("r.DB.RoomInfo: %w", err)
}
if info == nil {
return false, "", "", fmt.Errorf("cannot get RoomInfo: unknown room ID %s", roomID)
}
// Look up the state key NID for the supplied user ID.
targetUserNIDs, err := db.EventStateKeyNIDs(ctx, []string{userID})
if err != nil {
return false, "", "", fmt.Errorf("r.DB.EventStateKeyNIDs: %w", err)
}
targetUserNID, targetUserFound := targetUserNIDs[userID]
if !targetUserFound {
return false, "", "", fmt.Errorf("missing NID for user %q (%+v)", userID, targetUserNIDs)
}
// Let's see if we have an event active for the user in the room. If
// we do then it will contain a server name that we can direct the
// send_leave to.
senderUserNIDs, eventIDs, err := db.GetInvitesForUser(ctx, info.RoomNID, targetUserNID)
if err != nil {
return false, "", "", fmt.Errorf("r.DB.GetInvitesForUser: %w", err)
}
if len(senderUserNIDs) == 0 {
return false, "", "", nil
}
userNIDToEventID := make(map[types.EventStateKeyNID]string)
for i, nid := range senderUserNIDs {
userNIDToEventID[nid] = eventIDs[i]
}
// Look up the user ID from the NID.
senderUsers, err := db.EventStateKeys(ctx, senderUserNIDs)
if err != nil {
return false, "", "", fmt.Errorf("r.DB.EventStateKeys: %w", err)
}
if len(senderUsers) == 0 {
return false, "", "", fmt.Errorf("no senderUsers")
}
senderUser, senderUserFound := senderUsers[senderUserNIDs[0]]
if !senderUserFound {
return false, "", "", fmt.Errorf("missing user for NID %d (%+v)", senderUserNIDs[0], senderUsers)
}
return true, senderUser, userNIDToEventID[senderUserNIDs[0]], nil
}
// GetMembershipsAtState filters the state events to
// only keep the "m.room.member" events with a "join" membership. These events are returned.
// Returns an error if there was an issue fetching the events.
func GetMembershipsAtState(
ctx context.Context, db storage.Database, stateEntries []types.StateEntry, joinedOnly bool,
) ([]types.Event, error) {
var eventNIDs []types.EventNID
for _, entry := range stateEntries {
// Filter the events to retrieve to only keep the membership events
if entry.EventTypeNID == types.MRoomMemberNID {
eventNIDs = append(eventNIDs, entry.EventNID)
}
}
// Get all of the events in this state
stateEvents, err := db.Events(ctx, eventNIDs)
if err != nil {
return nil, err
}
if !joinedOnly {
return stateEvents, nil
}
// Filter the events to only keep the "join" membership events
var events []types.Event
for _, event := range stateEvents {
membership, err := event.Membership()
if err != nil {
return nil, err
}
if membership == gomatrixserverlib.Join {
events = append(events, event)
}
}
return events, nil
}
func StateBeforeEvent(ctx context.Context, db storage.Database, info types.RoomInfo, eventNID types.EventNID) ([]types.StateEntry, error) {
roomState := state.NewStateResolution(db, info)
// Lookup the event NID
eIDs, err := db.EventIDs(ctx, []types.EventNID{eventNID})
if err != nil {
return nil, err
}
eventIDs := []string{eIDs[eventNID]}
prevState, err := db.StateAtEventIDs(ctx, eventIDs)
if err != nil {
return nil, err
}
// Fetch the state as it was when this event was fired
return roomState.LoadCombinedStateAfterEvents(ctx, prevState)
}
func LoadEvents(
ctx context.Context, db storage.Database, eventNIDs []types.EventNID,
) ([]gomatrixserverlib.Event, error) {
stateEvents, err := db.Events(ctx, eventNIDs)
if err != nil {
return nil, err
}
result := make([]gomatrixserverlib.Event, len(stateEvents))
for i := range stateEvents {
result[i] = stateEvents[i].Event
}
return result, nil
}
func LoadStateEvents(
ctx context.Context, db storage.Database, stateEntries []types.StateEntry,
) ([]gomatrixserverlib.Event, error) {
eventNIDs := make([]types.EventNID, len(stateEntries))
for i := range stateEntries {
eventNIDs[i] = stateEntries[i].EventNID
}
return LoadEvents(ctx, db, eventNIDs)
}
func CheckServerAllowedToSeeEvent(
ctx context.Context, db storage.Database, info types.RoomInfo, eventID string, serverName gomatrixserverlib.ServerName, isServerInRoom bool,
) (bool, error) {
roomState := state.NewStateResolution(db, info)
stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID)
if err != nil {
return false, err
}
// TODO: We probably want to make it so that we don't have to pull
// out all the state if possible.
stateAtEvent, err := LoadStateEvents(ctx, db, stateEntries)
if err != nil {
return false, err
}
return auth.IsServerAllowed(serverName, isServerInRoom, stateAtEvent), nil
}
// TODO: Remove this when we have tests to assert correctness of this function
// nolint:gocyclo
func ScanEventTree(
ctx context.Context, db storage.Database, info types.RoomInfo, front []string, visited map[string]bool, limit int,
serverName gomatrixserverlib.ServerName,
) ([]types.EventNID, error) {
var resultNIDs []types.EventNID
var err error
var allowed bool
var events []types.Event
var next []string
var pre string
// TODO: add tests for this function to ensure it meets the contract that callers expect (and doc what that is supposed to be)
// Currently, callers like PerformBackfill will call scanEventTree with a pre-populated `visited` map, assuming that by doing
// so means that the events in that map will NOT be returned from this function. That is not currently true, resulting in
// duplicate events being sent in response to /backfill requests.
initialIgnoreList := make(map[string]bool, len(visited))
for k, v := range visited {
initialIgnoreList[k] = v
}
resultNIDs = make([]types.EventNID, 0, limit)
var checkedServerInRoom bool
var isServerInRoom bool
// Loop through the event IDs to retrieve the requested events and go
// through the whole tree (up to the provided limit) using the events'
// "prev_event" key.
BFSLoop:
for len(front) > 0 {
// Prevent unnecessary allocations: reset the slice only when not empty.
if len(next) > 0 {
next = make([]string, 0)
}
// Retrieve the events to process from the database.
events, err = db.EventsFromIDs(ctx, front)
if err != nil {
return resultNIDs, err
}
if !checkedServerInRoom && len(events) > 0 {
// It's nasty that we have to extract the room ID from an event, but many federation requests
// only talk in event IDs, no room IDs at all (!!!)
ev := events[0]
isServerInRoom, err = IsServerCurrentlyInRoom(ctx, db, serverName, ev.RoomID())
if err != nil {
util.GetLogger(ctx).WithError(err).Error("Failed to check if server is currently in room, assuming not.")
}
checkedServerInRoom = true
}
for _, ev := range events {
// Break out of the loop if the provided limit is reached.
if len(resultNIDs) == limit {
break BFSLoop
}
if !initialIgnoreList[ev.EventID()] {
// Update the list of events to retrieve.
resultNIDs = append(resultNIDs, ev.EventNID)
}
// Loop through the event's parents.
for _, pre = range ev.PrevEventIDs() {
// Only add an event to the list of next events to process if it
// hasn't been seen before.
if !visited[pre] {
visited[pre] = true
allowed, err = CheckServerAllowedToSeeEvent(ctx, db, info, pre, serverName, isServerInRoom)
if err != nil {
util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error(
"Error checking if allowed to see event",
)
return resultNIDs, err
}
// If the event hasn't been seen before and the HS
// requesting to retrieve it is allowed to do so, add it to
// the list of events to retrieve.
if allowed {
next = append(next, pre)
} else {
util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).Info("Not allowed to see event")
}
}
}
}
// Repeat the same process with the parent events we just processed.
front = next
}
return resultNIDs, err
}