Consolidation of roomserver APIs (#994)

* Consolidation of roomserver APIs

* Comment out alias tests for now, they are broken

* Wire AS API into roomserver again

* Roomserver didn't take asAPI param before so return to that

* Prevent roomserver asking AS API for alias info

* Rename some files

* Remove alias_test, incoherent tests and unwanted appservice integration

* Remove FS API inject on syncapi component
This commit is contained in:
Neil Alexander 2020-05-01 10:48:17 +01:00 committed by GitHub
parent ebbfc12592
commit e15f6676ac
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
72 changed files with 894 additions and 1170 deletions

View file

@ -0,0 +1,271 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package internal
import (
"context"
"encoding/json"
"errors"
"time"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib"
)
// RoomserverInternalAPIDatabase has the storage APIs needed to implement the alias API.
type RoomserverInternalAPIDatabase interface {
// Save a given room alias with the room ID it refers to.
// Returns an error if there was a problem talking to the database.
SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error
// Look up the room ID a given alias refers to.
// Returns an error if there was a problem talking to the database.
GetRoomIDForAlias(ctx context.Context, alias string) (string, error)
// Look up all aliases referring to a given room ID.
// Returns an error if there was a problem talking to the database.
GetAliasesForRoomID(ctx context.Context, roomID string) ([]string, error)
// Get the user ID of the creator of an alias.
// Returns an error if there was a problem talking to the database.
GetCreatorIDForAlias(ctx context.Context, alias string) (string, error)
// Remove a given room alias.
// Returns an error if there was a problem talking to the database.
RemoveRoomAlias(ctx context.Context, alias string) error
// Look up the room version for a given room.
GetRoomVersionForRoom(
ctx context.Context, roomID string,
) (gomatrixserverlib.RoomVersion, error)
}
// SetRoomAlias implements alias.RoomserverInternalAPI
func (r *RoomserverInternalAPI) SetRoomAlias(
ctx context.Context,
request *api.SetRoomAliasRequest,
response *api.SetRoomAliasResponse,
) error {
// Check if the alias isn't already referring to a room
roomID, err := r.DB.GetRoomIDForAlias(ctx, request.Alias)
if err != nil {
return err
}
if len(roomID) > 0 {
// If the alias already exists, stop the process
response.AliasExists = true
return nil
}
response.AliasExists = false
// Save the new alias
if err := r.DB.SetRoomAlias(ctx, request.Alias, request.RoomID, request.UserID); err != nil {
return err
}
// Send a m.room.aliases event with the updated list of aliases for this room
// At this point we've already committed the alias to the database so we
// shouldn't cancel this request.
// TODO: Ensure that we send unsent events when if server restarts.
return r.sendUpdatedAliasesEvent(context.TODO(), request.UserID, request.RoomID)
}
// GetRoomIDForAlias implements alias.RoomserverInternalAPI
func (r *RoomserverInternalAPI) GetRoomIDForAlias(
ctx context.Context,
request *api.GetRoomIDForAliasRequest,
response *api.GetRoomIDForAliasResponse,
) error {
// Look up the room ID in the database
roomID, err := r.DB.GetRoomIDForAlias(ctx, request.Alias)
if err != nil {
return err
}
/*
TODO: Why is this here? It creates an unnecessary dependency
from the roomserver to the appservice component, which should be
altogether optional.
if roomID == "" {
// No room found locally, try our application services by making a call to
// the appservice component
aliasReq := appserviceAPI.RoomAliasExistsRequest{Alias: request.Alias}
var aliasResp appserviceAPI.RoomAliasExistsResponse
if err = r.AppserviceAPI.RoomAliasExists(ctx, &aliasReq, &aliasResp); err != nil {
return err
}
if aliasResp.AliasExists {
roomID, err = r.DB.GetRoomIDForAlias(ctx, request.Alias)
if err != nil {
return err
}
}
}
*/
response.RoomID = roomID
return nil
}
// GetAliasesForRoomID implements alias.RoomserverInternalAPI
func (r *RoomserverInternalAPI) GetAliasesForRoomID(
ctx context.Context,
request *api.GetAliasesForRoomIDRequest,
response *api.GetAliasesForRoomIDResponse,
) error {
// Look up the aliases in the database for the given RoomID
aliases, err := r.DB.GetAliasesForRoomID(ctx, request.RoomID)
if err != nil {
return err
}
response.Aliases = aliases
return nil
}
// GetCreatorIDForAlias implements alias.RoomserverInternalAPI
func (r *RoomserverInternalAPI) GetCreatorIDForAlias(
ctx context.Context,
request *api.GetCreatorIDForAliasRequest,
response *api.GetCreatorIDForAliasResponse,
) error {
// Look up the aliases in the database for the given RoomID
creatorID, err := r.DB.GetCreatorIDForAlias(ctx, request.Alias)
if err != nil {
return err
}
response.UserID = creatorID
return nil
}
// RemoveRoomAlias implements alias.RoomserverInternalAPI
func (r *RoomserverInternalAPI) RemoveRoomAlias(
ctx context.Context,
request *api.RemoveRoomAliasRequest,
response *api.RemoveRoomAliasResponse,
) error {
// Look up the room ID in the database
roomID, err := r.DB.GetRoomIDForAlias(ctx, request.Alias)
if err != nil {
return err
}
// Remove the dalias from the database
if err := r.DB.RemoveRoomAlias(ctx, request.Alias); err != nil {
return err
}
// Send an updated m.room.aliases event
// At this point we've already committed the alias to the database so we
// shouldn't cancel this request.
// TODO: Ensure that we send unsent events when if server restarts.
return r.sendUpdatedAliasesEvent(context.TODO(), request.UserID, roomID)
}
type roomAliasesContent struct {
Aliases []string `json:"aliases"`
}
// Build the updated m.room.aliases event to send to the room after addition or
// removal of an alias
func (r *RoomserverInternalAPI) sendUpdatedAliasesEvent(
ctx context.Context, userID string, roomID string,
) error {
serverName := string(r.Cfg.Matrix.ServerName)
builder := gomatrixserverlib.EventBuilder{
Sender: userID,
RoomID: roomID,
Type: "m.room.aliases",
StateKey: &serverName,
}
// Retrieve the updated list of aliases, marhal it and set it as the
// event's content
aliases, err := r.DB.GetAliasesForRoomID(ctx, roomID)
if err != nil {
return err
}
content := roomAliasesContent{Aliases: aliases}
rawContent, err := json.Marshal(content)
if err != nil {
return err
}
err = builder.SetContent(json.RawMessage(rawContent))
if err != nil {
return err
}
// Get needed state events and depth
eventsNeeded, err := gomatrixserverlib.StateNeededForEventBuilder(&builder)
if err != nil {
return err
}
if len(eventsNeeded.Tuples()) == 0 {
return errors.New("expecting state tuples for event builder, got none")
}
req := api.QueryLatestEventsAndStateRequest{
RoomID: roomID,
StateToFetch: eventsNeeded.Tuples(),
}
var res api.QueryLatestEventsAndStateResponse
if err = r.QueryLatestEventsAndState(ctx, &req, &res); err != nil {
return err
}
builder.Depth = res.Depth
builder.PrevEvents = res.LatestEvents
// Add auth events
authEvents := gomatrixserverlib.NewAuthEvents(nil)
for i := range res.StateEvents {
err = authEvents.AddEvent(&res.StateEvents[i].Event)
if err != nil {
return err
}
}
refs, err := eventsNeeded.AuthEventReferences(&authEvents)
if err != nil {
return err
}
builder.AuthEvents = refs
roomVersion, err := r.DB.GetRoomVersionForRoom(ctx, roomID)
if err != nil {
return err
}
// Build the event
now := time.Now()
event, err := builder.Build(
now, r.Cfg.Matrix.ServerName, r.Cfg.Matrix.KeyID,
r.Cfg.Matrix.PrivateKey, roomVersion,
)
if err != nil {
return err
}
// Create the request
ire := api.InputRoomEvent{
Kind: api.KindNew,
Event: event.Headered(roomVersion),
AuthEventIDs: event.AuthEventIDs(),
SendAsServer: serverName,
}
inputReq := api.InputRoomEventsRequest{
InputRoomEvents: []api.InputRoomEvent{ire},
}
var inputRes api.InputRoomEventsResponse
// Send the request
return r.InputRoomEvents(ctx, &inputReq, &inputRes)
}

287
roomserver/internal/api.go Normal file
View file

@ -0,0 +1,287 @@
package internal
import (
"encoding/json"
"net/http"
"sync"
"github.com/Shopify/sarama"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/common/caching"
"github.com/matrix-org/dendrite/common/config"
fsAPI "github.com/matrix-org/dendrite/federationsender/api"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
// RoomserverInternalAPI is an implementation of api.RoomserverInternalAPI
type RoomserverInternalAPI struct {
DB storage.Database
Cfg *config.Dendrite
Producer sarama.SyncProducer
ImmutableCache caching.ImmutableCache
ServerName gomatrixserverlib.ServerName
KeyRing gomatrixserverlib.JSONVerifier
FedClient *gomatrixserverlib.FederationClient
OutputRoomEventTopic string // Kafka topic for new output room events
mutex sync.Mutex // Protects calls to processRoomEvent
fsAPI fsAPI.FederationSenderInternalAPI
}
// SetupHTTP adds the RoomserverInternalAPI handlers to the http.ServeMux.
// nolint: gocyclo
func (r *RoomserverInternalAPI) SetupHTTP(servMux *http.ServeMux) {
servMux.Handle(api.RoomserverInputRoomEventsPath,
common.MakeInternalAPI("inputRoomEvents", func(req *http.Request) util.JSONResponse {
var request api.InputRoomEventsRequest
var response api.InputRoomEventsResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := r.InputRoomEvents(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
servMux.Handle(
api.RoomserverQueryLatestEventsAndStatePath,
common.MakeInternalAPI("queryLatestEventsAndState", func(req *http.Request) util.JSONResponse {
var request api.QueryLatestEventsAndStateRequest
var response api.QueryLatestEventsAndStateResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryLatestEventsAndState(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
servMux.Handle(
api.RoomserverQueryStateAfterEventsPath,
common.MakeInternalAPI("queryStateAfterEvents", func(req *http.Request) util.JSONResponse {
var request api.QueryStateAfterEventsRequest
var response api.QueryStateAfterEventsResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryStateAfterEvents(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
servMux.Handle(
api.RoomserverQueryEventsByIDPath,
common.MakeInternalAPI("queryEventsByID", func(req *http.Request) util.JSONResponse {
var request api.QueryEventsByIDRequest
var response api.QueryEventsByIDResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryEventsByID(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
servMux.Handle(
api.RoomserverQueryMembershipForUserPath,
common.MakeInternalAPI("QueryMembershipForUser", func(req *http.Request) util.JSONResponse {
var request api.QueryMembershipForUserRequest
var response api.QueryMembershipForUserResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryMembershipForUser(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
servMux.Handle(
api.RoomserverQueryMembershipsForRoomPath,
common.MakeInternalAPI("queryMembershipsForRoom", func(req *http.Request) util.JSONResponse {
var request api.QueryMembershipsForRoomRequest
var response api.QueryMembershipsForRoomResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryMembershipsForRoom(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
servMux.Handle(
api.RoomserverQueryInvitesForUserPath,
common.MakeInternalAPI("queryInvitesForUser", func(req *http.Request) util.JSONResponse {
var request api.QueryInvitesForUserRequest
var response api.QueryInvitesForUserResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryInvitesForUser(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
servMux.Handle(
api.RoomserverQueryServerAllowedToSeeEventPath,
common.MakeInternalAPI("queryServerAllowedToSeeEvent", func(req *http.Request) util.JSONResponse {
var request api.QueryServerAllowedToSeeEventRequest
var response api.QueryServerAllowedToSeeEventResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryServerAllowedToSeeEvent(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
servMux.Handle(
api.RoomserverQueryMissingEventsPath,
common.MakeInternalAPI("queryMissingEvents", func(req *http.Request) util.JSONResponse {
var request api.QueryMissingEventsRequest
var response api.QueryMissingEventsResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryMissingEvents(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
servMux.Handle(
api.RoomserverQueryStateAndAuthChainPath,
common.MakeInternalAPI("queryStateAndAuthChain", func(req *http.Request) util.JSONResponse {
var request api.QueryStateAndAuthChainRequest
var response api.QueryStateAndAuthChainResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryStateAndAuthChain(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
servMux.Handle(
api.RoomserverQueryBackfillPath,
common.MakeInternalAPI("QueryBackfill", func(req *http.Request) util.JSONResponse {
var request api.QueryBackfillRequest
var response api.QueryBackfillResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryBackfill(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
servMux.Handle(
api.RoomserverQueryRoomVersionCapabilitiesPath,
common.MakeInternalAPI("QueryRoomVersionCapabilities", func(req *http.Request) util.JSONResponse {
var request api.QueryRoomVersionCapabilitiesRequest
var response api.QueryRoomVersionCapabilitiesResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryRoomVersionCapabilities(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
servMux.Handle(
api.RoomserverQueryRoomVersionForRoomPath,
common.MakeInternalAPI("QueryRoomVersionForRoom", func(req *http.Request) util.JSONResponse {
var request api.QueryRoomVersionForRoomRequest
var response api.QueryRoomVersionForRoomResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.QueryRoomVersionForRoom(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
servMux.Handle(
api.RoomserverSetRoomAliasPath,
common.MakeInternalAPI("setRoomAlias", func(req *http.Request) util.JSONResponse {
var request api.SetRoomAliasRequest
var response api.SetRoomAliasResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.SetRoomAlias(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
servMux.Handle(
api.RoomserverGetRoomIDForAliasPath,
common.MakeInternalAPI("GetRoomIDForAlias", func(req *http.Request) util.JSONResponse {
var request api.GetRoomIDForAliasRequest
var response api.GetRoomIDForAliasResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.GetRoomIDForAlias(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
servMux.Handle(
api.RoomserverGetCreatorIDForAliasPath,
common.MakeInternalAPI("GetCreatorIDForAlias", func(req *http.Request) util.JSONResponse {
var request api.GetCreatorIDForAliasRequest
var response api.GetCreatorIDForAliasResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.GetCreatorIDForAlias(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
servMux.Handle(
api.RoomserverGetAliasesForRoomIDPath,
common.MakeInternalAPI("getAliasesForRoomID", func(req *http.Request) util.JSONResponse {
var request api.GetAliasesForRoomIDRequest
var response api.GetAliasesForRoomIDResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.GetAliasesForRoomID(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
servMux.Handle(
api.RoomserverRemoveRoomAliasPath,
common.MakeInternalAPI("removeRoomAlias", func(req *http.Request) util.JSONResponse {
var request api.RemoveRoomAliasRequest
var response api.RemoveRoomAliasResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := r.RemoveRoomAlias(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
}

View file

@ -0,0 +1,72 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package input contains the code processes new room events
package internal
import (
"context"
"encoding/json"
"github.com/Shopify/sarama"
"github.com/matrix-org/dendrite/roomserver/api"
fsAPI "github.com/matrix-org/dendrite/federationsender/api"
)
// SetFederationSenderInputAPI passes in a federation sender input API reference
// so that we can avoid the chicken-and-egg problem of both the roomserver input API
// and the federation sender input API being interdependent.
func (r *RoomserverInternalAPI) SetFederationSenderAPI(fsAPI fsAPI.FederationSenderInternalAPI) {
r.fsAPI = fsAPI
}
// WriteOutputEvents implements OutputRoomEventWriter
func (r *RoomserverInternalAPI) WriteOutputEvents(roomID string, updates []api.OutputEvent) error {
messages := make([]*sarama.ProducerMessage, len(updates))
for i := range updates {
value, err := json.Marshal(updates[i])
if err != nil {
return err
}
messages[i] = &sarama.ProducerMessage{
Topic: r.OutputRoomEventTopic,
Key: sarama.StringEncoder(roomID),
Value: sarama.ByteEncoder(value),
}
}
return r.Producer.SendMessages(messages)
}
// InputRoomEvents implements api.RoomserverInternalAPI
func (r *RoomserverInternalAPI) InputRoomEvents(
ctx context.Context,
request *api.InputRoomEventsRequest,
response *api.InputRoomEventsResponse,
) (err error) {
// We lock as processRoomEvent can only be called once at a time
r.mutex.Lock()
defer r.mutex.Unlock()
for i := range request.InputRoomEvents {
if response.EventID, err = processRoomEvent(ctx, r.DB, r, request.InputRoomEvents[i]); err != nil {
return err
}
}
for i := range request.InputInviteEvents {
if err = processInviteEvent(ctx, r.DB, r, request.InputInviteEvents[i]); err != nil {
return err
}
}
return nil
}

View file

@ -0,0 +1,244 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package internal
import (
"context"
"sort"
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
)
// checkAuthEvents checks that the event passes authentication checks
// Returns the numeric IDs for the auth events.
func checkAuthEvents(
ctx context.Context,
db storage.Database,
event gomatrixserverlib.HeaderedEvent,
authEventIDs []string,
) ([]types.EventNID, error) {
// Grab the numeric IDs for the supplied auth state events from the database.
authStateEntries, err := db.StateEntriesForEventIDs(ctx, authEventIDs)
if err != nil {
return nil, err
}
// TODO: check for duplicate state keys here.
// Work out which of the state events we actually need.
stateNeeded := gomatrixserverlib.StateNeededForAuth([]gomatrixserverlib.Event{event.Unwrap()})
// Load the actual auth events from the database.
authEvents, err := loadAuthEvents(ctx, db, stateNeeded, authStateEntries)
if err != nil {
return nil, err
}
// Check if the event is allowed.
if err = gomatrixserverlib.Allowed(event.Event, &authEvents); err != nil {
return nil, err
}
// Return the numeric IDs for the auth events.
result := make([]types.EventNID, len(authStateEntries))
for i := range authStateEntries {
result[i] = authStateEntries[i].EventNID
}
return result, nil
}
type authEvents struct {
stateKeyNIDMap map[string]types.EventStateKeyNID
state stateEntryMap
events eventMap
}
// Create implements gomatrixserverlib.AuthEventProvider
func (ae *authEvents) Create() (*gomatrixserverlib.Event, error) {
return ae.lookupEventWithEmptyStateKey(types.MRoomCreateNID), nil
}
// PowerLevels implements gomatrixserverlib.AuthEventProvider
func (ae *authEvents) PowerLevels() (*gomatrixserverlib.Event, error) {
return ae.lookupEventWithEmptyStateKey(types.MRoomPowerLevelsNID), nil
}
// JoinRules implements gomatrixserverlib.AuthEventProvider
func (ae *authEvents) JoinRules() (*gomatrixserverlib.Event, error) {
return ae.lookupEventWithEmptyStateKey(types.MRoomJoinRulesNID), nil
}
// Memmber implements gomatrixserverlib.AuthEventProvider
func (ae *authEvents) Member(stateKey string) (*gomatrixserverlib.Event, error) {
return ae.lookupEvent(types.MRoomMemberNID, stateKey), nil
}
// ThirdPartyInvite implements gomatrixserverlib.AuthEventProvider
func (ae *authEvents) ThirdPartyInvite(stateKey string) (*gomatrixserverlib.Event, error) {
return ae.lookupEvent(types.MRoomThirdPartyInviteNID, stateKey), nil
}
func (ae *authEvents) lookupEventWithEmptyStateKey(typeNID types.EventTypeNID) *gomatrixserverlib.Event {
eventNID, ok := ae.state.lookup(types.StateKeyTuple{
EventTypeNID: typeNID,
EventStateKeyNID: types.EmptyStateKeyNID,
})
if !ok {
return nil
}
event, ok := ae.events.lookup(eventNID)
if !ok {
return nil
}
return &event.Event
}
func (ae *authEvents) lookupEvent(typeNID types.EventTypeNID, stateKey string) *gomatrixserverlib.Event {
stateKeyNID, ok := ae.stateKeyNIDMap[stateKey]
if !ok {
return nil
}
eventNID, ok := ae.state.lookup(types.StateKeyTuple{
EventTypeNID: typeNID,
EventStateKeyNID: stateKeyNID,
})
if !ok {
return nil
}
event, ok := ae.events.lookup(eventNID)
if !ok {
return nil
}
return &event.Event
}
// loadAuthEvents loads the events needed for authentication from the supplied room state.
func loadAuthEvents(
ctx context.Context,
db storage.Database,
needed gomatrixserverlib.StateNeeded,
state []types.StateEntry,
) (result authEvents, err error) {
// Look up the numeric IDs for the state keys needed for auth.
var neededStateKeys []string
neededStateKeys = append(neededStateKeys, needed.Member...)
neededStateKeys = append(neededStateKeys, needed.ThirdPartyInvite...)
if result.stateKeyNIDMap, err = db.EventStateKeyNIDs(ctx, neededStateKeys); err != nil {
return
}
// Load the events we need.
result.state = state
var eventNIDs []types.EventNID
keyTuplesNeeded := stateKeyTuplesNeeded(result.stateKeyNIDMap, needed)
for _, keyTuple := range keyTuplesNeeded {
eventNID, ok := result.state.lookup(keyTuple)
if ok {
eventNIDs = append(eventNIDs, eventNID)
}
}
if result.events, err = db.Events(ctx, eventNIDs); err != nil {
return
}
return
}
// stateKeyTuplesNeeded works out which numeric state key tuples we need to authenticate some events.
func stateKeyTuplesNeeded(
stateKeyNIDMap map[string]types.EventStateKeyNID,
stateNeeded gomatrixserverlib.StateNeeded,
) []types.StateKeyTuple {
var keyTuples []types.StateKeyTuple
if stateNeeded.Create {
keyTuples = append(keyTuples, types.StateKeyTuple{
EventTypeNID: types.MRoomCreateNID,
EventStateKeyNID: types.EmptyStateKeyNID,
})
}
if stateNeeded.PowerLevels {
keyTuples = append(keyTuples, types.StateKeyTuple{
EventTypeNID: types.MRoomPowerLevelsNID,
EventStateKeyNID: types.EmptyStateKeyNID,
})
}
if stateNeeded.JoinRules {
keyTuples = append(keyTuples, types.StateKeyTuple{
EventTypeNID: types.MRoomJoinRulesNID,
EventStateKeyNID: types.EmptyStateKeyNID,
})
}
for _, member := range stateNeeded.Member {
stateKeyNID, ok := stateKeyNIDMap[member]
if ok {
keyTuples = append(keyTuples, types.StateKeyTuple{
EventTypeNID: types.MRoomMemberNID,
EventStateKeyNID: stateKeyNID,
})
}
}
for _, token := range stateNeeded.ThirdPartyInvite {
stateKeyNID, ok := stateKeyNIDMap[token]
if ok {
keyTuples = append(keyTuples, types.StateKeyTuple{
EventTypeNID: types.MRoomThirdPartyInviteNID,
EventStateKeyNID: stateKeyNID,
})
}
}
return keyTuples
}
// Map from event type, state key tuple to numeric event ID.
// Implemented using binary search on a sorted array.
type stateEntryMap []types.StateEntry
// lookup an entry in the event map.
func (m stateEntryMap) lookup(stateKey types.StateKeyTuple) (eventNID types.EventNID, ok bool) {
// Since the list is sorted we can implement this using binary search.
// This is faster than using a hash map.
// We don't have to worry about pathological cases because the keys are fixed
// size and are controlled by us.
list := []types.StateEntry(m)
i := sort.Search(len(list), func(i int) bool {
return !list[i].StateKeyTuple.LessThan(stateKey)
})
if i < len(list) && list[i].StateKeyTuple == stateKey {
ok = true
eventNID = list[i].EventNID
}
return
}
// Map from numeric event ID to event.
// Implemented using binary search on a sorted array.
type eventMap []types.Event
// lookup an entry in the event map.
func (m eventMap) lookup(eventNID types.EventNID) (event *types.Event, ok bool) {
// Since the list is sorted we can implement this using binary search.
// This is faster than using a hash map.
// We don't have to worry about pathological cases because the keys are fixed
// size are controlled by us.
list := []types.Event(m)
i := sort.Search(len(list), func(i int) bool {
return list[i].EventNID >= eventNID
})
if i < len(list) && list[i].EventNID == eventNID {
ok = true
event = &list[i]
}
return
}

View file

@ -0,0 +1,136 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package internal
import (
"testing"
"github.com/matrix-org/dendrite/roomserver/types"
)
func benchmarkStateEntryMapLookup(entries, lookups int64, b *testing.B) {
var list []types.StateEntry
for i := int64(0); i < entries; i++ {
list = append(list, types.StateEntry{
StateKeyTuple: types.StateKeyTuple{
EventTypeNID: types.EventTypeNID(i),
EventStateKeyNID: types.EventStateKeyNID(i),
},
EventNID: types.EventNID(i),
})
}
for i := 0; i < b.N; i++ {
entryMap := stateEntryMap(list)
for j := int64(0); j < lookups; j++ {
entryMap.lookup(types.StateKeyTuple{
EventTypeNID: types.EventTypeNID(j),
EventStateKeyNID: types.EventStateKeyNID(j),
})
}
}
}
func BenchmarkStateEntryMap100Lookup10(b *testing.B) {
benchmarkStateEntryMapLookup(100, 10, b)
}
func BenchmarkStateEntryMap1000Lookup100(b *testing.B) {
benchmarkStateEntryMapLookup(1000, 100, b)
}
func BenchmarkStateEntryMap100Lookup100(b *testing.B) {
benchmarkStateEntryMapLookup(100, 100, b)
}
func BenchmarkStateEntryMap1000Lookup10000(b *testing.B) {
benchmarkStateEntryMapLookup(1000, 10000, b)
}
func TestStateEntryMap(t *testing.T) {
entryMap := stateEntryMap([]types.StateEntry{
{StateKeyTuple: types.StateKeyTuple{EventTypeNID: 1, EventStateKeyNID: 1}, EventNID: 1},
{StateKeyTuple: types.StateKeyTuple{EventTypeNID: 1, EventStateKeyNID: 3}, EventNID: 2},
{StateKeyTuple: types.StateKeyTuple{EventTypeNID: 2, EventStateKeyNID: 1}, EventNID: 3},
})
testCases := []struct {
inputTypeNID types.EventTypeNID
inputStateKey types.EventStateKeyNID
wantOK bool
wantEventNID types.EventNID
}{
// Check that tuples that in the array are in the map.
{1, 1, true, 1},
{1, 3, true, 2},
{2, 1, true, 3},
// Check that tuples that aren't in the array aren't in the map.
{0, 0, false, 0},
{1, 2, false, 0},
{3, 1, false, 0},
}
for _, testCase := range testCases {
keyTuple := types.StateKeyTuple{EventTypeNID: testCase.inputTypeNID, EventStateKeyNID: testCase.inputStateKey}
gotEventNID, gotOK := entryMap.lookup(keyTuple)
if testCase.wantOK != gotOK {
t.Fatalf("stateEntryMap lookup(%v): want ok to be %v, got %v", keyTuple, testCase.wantOK, gotOK)
}
if testCase.wantEventNID != gotEventNID {
t.Fatalf("stateEntryMap lookup(%v): want eventNID to be %v, got %v", keyTuple, testCase.wantEventNID, gotEventNID)
}
}
}
func TestEventMap(t *testing.T) {
events := eventMap([]types.Event{
{EventNID: 1},
{EventNID: 2},
{EventNID: 3},
{EventNID: 5},
{EventNID: 8},
})
testCases := []struct {
inputEventNID types.EventNID
wantOK bool
wantEvent *types.Event
}{
// Check that the IDs that are in the array are in the map.
{1, true, &events[0]},
{2, true, &events[1]},
{3, true, &events[2]},
{5, true, &events[3]},
{8, true, &events[4]},
// Check that tuples that aren't in the array aren't in the map.
{0, false, nil},
{4, false, nil},
{6, false, nil},
{7, false, nil},
{9, false, nil},
}
for _, testCase := range testCases {
gotEvent, gotOK := events.lookup(testCase.inputEventNID)
if testCase.wantOK != gotOK {
t.Fatalf("eventMap lookup(%v): want ok to be %v, got %v", testCase.inputEventNID, testCase.wantOK, gotOK)
}
if testCase.wantEvent != gotEvent {
t.Fatalf("eventMap lookup(%v): want event to be %v, got %v", testCase.inputEventNID, testCase.wantEvent, gotEvent)
}
}
}

View file

@ -0,0 +1,274 @@
// Copyright 2017 Vector Creations Ltd
// Copyright 2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package internal
import (
"context"
"fmt"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/state"
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus"
)
// OutputRoomEventWriter has the APIs needed to write an event to the output logs.
type OutputRoomEventWriter interface {
// Write a list of events for a room
WriteOutputEvents(roomID string, updates []api.OutputEvent) error
}
// processRoomEvent can only be called once at a time
//
// TODO(#375): This should be rewritten to allow concurrent calls. The
// difficulty is in ensuring that we correctly annotate events with the correct
// state deltas when sending to kafka streams
func processRoomEvent(
ctx context.Context,
db storage.Database,
ow OutputRoomEventWriter,
input api.InputRoomEvent,
) (eventID string, err error) {
// Parse and validate the event JSON
headered := input.Event
event := headered.Unwrap()
// Check that the event passes authentication checks and work out the numeric IDs for the auth events.
authEventNIDs, err := checkAuthEvents(ctx, db, headered, input.AuthEventIDs)
if err != nil {
logrus.WithError(err).WithField("event_id", event.EventID()).Error("processRoomEvent.checkAuthEvents failed for event")
return
}
if input.TransactionID != nil {
tdID := input.TransactionID
eventID, err = db.GetTransactionEventID(
ctx, tdID.TransactionID, tdID.SessionID, event.Sender(),
)
// On error OR event with the transaction already processed/processesing
if err != nil || eventID != "" {
return
}
}
// Store the event
roomNID, stateAtEvent, err := db.StoreEvent(ctx, event, input.TransactionID, authEventNIDs)
if err != nil {
return
}
if input.Kind == api.KindOutlier {
// For outliers we can stop after we've stored the event itself as it
// doesn't have any associated state to store and we don't need to
// notify anyone about it.
logrus.WithField("event_id", event.EventID()).WithField("type", event.Type()).WithField("room", event.RoomID()).Info("Stored outlier")
return event.EventID(), nil
}
if stateAtEvent.BeforeStateSnapshotNID == 0 {
// We haven't calculated a state for this event yet.
// Lets calculate one.
err = calculateAndSetState(ctx, db, input, roomNID, &stateAtEvent, event)
if err != nil {
return
}
}
// Update the extremities of the event graph for the room
return event.EventID(), updateLatestEvents(
ctx, db, ow, roomNID, stateAtEvent, event, input.SendAsServer, input.TransactionID,
)
}
func calculateAndSetState(
ctx context.Context,
db storage.Database,
input api.InputRoomEvent,
roomNID types.RoomNID,
stateAtEvent *types.StateAtEvent,
event gomatrixserverlib.Event,
) error {
var err error
roomState := state.NewStateResolution(db)
if input.HasState {
// We've been told what the state at the event is so we don't need to calculate it.
// Check that those state events are in the database and store the state.
var entries []types.StateEntry
if entries, err = db.StateEntriesForEventIDs(ctx, input.StateEventIDs); err != nil {
return err
}
if stateAtEvent.BeforeStateSnapshotNID, err = db.AddState(ctx, roomNID, nil, entries); err != nil {
return err
}
} else {
// We haven't been told what the state at the event is so we need to calculate it from the prev_events
if stateAtEvent.BeforeStateSnapshotNID, err = roomState.CalculateAndStoreStateBeforeEvent(ctx, event, roomNID); err != nil {
return err
}
}
return db.SetState(ctx, stateAtEvent.EventNID, stateAtEvent.BeforeStateSnapshotNID)
}
func processInviteEvent(
ctx context.Context,
db storage.Database,
ow OutputRoomEventWriter,
input api.InputInviteEvent,
) (err error) {
if input.Event.StateKey() == nil {
return fmt.Errorf("invite must be a state event")
}
roomID := input.Event.RoomID()
targetUserID := *input.Event.StateKey()
log.WithFields(log.Fields{
"event_id": input.Event.EventID(),
"room_id": roomID,
"room_version": input.RoomVersion,
"target_user_id": targetUserID,
}).Info("processing invite event")
updater, err := db.MembershipUpdater(ctx, roomID, targetUserID, input.RoomVersion)
if err != nil {
return err
}
succeeded := false
defer func() {
txerr := common.EndTransaction(updater, &succeeded)
if err == nil && txerr != nil {
err = txerr
}
}()
if updater.IsJoin() {
// If the user is joined to the room then that takes precedence over this
// invite event. It makes little sense to move a user that is already
// joined to the room into the invite state.
// This could plausibly happen if an invite request raced with a join
// request for a user. For example if a user was invited to a public
// room and they joined the room at the same time as the invite was sent.
// The other way this could plausibly happen is if an invite raced with
// a kick. For example if a user was kicked from a room in error and in
// response someone else in the room re-invited them then it is possible
// for the invite request to race with the leave event so that the
// target receives invite before it learns that it has been kicked.
// There are a few ways this could be plausibly handled in the roomserver.
// 1) Store the invite, but mark it as retired. That will result in the
// permanent rejection of that invite event. So even if the target
// user leaves the room and the invite is retransmitted it will be
// ignored. However a new invite with a new event ID would still be
// accepted.
// 2) Silently discard the invite event. This means that if the event
// was retransmitted at a later date after the target user had left
// the room we would accept the invite. However since we hadn't told
// the sending server that the invite had been discarded it would
// have no reason to attempt to retry.
// 3) Signal the sending server that the user is already joined to the
// room.
// For now we will implement option 2. Since in the abesence of a retry
// mechanism it will be equivalent to option 1, and we don't have a
// signalling mechanism to implement option 3.
return nil
}
event := input.Event.Unwrap()
if len(input.InviteRoomState) > 0 {
// If we were supplied with some invite room state already (which is
// most likely to be if the event came in over federation) then use
// that.
if err = event.SetUnsignedField("invite_room_state", input.InviteRoomState); err != nil {
return err
}
} else {
// There's no invite room state, so let's have a go at building it
// up from local data (which is most likely to be if the event came
// from the CS API). If we know about the room then we can insert
// the invite room state, if we don't then we just fail quietly.
if irs, ierr := buildInviteStrippedState(ctx, db, input); ierr == nil {
if err = event.SetUnsignedField("invite_room_state", irs); err != nil {
return err
}
}
}
outputUpdates, err := updateToInviteMembership(updater, &event, nil, input.Event.RoomVersion)
if err != nil {
return err
}
if err = ow.WriteOutputEvents(roomID, outputUpdates); err != nil {
return err
}
succeeded = true
return nil
}
func buildInviteStrippedState(
ctx context.Context,
db storage.Database,
input api.InputInviteEvent,
) ([]gomatrixserverlib.InviteV2StrippedState, error) {
roomNID, err := db.RoomNID(ctx, input.Event.RoomID())
if err != nil || roomNID == 0 {
return nil, fmt.Errorf("room %q unknown", input.Event.RoomID())
}
stateWanted := []gomatrixserverlib.StateKeyTuple{}
for _, t := range []string{
gomatrixserverlib.MRoomName, gomatrixserverlib.MRoomCanonicalAlias,
gomatrixserverlib.MRoomAliases, gomatrixserverlib.MRoomJoinRules,
} {
stateWanted = append(stateWanted, gomatrixserverlib.StateKeyTuple{
EventType: t,
StateKey: "",
})
}
_, currentStateSnapshotNID, _, err := db.LatestEventIDs(ctx, roomNID)
if err != nil {
return nil, err
}
roomState := state.NewStateResolution(db)
stateEntries, err := roomState.LoadStateAtSnapshotForStringTuples(
ctx, currentStateSnapshotNID, stateWanted,
)
if err != nil {
return nil, err
}
stateNIDs := []types.EventNID{}
for _, stateNID := range stateEntries {
stateNIDs = append(stateNIDs, stateNID.EventNID)
}
stateEvents, err := db.Events(ctx, stateNIDs)
if err != nil {
return nil, err
}
inviteState := []gomatrixserverlib.InviteV2StrippedState{
gomatrixserverlib.NewInviteV2StrippedState(&input.Event.Event),
}
for _, event := range stateEvents {
inviteState = append(inviteState, gomatrixserverlib.NewInviteV2StrippedState(&event.Event))
}
return inviteState, nil
}

View file

@ -0,0 +1,307 @@
// Copyright 2017 Vector Creations Ltd
// Copyright 2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package internal
import (
"bytes"
"context"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/state"
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
// updateLatestEvents updates the list of latest events for this room in the database and writes the
// event to the output log.
// The latest events are the events that aren't referenced by another event in the database:
//
// Time goes down the page. 1 is the m.room.create event (root).
//
// 1 After storing 1 the latest events are {1}
// | After storing 2 the latest events are {2}
// 2 After storing 3 the latest events are {3}
// / \ After storing 4 the latest events are {3,4}
// 3 4 After storing 5 the latest events are {5,4}
// | | After storing 6 the latest events are {5,6}
// 5 6 <--- latest After storing 7 the latest events are {6,7}
// |
// 7 <----- latest
//
// Can only be called once at a time
func updateLatestEvents(
ctx context.Context,
db storage.Database,
ow OutputRoomEventWriter,
roomNID types.RoomNID,
stateAtEvent types.StateAtEvent,
event gomatrixserverlib.Event,
sendAsServer string,
transactionID *api.TransactionID,
) (err error) {
updater, err := db.GetLatestEventsForUpdate(ctx, roomNID)
if err != nil {
return
}
succeeded := false
defer func() {
txerr := common.EndTransaction(updater, &succeeded)
if err == nil && txerr != nil {
err = txerr
}
}()
u := latestEventsUpdater{
ctx: ctx, db: db, updater: updater, ow: ow, roomNID: roomNID,
stateAtEvent: stateAtEvent, event: event, sendAsServer: sendAsServer,
transactionID: transactionID,
}
if err = u.doUpdateLatestEvents(); err != nil {
return err
}
succeeded = true
return
}
// latestEventsUpdater tracks the state used to update the latest events in the
// room. It mostly just ferries state between the various function calls.
// The state could be passed using function arguments, but it becomes impractical
// when there are so many variables to pass around.
type latestEventsUpdater struct {
ctx context.Context
db storage.Database
updater types.RoomRecentEventsUpdater
ow OutputRoomEventWriter
roomNID types.RoomNID
stateAtEvent types.StateAtEvent
event gomatrixserverlib.Event
transactionID *api.TransactionID
// Which server to send this event as.
sendAsServer string
// The eventID of the event that was processed before this one.
lastEventIDSent string
// The latest events in the room after processing this event.
latest []types.StateAtEventAndReference
// The state entries removed from and added to the current state of the
// room as a result of processing this event. They are sorted lists.
removed []types.StateEntry
added []types.StateEntry
// The state entries that are removed and added to recover the state before
// the event being processed. They are sorted lists.
stateBeforeEventRemoves []types.StateEntry
stateBeforeEventAdds []types.StateEntry
// The snapshots of current state before and after processing this event
oldStateNID types.StateSnapshotNID
newStateNID types.StateSnapshotNID
}
func (u *latestEventsUpdater) doUpdateLatestEvents() error {
prevEvents := u.event.PrevEvents()
oldLatest := u.updater.LatestEvents()
u.lastEventIDSent = u.updater.LastEventIDSent()
u.oldStateNID = u.updater.CurrentStateSnapshotNID()
hasBeenSent, err := u.updater.HasEventBeenSent(u.stateAtEvent.EventNID)
if err != nil {
return err
} else if hasBeenSent {
// Already sent this event so we can stop processing
return nil
}
if err = u.updater.StorePreviousEvents(u.stateAtEvent.EventNID, prevEvents); err != nil {
return err
}
eventReference := u.event.EventReference()
// Check if this event is already referenced by another event in the room.
alreadyReferenced, err := u.updater.IsReferenced(eventReference)
if err != nil {
return err
}
u.latest = calculateLatest(oldLatest, alreadyReferenced, prevEvents, types.StateAtEventAndReference{
EventReference: eventReference,
StateAtEvent: u.stateAtEvent,
})
if err = u.latestState(); err != nil {
return err
}
updates, err := updateMemberships(u.ctx, u.db, u.updater, u.removed, u.added)
if err != nil {
return err
}
update, err := u.makeOutputNewRoomEvent()
if err != nil {
return err
}
updates = append(updates, *update)
// Send the event to the output logs.
// We do this inside the database transaction to ensure that we only mark an event as sent if we sent it.
// (n.b. this means that it's possible that the same event will be sent twice if the transaction fails but
// the write to the output log succeeds)
// TODO: This assumes that writing the event to the output log is synchronous. It should be possible to
// send the event asynchronously but we would need to ensure that 1) the events are written to the log in
// the correct order, 2) that pending writes are resent across restarts. In order to avoid writing all the
// necessary bookkeeping we'll keep the event sending synchronous for now.
if err = u.ow.WriteOutputEvents(u.event.RoomID(), updates); err != nil {
return err
}
if err = u.updater.SetLatestEvents(u.roomNID, u.latest, u.stateAtEvent.EventNID, u.newStateNID); err != nil {
return err
}
return u.updater.MarkEventAsSent(u.stateAtEvent.EventNID)
}
func (u *latestEventsUpdater) latestState() error {
var err error
roomState := state.NewStateResolution(u.db)
latestStateAtEvents := make([]types.StateAtEvent, len(u.latest))
for i := range u.latest {
latestStateAtEvents[i] = u.latest[i].StateAtEvent
}
u.newStateNID, err = roomState.CalculateAndStoreStateAfterEvents(
u.ctx, u.roomNID, latestStateAtEvents,
)
if err != nil {
return err
}
u.removed, u.added, err = roomState.DifferenceBetweeenStateSnapshots(
u.ctx, u.oldStateNID, u.newStateNID,
)
if err != nil {
return err
}
u.stateBeforeEventRemoves, u.stateBeforeEventAdds, err = roomState.DifferenceBetweeenStateSnapshots(
u.ctx, u.newStateNID, u.stateAtEvent.BeforeStateSnapshotNID,
)
return err
}
func calculateLatest(
oldLatest []types.StateAtEventAndReference,
alreadyReferenced bool,
prevEvents []gomatrixserverlib.EventReference,
newEvent types.StateAtEventAndReference,
) []types.StateAtEventAndReference {
var alreadyInLatest bool
var newLatest []types.StateAtEventAndReference
for _, l := range oldLatest {
keep := true
for _, prevEvent := range prevEvents {
if l.EventID == prevEvent.EventID && bytes.Equal(l.EventSHA256, prevEvent.EventSHA256) {
// This event can be removed from the latest events cause we've found an event that references it.
// (If an event is referenced by another event then it can't be one of the latest events in the room
// because we have an event that comes after it)
keep = false
break
}
}
if l.EventNID == newEvent.EventNID {
alreadyInLatest = true
}
if keep {
// Keep the event in the latest events.
newLatest = append(newLatest, l)
}
}
if !alreadyReferenced && !alreadyInLatest {
// This event is not referenced by any of the events in the room
// and the event is not already in the latest events.
// Add it to the latest events
newLatest = append(newLatest, newEvent)
}
return newLatest
}
func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error) {
latestEventIDs := make([]string, len(u.latest))
for i := range u.latest {
latestEventIDs[i] = u.latest[i].EventID
}
roomVersion, err := u.db.GetRoomVersionForRoom(u.ctx, u.event.RoomID())
if err != nil {
return nil, err
}
ore := api.OutputNewRoomEvent{
Event: u.event.Headered(roomVersion),
LastSentEventID: u.lastEventIDSent,
LatestEventIDs: latestEventIDs,
TransactionID: u.transactionID,
}
var stateEventNIDs []types.EventNID
for _, entry := range u.added {
stateEventNIDs = append(stateEventNIDs, entry.EventNID)
}
for _, entry := range u.removed {
stateEventNIDs = append(stateEventNIDs, entry.EventNID)
}
for _, entry := range u.stateBeforeEventRemoves {
stateEventNIDs = append(stateEventNIDs, entry.EventNID)
}
for _, entry := range u.stateBeforeEventAdds {
stateEventNIDs = append(stateEventNIDs, entry.EventNID)
}
stateEventNIDs = stateEventNIDs[:util.SortAndUnique(eventNIDSorter(stateEventNIDs))]
eventIDMap, err := u.db.EventIDs(u.ctx, stateEventNIDs)
if err != nil {
return nil, err
}
for _, entry := range u.added {
ore.AddsStateEventIDs = append(ore.AddsStateEventIDs, eventIDMap[entry.EventNID])
}
for _, entry := range u.removed {
ore.RemovesStateEventIDs = append(ore.RemovesStateEventIDs, eventIDMap[entry.EventNID])
}
for _, entry := range u.stateBeforeEventRemoves {
ore.StateBeforeRemovesEventIDs = append(ore.StateBeforeRemovesEventIDs, eventIDMap[entry.EventNID])
}
for _, entry := range u.stateBeforeEventAdds {
ore.StateBeforeAddsEventIDs = append(ore.StateBeforeAddsEventIDs, eventIDMap[entry.EventNID])
}
ore.SendAsServer = u.sendAsServer
return &api.OutputEvent{
Type: api.OutputTypeNewRoomEvent,
NewRoomEvent: &ore,
}, nil
}
type eventNIDSorter []types.EventNID
func (s eventNIDSorter) Len() int { return len(s) }
func (s eventNIDSorter) Less(i, j int) bool { return s[i] < s[j] }
func (s eventNIDSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }

View file

@ -0,0 +1,306 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package internal
import (
"context"
"fmt"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
)
// updateMembership updates the current membership and the invites for each
// user affected by a change in the current state of the room.
// Returns a list of output events to write to the kafka log to inform the
// consumers about the invites added or retired by the change in current state.
func updateMemberships(
ctx context.Context,
db storage.Database,
updater types.RoomRecentEventsUpdater,
removed, added []types.StateEntry,
) ([]api.OutputEvent, error) {
changes := membershipChanges(removed, added)
var eventNIDs []types.EventNID
for _, change := range changes {
if change.addedEventNID != 0 {
eventNIDs = append(eventNIDs, change.addedEventNID)
}
if change.removedEventNID != 0 {
eventNIDs = append(eventNIDs, change.removedEventNID)
}
}
// Load the event JSON so we can look up the "membership" key.
// TODO: Maybe add a membership key to the events table so we can load that
// key without having to load the entire event JSON?
events, err := db.Events(ctx, eventNIDs)
if err != nil {
return nil, err
}
var updates []api.OutputEvent
for _, change := range changes {
var ae *gomatrixserverlib.Event
var re *gomatrixserverlib.Event
targetUserNID := change.EventStateKeyNID
if change.removedEventNID != 0 {
ev, _ := eventMap(events).lookup(change.removedEventNID)
if ev != nil {
re = &ev.Event
}
}
if change.addedEventNID != 0 {
ev, _ := eventMap(events).lookup(change.addedEventNID)
if ev != nil {
ae = &ev.Event
}
}
if updates, err = updateMembership(updater, targetUserNID, re, ae, updates); err != nil {
return nil, err
}
}
return updates, nil
}
func updateMembership(
updater types.RoomRecentEventsUpdater, targetUserNID types.EventStateKeyNID,
remove, add *gomatrixserverlib.Event,
updates []api.OutputEvent,
) ([]api.OutputEvent, error) {
var err error
// Default the membership to Leave if no event was added or removed.
oldMembership := gomatrixserverlib.Leave
newMembership := gomatrixserverlib.Leave
if remove != nil {
oldMembership, err = remove.Membership()
if err != nil {
return nil, err
}
}
if add != nil {
newMembership, err = add.Membership()
if err != nil {
return nil, err
}
}
if oldMembership == newMembership && newMembership != gomatrixserverlib.Join {
// If the membership is the same then nothing changed and we can return
// immediately, unless it's a Join update (e.g. profile update).
return updates, nil
}
mu, err := updater.MembershipUpdater(targetUserNID)
if err != nil {
return nil, err
}
switch newMembership {
case gomatrixserverlib.Invite:
return updateToInviteMembership(mu, add, updates, updater.RoomVersion())
case gomatrixserverlib.Join:
return updateToJoinMembership(mu, add, updates)
case gomatrixserverlib.Leave, gomatrixserverlib.Ban:
return updateToLeaveMembership(mu, add, newMembership, updates)
default:
panic(fmt.Errorf(
"input: membership %q is not one of the allowed values", newMembership,
))
}
}
func updateToInviteMembership(
mu types.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent,
roomVersion gomatrixserverlib.RoomVersion,
) ([]api.OutputEvent, error) {
// We may have already sent the invite to the user, either because we are
// reprocessing this event, or because the we received this invite from a
// remote server via the federation invite API. In those cases we don't need
// to send the event.
needsSending, err := mu.SetToInvite(*add)
if err != nil {
return nil, err
}
if needsSending {
// We notify the consumers using a special event even though we will
// notify them about the change in current state as part of the normal
// room event stream. This ensures that the consumers only have to
// consider a single stream of events when determining whether a user
// is invited, rather than having to combine multiple streams themselves.
onie := api.OutputNewInviteEvent{
Event: add.Headered(roomVersion),
RoomVersion: roomVersion,
}
updates = append(updates, api.OutputEvent{
Type: api.OutputTypeNewInviteEvent,
NewInviteEvent: &onie,
})
}
return updates, nil
}
func updateToJoinMembership(
mu types.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent,
) ([]api.OutputEvent, error) {
// If the user is already marked as being joined, we call SetToJoin to update
// the event ID then we can return immediately. Retired is ignored as there
// is no invite event to retire.
if mu.IsJoin() {
_, err := mu.SetToJoin(add.Sender(), add.EventID(), true)
if err != nil {
return nil, err
}
return updates, nil
}
// When we mark a user as being joined we will invalidate any invites that
// are active for that user. We notify the consumers that the invites have
// been retired using a special event, even though they could infer this
// by studying the state changes in the room event stream.
retired, err := mu.SetToJoin(add.Sender(), add.EventID(), false)
if err != nil {
return nil, err
}
for _, eventID := range retired {
orie := api.OutputRetireInviteEvent{
EventID: eventID,
Membership: gomatrixserverlib.Join,
RetiredByEventID: add.EventID(),
TargetUserID: *add.StateKey(),
}
updates = append(updates, api.OutputEvent{
Type: api.OutputTypeRetireInviteEvent,
RetireInviteEvent: &orie,
})
}
return updates, nil
}
func updateToLeaveMembership(
mu types.MembershipUpdater, add *gomatrixserverlib.Event,
newMembership string, updates []api.OutputEvent,
) ([]api.OutputEvent, error) {
// If the user is already neither joined, nor invited to the room then we
// can return immediately.
if mu.IsLeave() {
return updates, nil
}
// When we mark a user as having left we will invalidate any invites that
// are active for that user. We notify the consumers that the invites have
// been retired using a special event, even though they could infer this
// by studying the state changes in the room event stream.
retired, err := mu.SetToLeave(add.Sender(), add.EventID())
if err != nil {
return nil, err
}
for _, eventID := range retired {
orie := api.OutputRetireInviteEvent{
EventID: eventID,
Membership: newMembership,
RetiredByEventID: add.EventID(),
TargetUserID: *add.StateKey(),
}
updates = append(updates, api.OutputEvent{
Type: api.OutputTypeRetireInviteEvent,
RetireInviteEvent: &orie,
})
}
return updates, nil
}
// membershipChanges pairs up the membership state changes from a sorted list
// of state removed and a sorted list of state added.
func membershipChanges(removed, added []types.StateEntry) []stateChange {
changes := pairUpChanges(removed, added)
var result []stateChange
for _, c := range changes {
if c.EventTypeNID == types.MRoomMemberNID {
result = append(result, c)
}
}
return result
}
type stateChange struct {
types.StateKeyTuple
removedEventNID types.EventNID
addedEventNID types.EventNID
}
// pairUpChanges pairs up the state events added and removed for each type,
// state key tuple. Assumes that removed and added are sorted.
func pairUpChanges(removed, added []types.StateEntry) []stateChange {
var ai int
var ri int
var result []stateChange
for {
switch {
case ai == len(added):
// We've reached the end of the added entries.
// The rest of the removed list are events that were removed without
// an event with the same state key being added.
for _, s := range removed[ri:] {
result = append(result, stateChange{
StateKeyTuple: s.StateKeyTuple,
removedEventNID: s.EventNID,
})
}
return result
case ri == len(removed):
// We've reached the end of the removed entries.
// The rest of the added list are events that were added without
// an event with the same state key being removed.
for _, s := range added[ai:] {
result = append(result, stateChange{
StateKeyTuple: s.StateKeyTuple,
addedEventNID: s.EventNID,
})
}
return result
case added[ai].StateKeyTuple == removed[ri].StateKeyTuple:
// The tuple is in both lists so an event with that key is being
// removed and another event with the same key is being added.
result = append(result, stateChange{
StateKeyTuple: added[ai].StateKeyTuple,
removedEventNID: removed[ri].EventNID,
addedEventNID: added[ai].EventNID,
})
ai++
ri++
case added[ai].StateKeyTuple.LessThan(removed[ri].StateKeyTuple):
// The lists are sorted so the added entry being less than the
// removed entry means that the added event was added without an
// event with the same key being removed.
result = append(result, stateChange{
StateKeyTuple: added[ai].StateKeyTuple,
addedEventNID: added[ai].EventNID,
})
ai++
default:
// Reaching the default case implies that the removed entry is less
// than the added entry. Since the lists are sorted this means that
// the removed event was removed without an event with the same
// key being added.
result = append(result, stateChange{
StateKeyTuple: removed[ai].StateKeyTuple,
removedEventNID: removed[ri].EventNID,
})
ri++
}
}
}

View file

@ -0,0 +1,968 @@
// Copyright 2017 Vector Creations Ltd
// Copyright 2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package internal
import (
"context"
"fmt"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/auth"
"github.com/matrix-org/dendrite/roomserver/state"
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/dendrite/roomserver/version"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
)
// QueryLatestEventsAndState implements api.RoomserverInternalAPI
func (r *RoomserverInternalAPI) QueryLatestEventsAndState(
ctx context.Context,
request *api.QueryLatestEventsAndStateRequest,
response *api.QueryLatestEventsAndStateResponse,
) error {
roomVersion, err := r.DB.GetRoomVersionForRoom(ctx, request.RoomID)
if err != nil {
response.RoomExists = false
return nil
}
roomState := state.NewStateResolution(r.DB)
response.QueryLatestEventsAndStateRequest = *request
roomNID, err := r.DB.RoomNIDExcludingStubs(ctx, request.RoomID)
if err != nil {
return err
}
if roomNID == 0 {
return nil
}
response.RoomExists = true
response.RoomVersion = roomVersion
var currentStateSnapshotNID types.StateSnapshotNID
response.LatestEvents, currentStateSnapshotNID, response.Depth, err =
r.DB.LatestEventIDs(ctx, roomNID)
if err != nil {
return err
}
var stateEntries []types.StateEntry
if len(request.StateToFetch) == 0 {
// Look up all room state.
stateEntries, err = roomState.LoadStateAtSnapshot(
ctx, currentStateSnapshotNID,
)
} else {
// Look up the current state for the requested tuples.
stateEntries, err = roomState.LoadStateAtSnapshotForStringTuples(
ctx, currentStateSnapshotNID, request.StateToFetch,
)
}
if err != nil {
return err
}
stateEvents, err := r.loadStateEvents(ctx, stateEntries)
if err != nil {
return err
}
for _, event := range stateEvents {
response.StateEvents = append(response.StateEvents, event.Headered(roomVersion))
}
return nil
}
// QueryStateAfterEvents implements api.RoomserverInternalAPI
func (r *RoomserverInternalAPI) QueryStateAfterEvents(
ctx context.Context,
request *api.QueryStateAfterEventsRequest,
response *api.QueryStateAfterEventsResponse,
) error {
roomVersion, err := r.DB.GetRoomVersionForRoom(ctx, request.RoomID)
if err != nil {
response.RoomExists = false
return nil
}
roomState := state.NewStateResolution(r.DB)
response.QueryStateAfterEventsRequest = *request
roomNID, err := r.DB.RoomNIDExcludingStubs(ctx, request.RoomID)
if err != nil {
return err
}
if roomNID == 0 {
return nil
}
response.RoomExists = true
response.RoomVersion = roomVersion
prevStates, err := r.DB.StateAtEventIDs(ctx, request.PrevEventIDs)
if err != nil {
switch err.(type) {
case types.MissingEventError:
return nil
default:
return err
}
}
response.PrevEventsExist = true
// Look up the currrent state for the requested tuples.
stateEntries, err := roomState.LoadStateAfterEventsForStringTuples(
ctx, roomNID, prevStates, request.StateToFetch,
)
if err != nil {
return err
}
stateEvents, err := r.loadStateEvents(ctx, stateEntries)
if err != nil {
return err
}
for _, event := range stateEvents {
response.StateEvents = append(response.StateEvents, event.Headered(roomVersion))
}
return nil
}
// QueryEventsByID implements api.RoomserverInternalAPI
func (r *RoomserverInternalAPI) QueryEventsByID(
ctx context.Context,
request *api.QueryEventsByIDRequest,
response *api.QueryEventsByIDResponse,
) error {
response.QueryEventsByIDRequest = *request
eventNIDMap, err := r.DB.EventNIDs(ctx, request.EventIDs)
if err != nil {
return err
}
var eventNIDs []types.EventNID
for _, nid := range eventNIDMap {
eventNIDs = append(eventNIDs, nid)
}
events, err := r.loadEvents(ctx, eventNIDs)
if err != nil {
return err
}
for _, event := range events {
roomVersion, verr := r.DB.GetRoomVersionForRoom(ctx, event.RoomID())
if verr != nil {
return verr
}
response.Events = append(response.Events, event.Headered(roomVersion))
}
return nil
}
func (r *RoomserverInternalAPI) loadStateEvents(
ctx context.Context, stateEntries []types.StateEntry,
) ([]gomatrixserverlib.Event, error) {
eventNIDs := make([]types.EventNID, len(stateEntries))
for i := range stateEntries {
eventNIDs[i] = stateEntries[i].EventNID
}
return r.loadEvents(ctx, eventNIDs)
}
func (r *RoomserverInternalAPI) loadEvents(
ctx context.Context, eventNIDs []types.EventNID,
) ([]gomatrixserverlib.Event, error) {
stateEvents, err := r.DB.Events(ctx, eventNIDs)
if err != nil {
return nil, err
}
result := make([]gomatrixserverlib.Event, len(stateEvents))
for i := range stateEvents {
result[i] = stateEvents[i].Event
}
return result, nil
}
// QueryMembershipForUser implements api.RoomserverInternalAPI
func (r *RoomserverInternalAPI) QueryMembershipForUser(
ctx context.Context,
request *api.QueryMembershipForUserRequest,
response *api.QueryMembershipForUserResponse,
) error {
roomNID, err := r.DB.RoomNID(ctx, request.RoomID)
if err != nil {
return err
}
membershipEventNID, stillInRoom, err := r.DB.GetMembership(ctx, roomNID, request.UserID)
if err != nil {
return err
}
if membershipEventNID == 0 {
response.HasBeenInRoom = false
return nil
}
response.IsInRoom = stillInRoom
eventIDMap, err := r.DB.EventIDs(ctx, []types.EventNID{membershipEventNID})
if err != nil {
return err
}
response.EventID = eventIDMap[membershipEventNID]
return nil
}
// QueryMembershipsForRoom implements api.RoomserverInternalAPI
func (r *RoomserverInternalAPI) QueryMembershipsForRoom(
ctx context.Context,
request *api.QueryMembershipsForRoomRequest,
response *api.QueryMembershipsForRoomResponse,
) error {
roomNID, err := r.DB.RoomNID(ctx, request.RoomID)
if err != nil {
return err
}
membershipEventNID, stillInRoom, err := r.DB.GetMembership(ctx, roomNID, request.Sender)
if err != nil {
return err
}
if membershipEventNID == 0 {
response.HasBeenInRoom = false
response.JoinEvents = nil
return nil
}
response.HasBeenInRoom = true
response.JoinEvents = []gomatrixserverlib.ClientEvent{}
var events []types.Event
var stateEntries []types.StateEntry
if stillInRoom {
var eventNIDs []types.EventNID
eventNIDs, err = r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, request.JoinedOnly)
if err != nil {
return err
}
events, err = r.DB.Events(ctx, eventNIDs)
} else {
stateEntries, err = stateBeforeEvent(ctx, r.DB, membershipEventNID)
if err != nil {
logrus.WithField("membership_event_nid", membershipEventNID).WithError(err).Error("failed to load state before event")
return err
}
events, err = getMembershipsAtState(ctx, r.DB, stateEntries, request.JoinedOnly)
}
if err != nil {
return err
}
for _, event := range events {
clientEvent := gomatrixserverlib.ToClientEvent(event.Event, gomatrixserverlib.FormatAll)
response.JoinEvents = append(response.JoinEvents, clientEvent)
}
return nil
}
func stateBeforeEvent(ctx context.Context, db storage.Database, eventNID types.EventNID) ([]types.StateEntry, error) {
roomState := state.NewStateResolution(db)
// Lookup the event NID
eIDs, err := db.EventIDs(ctx, []types.EventNID{eventNID})
if err != nil {
return nil, err
}
eventIDs := []string{eIDs[eventNID]}
prevState, err := db.StateAtEventIDs(ctx, eventIDs)
if err != nil {
return nil, err
}
// Fetch the state as it was when this event was fired
return roomState.LoadCombinedStateAfterEvents(ctx, prevState)
}
// getMembershipsAtState filters the state events to
// only keep the "m.room.member" events with a "join" membership. These events are returned.
// Returns an error if there was an issue fetching the events.
func getMembershipsAtState(
ctx context.Context, db storage.Database, stateEntries []types.StateEntry, joinedOnly bool,
) ([]types.Event, error) {
var eventNIDs []types.EventNID
for _, entry := range stateEntries {
// Filter the events to retrieve to only keep the membership events
if entry.EventTypeNID == types.MRoomMemberNID {
eventNIDs = append(eventNIDs, entry.EventNID)
}
}
// Get all of the events in this state
stateEvents, err := db.Events(ctx, eventNIDs)
if err != nil {
return nil, err
}
if !joinedOnly {
return stateEvents, nil
}
// Filter the events to only keep the "join" membership events
var events []types.Event
for _, event := range stateEvents {
membership, err := event.Membership()
if err != nil {
return nil, err
}
if membership == gomatrixserverlib.Join {
events = append(events, event)
}
}
return events, nil
}
// QueryInvitesForUser implements api.RoomserverInternalAPI
func (r *RoomserverInternalAPI) QueryInvitesForUser(
ctx context.Context,
request *api.QueryInvitesForUserRequest,
response *api.QueryInvitesForUserResponse,
) error {
roomNID, err := r.DB.RoomNID(ctx, request.RoomID)
if err != nil {
return err
}
targetUserNIDs, err := r.DB.EventStateKeyNIDs(ctx, []string{request.TargetUserID})
if err != nil {
return err
}
targetUserNID := targetUserNIDs[request.TargetUserID]
senderUserNIDs, err := r.DB.GetInvitesForUser(ctx, roomNID, targetUserNID)
if err != nil {
return err
}
senderUserIDs, err := r.DB.EventStateKeys(ctx, senderUserNIDs)
if err != nil {
return err
}
for _, senderUserID := range senderUserIDs {
response.InviteSenderUserIDs = append(response.InviteSenderUserIDs, senderUserID)
}
return nil
}
// QueryServerAllowedToSeeEvent implements api.RoomserverInternalAPI
func (r *RoomserverInternalAPI) QueryServerAllowedToSeeEvent(
ctx context.Context,
request *api.QueryServerAllowedToSeeEventRequest,
response *api.QueryServerAllowedToSeeEventResponse,
) (err error) {
events, err := r.DB.EventsFromIDs(ctx, []string{request.EventID})
if err != nil {
return
}
if len(events) == 0 {
response.AllowedToSeeEvent = false // event doesn't exist so not allowed to see
return
}
isServerInRoom, err := r.isServerCurrentlyInRoom(ctx, request.ServerName, events[0].RoomID())
if err != nil {
return
}
response.AllowedToSeeEvent, err = r.checkServerAllowedToSeeEvent(
ctx, request.EventID, request.ServerName, isServerInRoom,
)
return
}
func (r *RoomserverInternalAPI) checkServerAllowedToSeeEvent(
ctx context.Context, eventID string, serverName gomatrixserverlib.ServerName, isServerInRoom bool,
) (bool, error) {
roomState := state.NewStateResolution(r.DB)
stateEntries, err := roomState.LoadStateAtEvent(ctx, eventID)
if err != nil {
return false, err
}
// TODO: We probably want to make it so that we don't have to pull
// out all the state if possible.
stateAtEvent, err := r.loadStateEvents(ctx, stateEntries)
if err != nil {
return false, err
}
return auth.IsServerAllowed(serverName, isServerInRoom, stateAtEvent), nil
}
// QueryMissingEvents implements api.RoomserverInternalAPI
func (r *RoomserverInternalAPI) QueryMissingEvents(
ctx context.Context,
request *api.QueryMissingEventsRequest,
response *api.QueryMissingEventsResponse,
) error {
var front []string
eventsToFilter := make(map[string]bool, len(request.LatestEvents))
visited := make(map[string]bool, request.Limit) // request.Limit acts as a hint to size.
for _, id := range request.EarliestEvents {
visited[id] = true
}
for _, id := range request.LatestEvents {
if !visited[id] {
front = append(front, id)
eventsToFilter[id] = true
}
}
resultNIDs, err := r.scanEventTree(ctx, front, visited, request.Limit, request.ServerName)
if err != nil {
return err
}
loadedEvents, err := r.loadEvents(ctx, resultNIDs)
if err != nil {
return err
}
response.Events = make([]gomatrixserverlib.HeaderedEvent, 0, len(loadedEvents)-len(eventsToFilter))
for _, event := range loadedEvents {
if !eventsToFilter[event.EventID()] {
roomVersion, verr := r.DB.GetRoomVersionForRoom(ctx, event.RoomID())
if verr != nil {
return verr
}
response.Events = append(response.Events, event.Headered(roomVersion))
}
}
return err
}
// QueryBackfill implements api.RoomServerQueryAPI
func (r *RoomserverInternalAPI) QueryBackfill(
ctx context.Context,
request *api.QueryBackfillRequest,
response *api.QueryBackfillResponse,
) error {
// if we are requesting the backfill then we need to do a federation hit
// TODO: we could be more sensible and fetch as many events we already have then request the rest
// which is what the syncapi does already.
if request.ServerName == r.ServerName {
return r.backfillViaFederation(ctx, request, response)
}
// someone else is requesting the backfill, try to service their request.
var err error
var front []string
// The limit defines the maximum number of events to retrieve, so it also
// defines the highest number of elements in the map below.
visited := make(map[string]bool, request.Limit)
// The provided event IDs have already been seen by the request's emitter,
// and will be retrieved anyway, so there's no need to care about them if
// they appear in our exploration of the event tree.
for _, id := range request.EarliestEventsIDs {
visited[id] = true
}
front = request.EarliestEventsIDs
// Scan the event tree for events to send back.
resultNIDs, err := r.scanEventTree(ctx, front, visited, request.Limit, request.ServerName)
if err != nil {
return err
}
// Retrieve events from the list that was filled previously.
var loadedEvents []gomatrixserverlib.Event
loadedEvents, err = r.loadEvents(ctx, resultNIDs)
if err != nil {
return err
}
for _, event := range loadedEvents {
roomVersion, verr := r.DB.GetRoomVersionForRoom(ctx, event.RoomID())
if verr != nil {
return verr
}
response.Events = append(response.Events, event.Headered(roomVersion))
}
return err
}
func (r *RoomserverInternalAPI) backfillViaFederation(ctx context.Context, req *api.QueryBackfillRequest, res *api.QueryBackfillResponse) error {
roomVer, err := r.DB.GetRoomVersionForRoom(ctx, req.RoomID)
if err != nil {
return fmt.Errorf("backfillViaFederation: unknown room version for room %s : %w", req.RoomID, err)
}
requester := newBackfillRequester(r.DB, r.FedClient, r.ServerName)
events, err := gomatrixserverlib.RequestBackfill(
ctx, requester,
r.KeyRing, req.RoomID, roomVer, req.EarliestEventsIDs, req.Limit)
if err != nil {
return err
}
logrus.WithField("room_id", req.RoomID).Infof("backfilled %d events", len(events))
// persist these new events - auth checks have already been done
roomNID, backfilledEventMap := persistEvents(ctx, r.DB, events)
if err != nil {
return err
}
for _, ev := range backfilledEventMap {
// now add state for these events
stateIDs, ok := requester.eventIDToBeforeStateIDs[ev.EventID()]
if !ok {
// this should be impossible as all events returned must have pass Step 5 of the PDU checks
// which requires a list of state IDs.
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to find state IDs for event which passed auth checks")
continue
}
var entries []types.StateEntry
if entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs); err != nil {
// attempt to fetch the missing events
r.fetchAndStoreMissingEvents(ctx, roomVer, requester, stateIDs)
// try again
entries, err = r.DB.StateEntriesForEventIDs(ctx, stateIDs)
if err != nil {
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to get state entries for event")
return err
}
}
var beforeStateSnapshotNID types.StateSnapshotNID
if beforeStateSnapshotNID, err = r.DB.AddState(ctx, roomNID, nil, entries); err != nil {
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to persist state entries to get snapshot nid")
return err
}
util.GetLogger(ctx).Infof("Backfilled event %s (nid=%d) getting snapshot %v with entries %+v", ev.EventID(), ev.EventNID, beforeStateSnapshotNID, entries)
if err = r.DB.SetState(ctx, ev.EventNID, beforeStateSnapshotNID); err != nil {
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to persist snapshot nid")
}
}
// TODO: update backwards extremities, as that should be moved from syncapi to roomserver at some point.
res.Events = events
return nil
}
func (r *RoomserverInternalAPI) isServerCurrentlyInRoom(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID string) (bool, error) {
roomNID, err := r.DB.RoomNID(ctx, roomID)
if err != nil {
return false, err
}
eventNIDs, err := r.DB.GetMembershipEventNIDsForRoom(ctx, roomNID, true)
if err != nil {
return false, err
}
events, err := r.DB.Events(ctx, eventNIDs)
if err != nil {
return false, err
}
gmslEvents := make([]gomatrixserverlib.Event, len(events))
for i := range events {
gmslEvents[i] = events[i].Event
}
return auth.IsAnyUserOnServerWithMembership(serverName, gmslEvents, gomatrixserverlib.Join), nil
}
// fetchAndStoreMissingEvents does a best-effort fetch and store of missing events specified in stateIDs. Returns no error as it is just
// best effort.
func (r *RoomserverInternalAPI) fetchAndStoreMissingEvents(ctx context.Context, roomVer gomatrixserverlib.RoomVersion,
backfillRequester *backfillRequester, stateIDs []string) {
servers := backfillRequester.servers
// work out which are missing
nidMap, err := r.DB.EventNIDs(ctx, stateIDs)
if err != nil {
util.GetLogger(ctx).WithError(err).Warn("cannot query missing events")
return
}
missingMap := make(map[string]*gomatrixserverlib.HeaderedEvent) // id -> event
for _, id := range stateIDs {
if _, ok := nidMap[id]; !ok {
missingMap[id] = nil
}
}
util.GetLogger(ctx).Infof("Fetching %d missing state events (from %d possible servers)", len(missingMap), len(servers))
// fetch the events from federation. Loop the servers first so if we find one that works we stick with them
for _, srv := range servers {
for id, ev := range missingMap {
if ev != nil {
continue // already found
}
logger := util.GetLogger(ctx).WithField("server", srv).WithField("event_id", id)
res, err := r.FedClient.GetEvent(ctx, srv, id)
if err != nil {
logger.WithError(err).Warn("failed to get event from server")
continue
}
loader := gomatrixserverlib.NewEventsLoader(roomVer, r.KeyRing, backfillRequester, backfillRequester.ProvideEvents, false)
result, err := loader.LoadAndVerify(ctx, res.PDUs, gomatrixserverlib.TopologicalOrderByPrevEvents)
if err != nil {
logger.WithError(err).Warn("failed to load and verify event")
continue
}
logger.Infof("returned %d PDUs which made events %+v", len(res.PDUs), result)
for _, res := range result {
if res.Error != nil {
logger.WithError(err).Warn("event failed PDU checks")
continue
}
missingMap[id] = res.Event
}
}
}
var newEvents []gomatrixserverlib.HeaderedEvent
for _, ev := range missingMap {
if ev != nil {
newEvents = append(newEvents, *ev)
}
}
util.GetLogger(ctx).Infof("Persisting %d new events", len(newEvents))
persistEvents(ctx, r.DB, newEvents)
}
// TODO: Remove this when we have tests to assert correctness of this function
// nolint:gocyclo
func (r *RoomserverInternalAPI) scanEventTree(
ctx context.Context, front []string, visited map[string]bool, limit int,
serverName gomatrixserverlib.ServerName,
) ([]types.EventNID, error) {
var resultNIDs []types.EventNID
var err error
var allowed bool
var events []types.Event
var next []string
var pre string
// TODO: add tests for this function to ensure it meets the contract that callers expect (and doc what that is supposed to be)
// Currently, callers like QueryBackfill will call scanEventTree with a pre-populated `visited` map, assuming that by doing
// so means that the events in that map will NOT be returned from this function. That is not currently true, resulting in
// duplicate events being sent in response to /backfill requests.
initialIgnoreList := make(map[string]bool, len(visited))
for k, v := range visited {
initialIgnoreList[k] = v
}
resultNIDs = make([]types.EventNID, 0, limit)
var checkedServerInRoom bool
var isServerInRoom bool
// Loop through the event IDs to retrieve the requested events and go
// through the whole tree (up to the provided limit) using the events'
// "prev_event" key.
BFSLoop:
for len(front) > 0 {
// Prevent unnecessary allocations: reset the slice only when not empty.
if len(next) > 0 {
next = make([]string, 0)
}
// Retrieve the events to process from the database.
events, err = r.DB.EventsFromIDs(ctx, front)
if err != nil {
return resultNIDs, err
}
if !checkedServerInRoom && len(events) > 0 {
// It's nasty that we have to extract the room ID from an event, but many federation requests
// only talk in event IDs, no room IDs at all (!!!)
ev := events[0]
isServerInRoom, err = r.isServerCurrentlyInRoom(ctx, serverName, ev.RoomID())
if err != nil {
util.GetLogger(ctx).WithError(err).Error("Failed to check if server is currently in room, assuming not.")
}
checkedServerInRoom = true
}
for _, ev := range events {
// Break out of the loop if the provided limit is reached.
if len(resultNIDs) == limit {
break BFSLoop
}
if !initialIgnoreList[ev.EventID()] {
// Update the list of events to retrieve.
resultNIDs = append(resultNIDs, ev.EventNID)
}
// Loop through the event's parents.
for _, pre = range ev.PrevEventIDs() {
// Only add an event to the list of next events to process if it
// hasn't been seen before.
if !visited[pre] {
visited[pre] = true
allowed, err = r.checkServerAllowedToSeeEvent(ctx, pre, serverName, isServerInRoom)
if err != nil {
util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).WithError(err).Error(
"Error checking if allowed to see event",
)
return resultNIDs, err
}
// If the event hasn't been seen before and the HS
// requesting to retrieve it is allowed to do so, add it to
// the list of events to retrieve.
if allowed {
next = append(next, pre)
} else {
util.GetLogger(ctx).WithField("server", serverName).WithField("event_id", pre).Info("Not allowed to see event")
}
}
}
}
// Repeat the same process with the parent events we just processed.
front = next
}
return resultNIDs, err
}
// QueryStateAndAuthChain implements api.RoomserverInternalAPI
func (r *RoomserverInternalAPI) QueryStateAndAuthChain(
ctx context.Context,
request *api.QueryStateAndAuthChainRequest,
response *api.QueryStateAndAuthChainResponse,
) error {
response.QueryStateAndAuthChainRequest = *request
roomNID, err := r.DB.RoomNIDExcludingStubs(ctx, request.RoomID)
if err != nil {
return err
}
if roomNID == 0 {
return nil
}
response.RoomExists = true
roomVersion, err := r.DB.GetRoomVersionForRoom(ctx, request.RoomID)
if err != nil {
return err
}
response.RoomVersion = roomVersion
stateEvents, err := r.loadStateAtEventIDs(ctx, request.PrevEventIDs)
if err != nil {
return err
}
response.PrevEventsExist = true
// add the auth event IDs for the current state events too
var authEventIDs []string
authEventIDs = append(authEventIDs, request.AuthEventIDs...)
for _, se := range stateEvents {
authEventIDs = append(authEventIDs, se.AuthEventIDs()...)
}
authEventIDs = util.UniqueStrings(authEventIDs) // de-dupe
authEvents, err := getAuthChain(ctx, r.DB.EventsFromIDs, authEventIDs)
if err != nil {
return err
}
if request.ResolveState {
if stateEvents, err = state.ResolveConflictsAdhoc(
roomVersion, stateEvents, authEvents,
); err != nil {
return err
}
}
for _, event := range stateEvents {
response.StateEvents = append(response.StateEvents, event.Headered(roomVersion))
}
for _, event := range authEvents {
response.AuthChainEvents = append(response.AuthChainEvents, event.Headered(roomVersion))
}
return err
}
func (r *RoomserverInternalAPI) loadStateAtEventIDs(ctx context.Context, eventIDs []string) ([]gomatrixserverlib.Event, error) {
roomState := state.NewStateResolution(r.DB)
prevStates, err := r.DB.StateAtEventIDs(ctx, eventIDs)
if err != nil {
switch err.(type) {
case types.MissingEventError:
return nil, nil
default:
return nil, err
}
}
// Look up the currrent state for the requested tuples.
stateEntries, err := roomState.LoadCombinedStateAfterEvents(
ctx, prevStates,
)
if err != nil {
return nil, err
}
return r.loadStateEvents(ctx, stateEntries)
}
type eventsFromIDs func(context.Context, []string) ([]types.Event, error)
// getAuthChain fetches the auth chain for the given auth events. An auth chain
// is the list of all events that are referenced in the auth_events section, and
// all their auth_events, recursively. The returned set of events contain the
// given events. Will *not* error if we don't have all auth events.
func getAuthChain(
ctx context.Context, fn eventsFromIDs, authEventIDs []string,
) ([]gomatrixserverlib.Event, error) {
// List of event IDs to fetch. On each pass, these events will be requested
// from the database and the `eventsToFetch` will be updated with any new
// events that we have learned about and need to find. When `eventsToFetch`
// is eventually empty, we should have reached the end of the chain.
eventsToFetch := authEventIDs
authEventsMap := make(map[string]gomatrixserverlib.Event)
for len(eventsToFetch) > 0 {
// Try to retrieve the events from the database.
events, err := fn(ctx, eventsToFetch)
if err != nil {
return nil, err
}
// We've now fetched these events so clear out `eventsToFetch`. Soon we may
// add newly discovered events to this for the next pass.
eventsToFetch = eventsToFetch[:0]
for _, event := range events {
// Store the event in the event map - this prevents us from requesting it
// from the database again.
authEventsMap[event.EventID()] = event.Event
// Extract all of the auth events from the newly obtained event. If we
// don't already have a record of the event, record it in the list of
// events we want to request for the next pass.
for _, authEvent := range event.AuthEvents() {
if _, ok := authEventsMap[authEvent.EventID]; !ok {
eventsToFetch = append(eventsToFetch, authEvent.EventID)
}
}
}
}
// We've now retrieved all of the events we can. Flatten them down into an
// array and return them.
var authEvents []gomatrixserverlib.Event
for _, event := range authEventsMap {
authEvents = append(authEvents, event)
}
return authEvents, nil
}
func persistEvents(ctx context.Context, db storage.Database, events []gomatrixserverlib.HeaderedEvent) (types.RoomNID, map[string]types.Event) {
var roomNID types.RoomNID
backfilledEventMap := make(map[string]types.Event)
for _, ev := range events {
nidMap, err := db.EventNIDs(ctx, ev.AuthEventIDs())
if err != nil { // this shouldn't happen as RequestBackfill already found them
logrus.WithError(err).WithField("auth_events", ev.AuthEventIDs()).Error("Failed to find one or more auth events")
continue
}
authNids := make([]types.EventNID, len(nidMap))
i := 0
for _, nid := range nidMap {
authNids[i] = nid
i++
}
var stateAtEvent types.StateAtEvent
roomNID, stateAtEvent, err = db.StoreEvent(ctx, ev.Unwrap(), nil, authNids)
if err != nil {
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("Failed to persist event")
continue
}
backfilledEventMap[ev.EventID()] = types.Event{
EventNID: stateAtEvent.StateEntry.EventNID,
Event: ev.Unwrap(),
}
}
return roomNID, backfilledEventMap
}
// QueryRoomVersionCapabilities implements api.RoomserverInternalAPI
func (r *RoomserverInternalAPI) QueryRoomVersionCapabilities(
ctx context.Context,
request *api.QueryRoomVersionCapabilitiesRequest,
response *api.QueryRoomVersionCapabilitiesResponse,
) error {
response.DefaultRoomVersion = version.DefaultRoomVersion()
response.AvailableRoomVersions = make(map[gomatrixserverlib.RoomVersion]string)
for v, desc := range version.SupportedRoomVersions() {
if desc.Stable {
response.AvailableRoomVersions[v] = "stable"
} else {
response.AvailableRoomVersions[v] = "unstable"
}
}
return nil
}
// QueryRoomVersionCapabilities implements api.RoomserverInternalAPI
func (r *RoomserverInternalAPI) QueryRoomVersionForRoom(
ctx context.Context,
request *api.QueryRoomVersionForRoomRequest,
response *api.QueryRoomVersionForRoomResponse,
) error {
if roomVersion, ok := r.ImmutableCache.GetRoomVersion(request.RoomID); ok {
response.RoomVersion = roomVersion
return nil
}
roomVersion, err := r.DB.GetRoomVersionForRoom(ctx, request.RoomID)
if err != nil {
return err
}
response.RoomVersion = roomVersion
r.ImmutableCache.StoreRoomVersion(request.RoomID, response.RoomVersion)
return nil
}

View file

@ -0,0 +1,278 @@
package internal
import (
"context"
"github.com/matrix-org/dendrite/roomserver/auth"
"github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
)
// backfillRequester implements gomatrixserverlib.BackfillRequester
type backfillRequester struct {
db storage.Database
fedClient *gomatrixserverlib.FederationClient
thisServer gomatrixserverlib.ServerName
// per-request state
servers []gomatrixserverlib.ServerName
eventIDToBeforeStateIDs map[string][]string
eventIDMap map[string]gomatrixserverlib.Event
}
func newBackfillRequester(db storage.Database, fedClient *gomatrixserverlib.FederationClient, thisServer gomatrixserverlib.ServerName) *backfillRequester {
return &backfillRequester{
db: db,
fedClient: fedClient,
thisServer: thisServer,
eventIDToBeforeStateIDs: make(map[string][]string),
eventIDMap: make(map[string]gomatrixserverlib.Event),
}
}
func (b *backfillRequester) StateIDsBeforeEvent(ctx context.Context, targetEvent gomatrixserverlib.HeaderedEvent) ([]string, error) {
b.eventIDMap[targetEvent.EventID()] = targetEvent.Unwrap()
if ids, ok := b.eventIDToBeforeStateIDs[targetEvent.EventID()]; ok {
return ids, nil
}
// if we have exactly 1 prev event and we know the state of the room at that prev event, then just roll forward the prev event.
// Else, we have to hit /state_ids because either we don't know the state at all at this event (new backwards extremity) or
// we don't know the result of state res to merge forks (2 or more prev_events)
if len(targetEvent.PrevEventIDs()) == 1 {
prevEventID := targetEvent.PrevEventIDs()[0]
prevEvent, ok := b.eventIDMap[prevEventID]
if !ok {
goto FederationHit
}
prevEventStateIDs, ok := b.eventIDToBeforeStateIDs[prevEventID]
if !ok {
goto FederationHit
}
newStateIDs := b.calculateNewStateIDs(targetEvent.Unwrap(), prevEvent, prevEventStateIDs)
if newStateIDs != nil {
b.eventIDToBeforeStateIDs[targetEvent.EventID()] = newStateIDs
return newStateIDs, nil
}
// else we failed to calculate the new state, so fallthrough
}
FederationHit:
var lastErr error
logrus.WithField("event_id", targetEvent.EventID()).Info("Requesting /state_ids at event")
for _, srv := range b.servers { // hit any valid server
c := gomatrixserverlib.FederatedStateProvider{
FedClient: b.fedClient,
RememberAuthEvents: false,
Server: srv,
}
res, err := c.StateIDsBeforeEvent(ctx, targetEvent)
if err != nil {
lastErr = err
continue
}
b.eventIDToBeforeStateIDs[targetEvent.EventID()] = res
return res, nil
}
return nil, lastErr
}
func (b *backfillRequester) calculateNewStateIDs(targetEvent, prevEvent gomatrixserverlib.Event, prevEventStateIDs []string) []string {
newStateIDs := prevEventStateIDs[:]
if prevEvent.StateKey() == nil {
// state is the same as the previous event
b.eventIDToBeforeStateIDs[targetEvent.EventID()] = newStateIDs
return newStateIDs
}
missingState := false // true if we are missing the info for a state event ID
foundEvent := false // true if we found a (type, state_key) match
// find which state ID to replace, if any
for i, id := range newStateIDs {
ev, ok := b.eventIDMap[id]
if !ok {
missingState = true
continue
}
// The state IDs BEFORE the target event are the state IDs BEFORE the prev_event PLUS the prev_event itself
if ev.Type() == prevEvent.Type() && ev.StateKey() != nil && *ev.StateKey() == *prevEvent.StateKey() {
newStateIDs[i] = prevEvent.EventID()
foundEvent = true
break
}
}
if !foundEvent && !missingState {
// we can be certain that this is new state
newStateIDs = append(newStateIDs, prevEvent.EventID())
foundEvent = true
}
if foundEvent {
b.eventIDToBeforeStateIDs[targetEvent.EventID()] = newStateIDs
return newStateIDs
}
return nil
}
func (b *backfillRequester) StateBeforeEvent(ctx context.Context, roomVer gomatrixserverlib.RoomVersion,
event gomatrixserverlib.HeaderedEvent, eventIDs []string) (map[string]*gomatrixserverlib.Event, error) {
// try to fetch the events from the database first
events, err := b.ProvideEvents(roomVer, eventIDs)
if err != nil {
// non-fatal, fallthrough
logrus.WithError(err).Info("Failed to fetch events")
} else {
logrus.Infof("Fetched %d/%d events from the database", len(events), len(eventIDs))
if len(events) == len(eventIDs) {
result := make(map[string]*gomatrixserverlib.Event)
for i := range events {
result[events[i].EventID()] = &events[i]
b.eventIDMap[events[i].EventID()] = events[i]
}
return result, nil
}
}
c := gomatrixserverlib.FederatedStateProvider{
FedClient: b.fedClient,
RememberAuthEvents: false,
Server: b.servers[0],
}
result, err := c.StateBeforeEvent(ctx, roomVer, event, eventIDs)
if err != nil {
return nil, err
}
for eventID, ev := range result {
b.eventIDMap[eventID] = *ev
}
return result, nil
}
// ServersAtEvent is called when trying to determine which server to request from.
// It returns a list of servers which can be queried for backfill requests. These servers
// will be servers that are in the room already. The entries at the beginning are preferred servers
// and will be tried first. An empty list will fail the request.
func (b *backfillRequester) ServersAtEvent(ctx context.Context, roomID, eventID string) (servers []gomatrixserverlib.ServerName) {
// getMembershipsBeforeEventNID requires a NID, so retrieving the NID for
// the event is necessary.
NIDs, err := b.db.EventNIDs(ctx, []string{eventID})
if err != nil {
logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to get event NID for event")
return
}
stateEntries, err := stateBeforeEvent(ctx, b.db, NIDs[eventID])
if err != nil {
logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to load state before event")
return
}
// possibly return all joined servers depending on history visiblity
memberEventsFromVis, err := joinEventsFromHistoryVisibility(ctx, b.db, roomID, stateEntries)
if err != nil {
logrus.WithError(err).Error("ServersAtEvent: failed calculate servers from history visibility rules")
return
}
logrus.Infof("ServersAtEvent including %d current events from history visibility", len(memberEventsFromVis))
// Retrieve all "m.room.member" state events of "join" membership, which
// contains the list of users in the room before the event, therefore all
// the servers in it at that moment.
memberEvents, err := getMembershipsAtState(ctx, b.db, stateEntries, true)
if err != nil {
logrus.WithField("event_id", eventID).WithError(err).Error("ServersAtEvent: failed to get memberships before event")
return
}
memberEvents = append(memberEvents, memberEventsFromVis...)
// Store the server names in a temporary map to avoid duplicates.
serverSet := make(map[gomatrixserverlib.ServerName]bool)
for _, event := range memberEvents {
serverSet[event.Origin()] = true
}
for server := range serverSet {
if server == b.thisServer {
continue
}
servers = append(servers, server)
}
b.servers = servers
return
}
// Backfill performs a backfill request to the given server.
// https://matrix.org/docs/spec/server_server/latest#get-matrix-federation-v1-backfill-roomid
func (b *backfillRequester) Backfill(ctx context.Context, server gomatrixserverlib.ServerName, roomID string,
fromEventIDs []string, limit int) (*gomatrixserverlib.Transaction, error) {
tx, err := b.fedClient.Backfill(ctx, server, roomID, limit, fromEventIDs)
return &tx, err
}
func (b *backfillRequester) ProvideEvents(roomVer gomatrixserverlib.RoomVersion, eventIDs []string) ([]gomatrixserverlib.Event, error) {
ctx := context.Background()
nidMap, err := b.db.EventNIDs(ctx, eventIDs)
if err != nil {
logrus.WithError(err).WithField("event_ids", eventIDs).Error("Failed to find events")
return nil, err
}
eventNIDs := make([]types.EventNID, len(nidMap))
i := 0
for _, nid := range nidMap {
eventNIDs[i] = nid
i++
}
eventsWithNids, err := b.db.Events(ctx, eventNIDs)
if err != nil {
logrus.WithError(err).WithField("event_nids", eventNIDs).Error("Failed to load events")
return nil, err
}
events := make([]gomatrixserverlib.Event, len(eventsWithNids))
for i := range eventsWithNids {
events[i] = eventsWithNids[i].Event
}
return events, nil
}
// joinEventsFromHistoryVisibility returns all CURRENTLY joined members if the provided state indicated a 'shared' history visibility.
// TODO: Long term we probably want a history_visibility table which stores eventNID | visibility_enum so we can just
// pull all events and then filter by that table.
func joinEventsFromHistoryVisibility(
ctx context.Context, db storage.Database, roomID string, stateEntries []types.StateEntry) ([]types.Event, error) {
var eventNIDs []types.EventNID
for _, entry := range stateEntries {
// Filter the events to retrieve to only keep the membership events
if entry.EventTypeNID == types.MRoomHistoryVisibilityNID && entry.EventStateKeyNID == types.EmptyStateKeyNID {
eventNIDs = append(eventNIDs, entry.EventNID)
break
}
}
// Get all of the events in this state
stateEvents, err := db.Events(ctx, eventNIDs)
if err != nil {
return nil, err
}
events := make([]gomatrixserverlib.Event, len(stateEvents))
for i := range stateEvents {
events[i] = stateEvents[i].Event
}
visibility := auth.HistoryVisibilityForRoom(events)
if visibility != "shared" {
logrus.Infof("ServersAtEvent history visibility not shared: %s", visibility)
return nil, nil
}
// get joined members
roomNID, err := db.RoomNID(ctx, roomID)
if err != nil {
return nil, err
}
joinEventNIDs, err := db.GetMembershipEventNIDsForRoom(ctx, roomNID, true)
if err != nil {
return nil, err
}
return db.Events(ctx, joinEventNIDs)
}

View file

@ -0,0 +1,157 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package internal
import (
"context"
"encoding/json"
"testing"
"github.com/matrix-org/dendrite/common/test"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
)
// used to implement RoomserverInternalAPIEventDB to test getAuthChain
type getEventDB struct {
eventMap map[string]gomatrixserverlib.Event
}
func createEventDB() *getEventDB {
return &getEventDB{
eventMap: make(map[string]gomatrixserverlib.Event),
}
}
// Adds a fake event to the storage with given auth events.
func (db *getEventDB) addFakeEvent(eventID string, authIDs []string) error {
authEvents := []gomatrixserverlib.EventReference{}
for _, authID := range authIDs {
authEvents = append(authEvents, gomatrixserverlib.EventReference{
EventID: authID,
})
}
builder := map[string]interface{}{
"event_id": eventID,
"auth_events": authEvents,
}
eventJSON, err := json.Marshal(&builder)
if err != nil {
return err
}
event, err := gomatrixserverlib.NewEventFromTrustedJSON(
eventJSON, false, gomatrixserverlib.RoomVersionV1,
)
if err != nil {
return err
}
db.eventMap[eventID] = event
return nil
}
// Adds multiple events at once, each entry in the map is an eventID and set of
// auth events that are converted to an event and added.
func (db *getEventDB) addFakeEvents(graph map[string][]string) error {
for eventID, authIDs := range graph {
err := db.addFakeEvent(eventID, authIDs)
if err != nil {
return err
}
}
return nil
}
// EventsFromIDs implements RoomserverInternalAPIEventDB
func (db *getEventDB) EventsFromIDs(ctx context.Context, eventIDs []string) (res []types.Event, err error) {
for _, evID := range eventIDs {
res = append(res, types.Event{
EventNID: 0,
Event: db.eventMap[evID],
})
}
return
}
func TestGetAuthChainSingle(t *testing.T) {
db := createEventDB()
err := db.addFakeEvents(map[string][]string{
"a": {},
"b": {"a"},
"c": {"a", "b"},
"d": {"b", "c"},
"e": {"a", "d"},
})
if err != nil {
t.Fatalf("Failed to add events to db: %v", err)
}
result, err := getAuthChain(context.TODO(), db.EventsFromIDs, []string{"e"})
if err != nil {
t.Fatalf("getAuthChain failed: %v", err)
}
var returnedIDs []string
for _, event := range result {
returnedIDs = append(returnedIDs, event.EventID())
}
expectedIDs := []string{"a", "b", "c", "d", "e"}
if !test.UnsortedStringSliceEqual(expectedIDs, returnedIDs) {
t.Fatalf("returnedIDs got '%v', expected '%v'", returnedIDs, expectedIDs)
}
}
func TestGetAuthChainMultiple(t *testing.T) {
db := createEventDB()
err := db.addFakeEvents(map[string][]string{
"a": {},
"b": {"a"},
"c": {"a", "b"},
"d": {"b", "c"},
"e": {"a", "d"},
"f": {"a", "b", "c"},
})
if err != nil {
t.Fatalf("Failed to add events to db: %v", err)
}
result, err := getAuthChain(context.TODO(), db.EventsFromIDs, []string{"e", "f"})
if err != nil {
t.Fatalf("getAuthChain failed: %v", err)
}
var returnedIDs []string
for _, event := range result {
returnedIDs = append(returnedIDs, event.EventID())
}
expectedIDs := []string{"a", "b", "c", "d", "e", "f"}
if !test.UnsortedStringSliceEqual(expectedIDs, returnedIDs) {
t.Fatalf("returnedIDs got '%v', expected '%v'", returnedIDs, expectedIDs)
}
}