mirror of
https://github.com/hoernschen/dendrite.git
synced 2025-07-31 21:32:46 +00:00
Refactor roomserver/internal - split perform stuff out (#1380)
- New package `perform` which contains all `Perform` functions - New package `helpers` which contains helper functions used by both perform and query/input functions. - Perform invite/leave have no idea how to `WriteOutputEvents` and this is now returned from `PerformInvite` or `PerformLeave` respectively. Still to do: - RSAPI is fed into the inviter/joiner/leaver - this introduces circular logic so will need to be removed. - Put query operations in a `query` package. - Put input operations (and output) in an `input` package. - Factor out helper functions as much as possible, possibly rejigging the storage layer in the process.
This commit is contained in:
parent
02a73f29f8
commit
e473320e73
15 changed files with 820 additions and 647 deletions
244
roomserver/internal/helpers/auth.go
Normal file
244
roomserver/internal/helpers/auth.go
Normal file
|
@ -0,0 +1,244 @@
|
|||
// Copyright 2017 Vector Creations Ltd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package helpers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sort"
|
||||
|
||||
"github.com/matrix-org/dendrite/roomserver/storage"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// CheckAuthEvents checks that the event passes authentication checks
|
||||
// Returns the numeric IDs for the auth events.
|
||||
func CheckAuthEvents(
|
||||
ctx context.Context,
|
||||
db storage.Database,
|
||||
event gomatrixserverlib.HeaderedEvent,
|
||||
authEventIDs []string,
|
||||
) ([]types.EventNID, error) {
|
||||
// Grab the numeric IDs for the supplied auth state events from the database.
|
||||
authStateEntries, err := db.StateEntriesForEventIDs(ctx, authEventIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// TODO: check for duplicate state keys here.
|
||||
|
||||
// Work out which of the state events we actually need.
|
||||
stateNeeded := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{event.Unwrap()})
|
||||
|
||||
// Load the actual auth events from the database.
|
||||
authEvents, err := loadAuthEvents(ctx, db, stateNeeded, authStateEntries)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Check if the event is allowed.
|
||||
if err = gomatrixserverlib.Allowed(event.Event, &authEvents); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Return the numeric IDs for the auth events.
|
||||
result := make([]types.EventNID, len(authStateEntries))
|
||||
for i := range authStateEntries {
|
||||
result[i] = authStateEntries[i].EventNID
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
type authEvents struct {
|
||||
stateKeyNIDMap map[string]types.EventStateKeyNID
|
||||
state stateEntryMap
|
||||
events EventMap
|
||||
}
|
||||
|
||||
// Create implements gomatrixserverlib.AuthEventProvider
|
||||
func (ae *authEvents) Create() (*gomatrixserverlib.Event, error) {
|
||||
return ae.lookupEventWithEmptyStateKey(types.MRoomCreateNID), nil
|
||||
}
|
||||
|
||||
// PowerLevels implements gomatrixserverlib.AuthEventProvider
|
||||
func (ae *authEvents) PowerLevels() (*gomatrixserverlib.Event, error) {
|
||||
return ae.lookupEventWithEmptyStateKey(types.MRoomPowerLevelsNID), nil
|
||||
}
|
||||
|
||||
// JoinRules implements gomatrixserverlib.AuthEventProvider
|
||||
func (ae *authEvents) JoinRules() (*gomatrixserverlib.Event, error) {
|
||||
return ae.lookupEventWithEmptyStateKey(types.MRoomJoinRulesNID), nil
|
||||
}
|
||||
|
||||
// Memmber implements gomatrixserverlib.AuthEventProvider
|
||||
func (ae *authEvents) Member(stateKey string) (*gomatrixserverlib.Event, error) {
|
||||
return ae.lookupEvent(types.MRoomMemberNID, stateKey), nil
|
||||
}
|
||||
|
||||
// ThirdPartyInvite implements gomatrixserverlib.AuthEventProvider
|
||||
func (ae *authEvents) ThirdPartyInvite(stateKey string) (*gomatrixserverlib.Event, error) {
|
||||
return ae.lookupEvent(types.MRoomThirdPartyInviteNID, stateKey), nil
|
||||
}
|
||||
|
||||
func (ae *authEvents) lookupEventWithEmptyStateKey(typeNID types.EventTypeNID) *gomatrixserverlib.Event {
|
||||
eventNID, ok := ae.state.lookup(types.StateKeyTuple{
|
||||
EventTypeNID: typeNID,
|
||||
EventStateKeyNID: types.EmptyStateKeyNID,
|
||||
})
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
event, ok := ae.events.Lookup(eventNID)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return &event.Event
|
||||
}
|
||||
|
||||
func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) *gomatrixserverlib.Event {
|
||||
stateKeyNID, ok := ae.stateKeyNIDMap[stateKey]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
eventNID, ok := ae.state.lookup(types.StateKeyTuple{
|
||||
EventTypeNID: typeNID,
|
||||
EventStateKeyNID: stateKeyNID,
|
||||
})
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
event, ok := ae.events.Lookup(eventNID)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return &event.Event
|
||||
}
|
||||
|
||||
// loadAuthEvents loads the events needed for authentication from the supplied room state.
|
||||
func loadAuthEvents(
|
||||
ctx context.Context,
|
||||
db storage.Database,
|
||||
needed gomatrixserverlib.StateNeeded,
|
||||
state []types.StateEntry,
|
||||
) (result authEvents, err error) {
|
||||
// Look up the numeric IDs for the state keys needed for auth.
|
||||
var neededStateKeys []string
|
||||
neededStateKeys = append(neededStateKeys, needed.Member...)
|
||||
neededStateKeys = append(neededStateKeys, needed.ThirdPartyInvite...)
|
||||
if result.stateKeyNIDMap, err = db.EventStateKeyNIDs(ctx, neededStateKeys); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Load the events we need.
|
||||
result.state = state
|
||||
var eventNIDs []types.EventNID
|
||||
keyTuplesNeeded := stateKeyTuplesNeeded(result.stateKeyNIDMap, needed)
|
||||
for _, keyTuple := range keyTuplesNeeded {
|
||||
eventNID, ok := result.state.lookup(keyTuple)
|
||||
if ok {
|
||||
eventNIDs = append(eventNIDs, eventNID)
|
||||
}
|
||||
}
|
||||
if result.events, err = db.Events(ctx, eventNIDs); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// stateKeyTuplesNeeded works out which numeric state key tuples we need to authenticate some events.
|
||||
func stateKeyTuplesNeeded(
|
||||
stateKeyNIDMap map[string]types.EventStateKeyNID,
|
||||
stateNeeded gomatrixserverlib.StateNeeded,
|
||||
) []types.StateKeyTuple {
|
||||
var keyTuples []types.StateKeyTuple
|
||||
if stateNeeded.Create {
|
||||
keyTuples = append(keyTuples, types.StateKeyTuple{
|
||||
EventTypeNID: types.MRoomCreateNID,
|
||||
EventStateKeyNID: types.EmptyStateKeyNID,
|
||||
})
|
||||
}
|
||||
if stateNeeded.PowerLevels {
|
||||
keyTuples = append(keyTuples, types.StateKeyTuple{
|
||||
EventTypeNID: types.MRoomPowerLevelsNID,
|
||||
EventStateKeyNID: types.EmptyStateKeyNID,
|
||||
})
|
||||
}
|
||||
if stateNeeded.JoinRules {
|
||||
keyTuples = append(keyTuples, types.StateKeyTuple{
|
||||
EventTypeNID: types.MRoomJoinRulesNID,
|
||||
EventStateKeyNID: types.EmptyStateKeyNID,
|
||||
})
|
||||
}
|
||||
for _, member := range stateNeeded.Member {
|
||||
stateKeyNID, ok := stateKeyNIDMap[member]
|
||||
if ok {
|
||||
keyTuples = append(keyTuples, types.StateKeyTuple{
|
||||
EventTypeNID: types.MRoomMemberNID,
|
||||
EventStateKeyNID: stateKeyNID,
|
||||
})
|
||||
}
|
||||
}
|
||||
for _, token := range stateNeeded.ThirdPartyInvite {
|
||||
stateKeyNID, ok := stateKeyNIDMap[token]
|
||||
if ok {
|
||||
keyTuples = append(keyTuples, types.StateKeyTuple{
|
||||
EventTypeNID: types.MRoomThirdPartyInviteNID,
|
||||
EventStateKeyNID: stateKeyNID,
|
||||
})
|
||||
}
|
||||
}
|
||||
return keyTuples
|
||||
}
|
||||
|
||||
// Map from event type, state key tuple to numeric event ID.
|
||||
// Implemented using binary search on a sorted array.
|
||||
type stateEntryMap []types.StateEntry
|
||||
|
||||
// lookup an entry in the event map.
|
||||
func (m stateEntryMap) lookup(stateKey types.StateKeyTuple) (eventNID types.EventNID, ok bool) {
|
||||
// Since the list is sorted we can implement this using binary search.
|
||||
// This is faster than using a hash map.
|
||||
// We don't have to worry about pathological cases because the keys are fixed
|
||||
// size and are controlled by us.
|
||||
list := []types.StateEntry(m)
|
||||
i := sort.Search(len(list), func(i int) bool {
|
||||
return !list[i].StateKeyTuple.LessThan(stateKey)
|
||||
})
|
||||
if i < len(list) && list[i].StateKeyTuple == stateKey {
|
||||
ok = true
|
||||
eventNID = list[i].EventNID
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Map from numeric event ID to event.
|
||||
// Implemented using binary search on a sorted array.
|
||||
type EventMap []types.Event
|
||||
|
||||
// lookup an entry in the event map.
|
||||
func (m EventMap) Lookup(eventNID types.EventNID) (event *types.Event, ok bool) {
|
||||
// Since the list is sorted we can implement this using binary search.
|
||||
// This is faster than using a hash map.
|
||||
// We don't have to worry about pathological cases because the keys are fixed
|
||||
// size are controlled by us.
|
||||
list := []types.Event(m)
|
||||
i := sort.Search(len(list), func(i int) bool {
|
||||
return list[i].EventNID >= eventNID
|
||||
})
|
||||
if i < len(list) && list[i].EventNID == eventNID {
|
||||
ok = true
|
||||
event = &list[i]
|
||||
}
|
||||
return
|
||||
}
|
136
roomserver/internal/helpers/auth_test.go
Normal file
136
roomserver/internal/helpers/auth_test.go
Normal file
|
@ -0,0 +1,136 @@
|
|||
// Copyright 2017 Vector Creations Ltd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package helpers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
)
|
||||
|
||||
func benchmarkStateEntryMapLookup(entries, lookups int64, b *testing.B) {
|
||||
var list []types.StateEntry
|
||||
for i := int64(0); i < entries; i++ {
|
||||
list = append(list, types.StateEntry{
|
||||
StateKeyTuple: types.StateKeyTuple{
|
||||
EventTypeNID: types.EventTypeNID(i),
|
||||
EventStateKeyNID: types.EventStateKeyNID(i),
|
||||
},
|
||||
EventNID: types.EventNID(i),
|
||||
})
|
||||
}
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
entryMap := stateEntryMap(list)
|
||||
for j := int64(0); j < lookups; j++ {
|
||||
entryMap.lookup(types.StateKeyTuple{
|
||||
EventTypeNID: types.EventTypeNID(j),
|
||||
EventStateKeyNID: types.EventStateKeyNID(j),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkStateEntryMap100Lookup10(b *testing.B) {
|
||||
benchmarkStateEntryMapLookup(100, 10, b)
|
||||
}
|
||||
|
||||
func BenchmarkStateEntryMap1000Lookup100(b *testing.B) {
|
||||
benchmarkStateEntryMapLookup(1000, 100, b)
|
||||
}
|
||||
|
||||
func BenchmarkStateEntryMap100Lookup100(b *testing.B) {
|
||||
benchmarkStateEntryMapLookup(100, 100, b)
|
||||
}
|
||||
|
||||
func BenchmarkStateEntryMap1000Lookup10000(b *testing.B) {
|
||||
benchmarkStateEntryMapLookup(1000, 10000, b)
|
||||
}
|
||||
|
||||
func TestStateEntryMap(t *testing.T) {
|
||||
entryMap := stateEntryMap([]types.StateEntry{
|
||||
{StateKeyTuple: types.StateKeyTuple{EventTypeNID: 1, EventStateKeyNID: 1}, EventNID: 1},
|
||||
{StateKeyTuple: types.StateKeyTuple{EventTypeNID: 1, EventStateKeyNID: 3}, EventNID: 2},
|
||||
{StateKeyTuple: types.StateKeyTuple{EventTypeNID: 2, EventStateKeyNID: 1}, EventNID: 3},
|
||||
})
|
||||
|
||||
testCases := []struct {
|
||||
inputTypeNID types.EventTypeNID
|
||||
inputStateKey types.EventStateKeyNID
|
||||
wantOK bool
|
||||
wantEventNID types.EventNID
|
||||
}{
|
||||
// Check that tuples that in the array are in the map.
|
||||
{1, 1, true, 1},
|
||||
{1, 3, true, 2},
|
||||
{2, 1, true, 3},
|
||||
// Check that tuples that aren't in the array aren't in the map.
|
||||
{0, 0, false, 0},
|
||||
{1, 2, false, 0},
|
||||
{3, 1, false, 0},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
keyTuple := types.StateKeyTuple{EventTypeNID: testCase.inputTypeNID, EventStateKeyNID: testCase.inputStateKey}
|
||||
gotEventNID, gotOK := entryMap.lookup(keyTuple)
|
||||
if testCase.wantOK != gotOK {
|
||||
t.Fatalf("stateEntryMap lookup(%v): want ok to be %v, got %v", keyTuple, testCase.wantOK, gotOK)
|
||||
}
|
||||
if testCase.wantEventNID != gotEventNID {
|
||||
t.Fatalf("stateEntryMap lookup(%v): want eventNID to be %v, got %v", keyTuple, testCase.wantEventNID, gotEventNID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEventMap(t *testing.T) {
|
||||
events := EventMap([]types.Event{
|
||||
{EventNID: 1},
|
||||
{EventNID: 2},
|
||||
{EventNID: 3},
|
||||
{EventNID: 5},
|
||||
{EventNID: 8},
|
||||
})
|
||||
|
||||
testCases := []struct {
|
||||
inputEventNID types.EventNID
|
||||
wantOK bool
|
||||
wantEvent *types.Event
|
||||
}{
|
||||
// Check that the IDs that are in the array are in the map.
|
||||
{1, true, &events[0]},
|
||||
{2, true, &events[1]},
|
||||
{3, true, &events[2]},
|
||||
{5, true, &events[3]},
|
||||
{8, true, &events[4]},
|
||||
// Check that tuples that aren't in the array aren't in the map.
|
||||
{0, false, nil},
|
||||
{4, false, nil},
|
||||
{6, false, nil},
|
||||
{7, false, nil},
|
||||
{9, false, nil},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
gotEvent, gotOK := events.Lookup(testCase.inputEventNID)
|
||||
if testCase.wantOK != gotOK {
|
||||
t.Fatalf("eventMap lookup(%v): want ok to be %v, got %v", testCase.inputEventNID, testCase.wantOK, gotOK)
|
||||
}
|
||||
|
||||
if testCase.wantEvent != gotEvent {
|
||||
t.Fatalf("eventMap lookup(%v): want event to be %v, got %v", testCase.inputEventNID, testCase.wantEvent, gotEvent)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
326
roomserver/internal/helpers/helpers.go
Normal file
326
roomserver/internal/helpers/helpers.go
Normal file
|
@ -0,0 +1,326 @@
|
|||
package helpers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/roomserver/auth"
|
||||
"github.com/matrix-org/dendrite/roomserver/state"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
)
|
||||
|
||||
// TODO: temporary package which has helper functions used by both internal/perform packages.
|
||||
// Move these to a more sensible place.
|
||||
|
||||
func UpdateToInviteMembership(
|
||||
mu *shared.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent,
|
||||
roomVersion gomatrixserverlib.RoomVersion,
|
||||
) ([]api.OutputEvent, error) {
|
||||
// We may have already sent the invite to the user, either because we are
|
||||
// reprocessing this event, or because the we received this invite from a
|
||||
// remote server via the federation invite API. In those cases we don't need
|
||||
// to send the event.
|
||||
needsSending, err := mu.SetToInvite(*add)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if needsSending {
|
||||
// We notify the consumers using a special event even though we will
|
||||
// notify them about the change in current state as part of the normal
|
||||
// room event stream. This ensures that the consumers only have to
|
||||
// consider a single stream of events when determining whether a user
|
||||
// is invited, rather than having to combine multiple streams themselves.
|
||||
onie := api.OutputNewInviteEvent{
|
||||
Event: add.Headered(roomVersion),
|
||||
RoomVersion: roomVersion,
|
||||
}
|
||||
updates = append(updates, api.OutputEvent{
|
||||
Type: api.OutputTypeNewInviteEvent,
|
||||
NewInviteEvent: &onie,
|
||||
})
|
||||
}
|
||||
return updates, nil
|
||||
}
|
||||
|
||||
func IsServerCurrentlyInRoom(ctx context.Context, db storage.Database, serverName gomatrixserverlib.ServerName, roomID string) (bool, error) {
|
||||
info, err := db.RoomInfo(ctx, roomID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if info == nil {
|
||||
return false, fmt.Errorf("unknown room %s", roomID)
|
||||
}
|
||||
|
||||
eventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, info.RoomNID, true, false)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
events, err := db.Events(ctx, eventNIDs)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
gmslEvents := make([]gomatrixserverlib.Event, len(events))
|
||||
for i := range events {
|
||||
gmslEvents[i] = events[i].Event
|
||||
}
|
||||
return auth.IsAnyUserOnServerWithMembership(serverName, gmslEvents, gomatrixserverlib.Join), nil
|
||||
}
|
||||
|
||||
func IsInvitePending(
|
||||
ctx context.Context, db storage.Database,
|
||||
roomID, userID string,
|
||||
) (bool, string, string, error) {
|
||||
// Look up the room NID for the supplied room ID.
|
||||
info, err := db.RoomInfo(ctx, roomID)
|
||||
if err != nil {
|
||||
return false, "", "", fmt.Errorf("r.DB.RoomInfo: %w", err)
|
||||
}
|
||||
if info == nil {
|
||||
return false, "", "", fmt.Errorf("cannot get RoomInfo: unknown room ID %s", roomID)
|
||||
}
|
||||
|
||||
// Look up the state key NID for the supplied user ID.
|
||||
targetUserNIDs, err := db.EventStateKeyNIDs(ctx, []string{userID})
|
||||
if err != nil {
|
||||
return false, "", "", fmt.Errorf("r.DB.EventStateKeyNIDs: %w", err)
|
||||
}
|
||||
targetUserNID, targetUserFound := targetUserNIDs[userID]
|
||||
if !targetUserFound {
|
||||
return false, "", "", fmt.Errorf("missing NID for user %q (%+v)", userID, targetUserNIDs)
|
||||
}
|
||||
|
||||
// Let's see if we have an event active for the user in the room. If
|
||||
// we do then it will contain a server name that we can direct the
|
||||
// send_leave to.
|
||||
senderUserNIDs, eventIDs, err := db.GetInvitesForUser(ctx, info.RoomNID, targetUserNID)
|
||||
if err != nil {
|
||||
return false, "", "", fmt.Errorf("r.DB.GetInvitesForUser: %w", err)
|
||||
}
|
||||
if len(senderUserNIDs) == 0 {
|
||||
return false, "", "", nil
|
||||
}
|
||||
userNIDToEventID := make(map[types.EventStateKeyNID]string)
|
||||
for i, nid := range senderUserNIDs {
|
||||
userNIDToEventID[nid] = eventIDs[i]
|
||||
}
|
||||
|
||||
// Look up the user ID from the NID.
|
||||
senderUsers, err := db.EventStateKeys(ctx, senderUserNIDs)
|
||||
if err != nil {
|
||||
return false, "", "", fmt.Errorf("r.DB.EventStateKeys: %w", err)
|
||||
}
|
||||
if len(senderUsers) == 0 {
|
||||
return false, "", "", fmt.Errorf("no senderUsers")
|
||||
}
|
||||
|
||||
senderUser, senderUserFound := senderUsers[senderUserNIDs[0]]
|
||||
if !senderUserFound {
|
||||
return false, "", "", fmt.Errorf("missing user for NID %d (%+v)", senderUserNIDs[0], senderUsers)
|
||||
}
|
||||
|
||||
return true, senderUser, userNIDToEventID[senderUserNIDs[0]], nil
|
||||
}
|
||||
|
||||
// GetMembershipsAtState filters the state events to
|
||||
// only keep the "m.room.member" events with a "join" membership. These events are returned.
|
||||
// Returns an error if there was an issue fetching the events.
|
||||
func GetMembershipsAtState(
|
||||
ctx context.Context, db storage.Database, stateEntries []types.StateEntry, joinedOnly bool,
|
||||
) ([]types.Event, error) {
|
||||
|
||||
var eventNIDs []types.EventNID
|
||||
for _, entry := range stateEntries {
|
||||
// Filter the events to retrieve to only keep the membership events
|
||||
if entry.EventTypeNID == types.MRoomMemberNID {
|
||||
eventNIDs = append(eventNIDs, entry.EventNID)
|
||||
}
|
||||
}
|
||||
|
||||
// Get all of the events in this state
|
||||
stateEvents, err := db.Events(ctx, eventNIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !joinedOnly {
|
||||
return stateEvents, nil
|
||||
}
|
||||
|
||||
// Filter the events to only keep the "join" membership events
|
||||
var events []types.Event
|
||||
for _, event := range stateEvents {
|
||||
membership, err := event.Membership()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if membership == gomatrixserverlib.Join {
|
||||
events = append(events, event)
|
||||
}
|
||||
}
|
||||
|
||||
return events, nil
|
||||
}
|
||||
|
||||
func StateBeforeEvent(ctx context.Context, db storage.Database, info types.RoomInfo, eventNID types.EventNID) ([]types.StateEntry, error) {
|
||||
roomState := state.NewStateResolution(db, info)
|
||||
// Lookup the event NID
|
||||
eIDs, err := db.EventIDs(ctx, []types.EventNID{eventNID})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
eventIDs := []string{eIDs[eventNID]}
|
||||
|
||||
prevState, err := db.StateAtEventIDs(ctx, eventIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Fetch the state as it was when this event was fired
|
||||
return roomState.LoadCombinedStateAfterEvents(ctx, prevState)
|
||||
}
|
||||
|
||||
func LoadEvents(
|
||||
ctx context.Context, db storage.Database, eventNIDs []types.EventNID,
|
||||
) ([]gomatrixserverlib.Event, error) {
|
||||
stateEvents, err := db.Events(ctx, eventNIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
result := make([]gomatrixserverlib.Event, len(stateEvents))
|
||||
for i := range stateEvents {
|
||||
result[i] = stateEvents[i].Event
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func LoadStateEvents(
|
||||
ctx context.Context, db storage.Database, stateEntries []types.StateEntry,
|
||||
) ([]gomatrixserverlib.Event, error) {
|
||||
eventNIDs := make([]types.EventNID, len(stateEntries))
|
||||
for i := range stateEntries {
|
||||
eventNIDs[i] = stateEntries[i].EventNID
|
||||
}
|
||||
return LoadEvents(ctx, db, eventNIDs)
|
||||
}
|
||||
|
||||
func CheckServerAllowedToSeeEvent(
|
||||
ctx context.Context, db storage.Database, info types.RoomInfo, eventID string, serverName gomatrixserverlib.ServerName, isServerInRoom bool,
|
||||
) (bool, error) {
|
||||
roomState := state.NewStateResolution(db, info)
|
||||
stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// TODO: We probably want to make it so that we don't have to pull
|
||||
// out all the state if possible.
|
||||
stateAtEvent, err := LoadStateEvents(ctx, db, stateEntries)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
return auth.IsServerAllowed(serverName, isServerInRoom, stateAtEvent), nil
|
||||
}
|
||||
|
||||
// TODO: Remove this when we have tests to assert correctness of this function
|
||||
// nolint:gocyclo
|
||||
func ScanEventTree(
|
||||
ctx context.Context, db storage.Database, info types.RoomInfo, front []string, visited map[string]bool, limit int,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
) ([]types.EventNID, error) {
|
||||
var resultNIDs []types.EventNID
|
||||
var err error
|
||||
var allowed bool
|
||||
var events []types.Event
|
||||
var next []string
|
||||
var pre string
|
||||
|
||||
// TODO: add tests for this function to ensure it meets the contract that callers expect (and doc what that is supposed to be)
|
||||
// Currently, callers like PerformBackfill will call scanEventTree with a pre-populated `visited` map, assuming that by doing
|
||||
// so means that the events in that map will NOT be returned from this function. That is not currently true, resulting in
|
||||
// duplicate events being sent in response to /backfill requests.
|
||||
initialIgnoreList := make(map[string]bool, len(visited))
|
||||
for k, v := range visited {
|
||||
initialIgnoreList[k] = v
|
||||
}
|
||||
|
||||
resultNIDs = make([]types.EventNID, 0, limit)
|
||||
|
||||
var checkedServerInRoom bool
|
||||
var isServerInRoom bool
|
||||
|
||||
// Loop through the event IDs to retrieve the requested events and go
|
||||
// through the whole tree (up to the provided limit) using the events'
|
||||
// "prev_event" key.
|
||||
BFSLoop:
|
||||
for len(front) > 0 {
|
||||
// Prevent unnecessary allocations: reset the slice only when not empty.
|
||||
if len(next) > 0 {
|
||||
next = make([]string, 0)
|
||||
}
|
||||
// Retrieve the events to process from the database.
|
||||
events, err = db.EventsFromIDs(ctx, front)
|
||||
if err != nil {
|
||||
return resultNIDs, err
|
||||
}
|
||||
|
||||
if !checkedServerInRoom && len(events) > 0 {
|
||||
// It's nasty that we have to extract the room ID from an event, but many federation requests
|
||||
// only talk in event IDs, no room IDs at all (!!!)
|
||||
ev := events[0]
|
||||
isServerInRoom, err = IsServerCurrentlyInRoom(ctx, db, serverName, ev.RoomID())
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("Failed to check if server is currently in room, assuming not.")
|
||||
}
|
||||
checkedServerInRoom = true
|
||||
}
|
||||
|
||||
for _, ev := range events {
|
||||
// Break out of the loop if the provided limit is reached.
|
||||
if len(resultNIDs) == limit {
|
||||
break BFSLoop
|
||||
}
|
||||
|
||||
if !initialIgnoreList[ev.EventID()] {
|
||||
// Update the list of events to retrieve.
|
||||
resultNIDs = append(resultNIDs, ev.EventNID)
|
||||
}
|
||||
// Loop through the event's parents.
|
||||
for _, pre = range ev.PrevEventIDs() {
|
||||
// Only add an event to the list of next events to process if it
|
||||
// hasn't been seen before.
|
||||
if !visited[pre] {
|
||||
visited[pre] = true
|
||||
allowed, err = CheckServerAllowedToSeeEvent(ctx, db, info, pre, serverName, isServerInRoom)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error(
|
||||
"Error checking if allowed to see event",
|
||||
)
|
||||
return resultNIDs, err
|
||||
}
|
||||
|
||||
// If the event hasn't been seen before and the HS
|
||||
// requesting to retrieve it is allowed to do so, add it to
|
||||
// the list of events to retrieve.
|
||||
if allowed {
|
||||
next = append(next, pre)
|
||||
} else {
|
||||
util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).Info("Not allowed to see event")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Repeat the same process with the parent events we just processed.
|
||||
front = next
|
||||
}
|
||||
|
||||
return resultNIDs, err
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue