Initial Store & Forward Implementation (#2917)

This adds store & forward relays into dendrite for p2p.
A few things have changed:
- new relay api serves new http endpoints for s&f federation
- updated outbound federation queueing which will attempt to forward
using s&f if appropriate
- database entries to track s&f relays for other nodes
This commit is contained in:
devonh 2023-01-23 17:55:12 +00:00 committed by GitHub
parent 48fa869fa3
commit 5b73592f5a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
77 changed files with 7646 additions and 1373 deletions

View file

@ -18,6 +18,7 @@ type FederationInternalAPI interface {
gomatrixserverlib.KeyDatabase
ClientFederationAPI
RoomserverFederationAPI
P2PFederationAPI
QueryServerKeys(ctx context.Context, request *QueryServerKeysRequest, response *QueryServerKeysResponse) error
LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error)
@ -30,7 +31,6 @@ type FederationInternalAPI interface {
request *PerformBroadcastEDURequest,
response *PerformBroadcastEDUResponse,
) error
PerformWakeupServers(
ctx context.Context,
request *PerformWakeupServersRequest,
@ -71,6 +71,15 @@ type RoomserverFederationAPI interface {
LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error)
}
type P2PFederationAPI interface {
// Relay Server sync api used in the pinecone demos.
P2PQueryRelayServers(
ctx context.Context,
request *P2PQueryRelayServersRequest,
response *P2PQueryRelayServersResponse,
) error
}
// KeyserverFederationAPI is a subset of gomatrixserverlib.FederationClient functions which the keyserver
// implements as proxy calls, with built-in backoff/retries/etc. Errors returned from functions in
// this interface are of type FederationClientError
@ -82,6 +91,7 @@ type KeyserverFederationAPI interface {
// an interface for gmsl.FederationClient - contains functions called by federationapi only.
type FederationClient interface {
P2PFederationClient
gomatrixserverlib.KeyClient
SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error)
@ -110,6 +120,11 @@ type FederationClient interface {
LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error)
}
type P2PFederationClient interface {
P2PSendTransactionToRelay(ctx context.Context, u gomatrixserverlib.UserID, t gomatrixserverlib.Transaction, forwardingServer gomatrixserverlib.ServerName) (res gomatrixserverlib.EmptyResp, err error)
P2PGetTransactionFromRelay(ctx context.Context, u gomatrixserverlib.UserID, prev gomatrixserverlib.RelayEntry, relayServer gomatrixserverlib.ServerName) (res gomatrixserverlib.RespGetRelayTransaction, err error)
}
// FederationClientError is returned from FederationClient methods in the event of a problem.
type FederationClientError struct {
Err string
@ -233,3 +248,11 @@ type InputPublicKeysRequest struct {
type InputPublicKeysResponse struct {
}
type P2PQueryRelayServersRequest struct {
Server gomatrixserverlib.ServerName
}
type P2PQueryRelayServersResponse struct {
RelayServers []gomatrixserverlib.ServerName
}

View file

@ -113,7 +113,10 @@ func NewInternalAPI(
_ = federationDB.RemoveAllServersFromBlacklist()
}
stats := statistics.NewStatistics(federationDB, cfg.FederationMaxRetries+1)
stats := statistics.NewStatistics(
federationDB,
cfg.FederationMaxRetries+1,
cfg.P2PFederationRetriesUntilAssumedOffline+1)
js, nats := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream)

View file

@ -109,13 +109,14 @@ func NewFederationInternalAPI(
func (a *FederationInternalAPI) isBlacklistedOrBackingOff(s gomatrixserverlib.ServerName) (*statistics.ServerStatistics, error) {
stats := a.statistics.ForServer(s)
until, blacklisted := stats.BackoffInfo()
if blacklisted {
if stats.Blacklisted() {
return stats, &api.FederationClientError{
Blacklisted: true,
}
}
now := time.Now()
until := stats.BackoffInfo()
if until != nil && now.Before(*until) {
return stats, &api.FederationClientError{
RetryAfter: time.Until(*until),
@ -163,7 +164,7 @@ func (a *FederationInternalAPI) doRequestIfNotBackingOffOrBlacklisted(
RetryAfter: retryAfter,
}
}
stats.Success()
stats.Success(statistics.SendDirect)
return res, nil
}
@ -171,7 +172,7 @@ func (a *FederationInternalAPI) doRequestIfNotBlacklisted(
s gomatrixserverlib.ServerName, request func() (interface{}, error),
) (interface{}, error) {
stats := a.statistics.ForServer(s)
if _, blacklisted := stats.BackoffInfo(); blacklisted {
if blacklisted := stats.Blacklisted(); blacklisted {
return stats, &api.FederationClientError{
Err: fmt.Sprintf("server %q is blacklisted", s),
Blacklisted: true,

View file

@ -0,0 +1,202 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package internal
import (
"context"
"fmt"
"testing"
"github.com/matrix-org/dendrite/federationapi/queue"
"github.com/matrix-org/dendrite/federationapi/statistics"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib"
"github.com/stretchr/testify/assert"
)
const (
FailuresUntilAssumedOffline = 3
FailuresUntilBlacklist = 8
)
func (t *testFedClient) QueryKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, keys map[string][]string) (gomatrixserverlib.RespQueryKeys, error) {
t.queryKeysCalled = true
if t.shouldFail {
return gomatrixserverlib.RespQueryKeys{}, fmt.Errorf("Failure")
}
return gomatrixserverlib.RespQueryKeys{}, nil
}
func (t *testFedClient) ClaimKeys(ctx context.Context, origin, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (gomatrixserverlib.RespClaimKeys, error) {
t.claimKeysCalled = true
if t.shouldFail {
return gomatrixserverlib.RespClaimKeys{}, fmt.Errorf("Failure")
}
return gomatrixserverlib.RespClaimKeys{}, nil
}
func TestFederationClientQueryKeys(t *testing.T) {
testDB := test.NewInMemoryFederationDatabase()
cfg := config.FederationAPI{
Matrix: &config.Global{
SigningIdentity: gomatrixserverlib.SigningIdentity{
ServerName: "server",
},
},
}
fedClient := &testFedClient{}
stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline)
queues := queue.NewOutgoingQueues(
testDB, process.NewProcessContext(),
false,
cfg.Matrix.ServerName, fedClient, nil, &stats,
nil,
)
fedapi := FederationInternalAPI{
db: testDB,
cfg: &cfg,
statistics: &stats,
federation: fedClient,
queues: queues,
}
_, err := fedapi.QueryKeys(context.Background(), "origin", "server", nil)
assert.Nil(t, err)
assert.True(t, fedClient.queryKeysCalled)
}
func TestFederationClientQueryKeysBlacklisted(t *testing.T) {
testDB := test.NewInMemoryFederationDatabase()
testDB.AddServerToBlacklist("server")
cfg := config.FederationAPI{
Matrix: &config.Global{
SigningIdentity: gomatrixserverlib.SigningIdentity{
ServerName: "server",
},
},
}
fedClient := &testFedClient{}
stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline)
queues := queue.NewOutgoingQueues(
testDB, process.NewProcessContext(),
false,
cfg.Matrix.ServerName, fedClient, nil, &stats,
nil,
)
fedapi := FederationInternalAPI{
db: testDB,
cfg: &cfg,
statistics: &stats,
federation: fedClient,
queues: queues,
}
_, err := fedapi.QueryKeys(context.Background(), "origin", "server", nil)
assert.NotNil(t, err)
assert.False(t, fedClient.queryKeysCalled)
}
func TestFederationClientQueryKeysFailure(t *testing.T) {
testDB := test.NewInMemoryFederationDatabase()
cfg := config.FederationAPI{
Matrix: &config.Global{
SigningIdentity: gomatrixserverlib.SigningIdentity{
ServerName: "server",
},
},
}
fedClient := &testFedClient{shouldFail: true}
stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline)
queues := queue.NewOutgoingQueues(
testDB, process.NewProcessContext(),
false,
cfg.Matrix.ServerName, fedClient, nil, &stats,
nil,
)
fedapi := FederationInternalAPI{
db: testDB,
cfg: &cfg,
statistics: &stats,
federation: fedClient,
queues: queues,
}
_, err := fedapi.QueryKeys(context.Background(), "origin", "server", nil)
assert.NotNil(t, err)
assert.True(t, fedClient.queryKeysCalled)
}
func TestFederationClientClaimKeys(t *testing.T) {
testDB := test.NewInMemoryFederationDatabase()
cfg := config.FederationAPI{
Matrix: &config.Global{
SigningIdentity: gomatrixserverlib.SigningIdentity{
ServerName: "server",
},
},
}
fedClient := &testFedClient{}
stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline)
queues := queue.NewOutgoingQueues(
testDB, process.NewProcessContext(),
false,
cfg.Matrix.ServerName, fedClient, nil, &stats,
nil,
)
fedapi := FederationInternalAPI{
db: testDB,
cfg: &cfg,
statistics: &stats,
federation: fedClient,
queues: queues,
}
_, err := fedapi.ClaimKeys(context.Background(), "origin", "server", nil)
assert.Nil(t, err)
assert.True(t, fedClient.claimKeysCalled)
}
func TestFederationClientClaimKeysBlacklisted(t *testing.T) {
testDB := test.NewInMemoryFederationDatabase()
testDB.AddServerToBlacklist("server")
cfg := config.FederationAPI{
Matrix: &config.Global{
SigningIdentity: gomatrixserverlib.SigningIdentity{
ServerName: "server",
},
},
}
fedClient := &testFedClient{}
stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline)
queues := queue.NewOutgoingQueues(
testDB, process.NewProcessContext(),
false,
cfg.Matrix.ServerName, fedClient, nil, &stats,
nil,
)
fedapi := FederationInternalAPI{
db: testDB,
cfg: &cfg,
statistics: &stats,
federation: fedClient,
queues: queues,
}
_, err := fedapi.ClaimKeys(context.Background(), "origin", "server", nil)
assert.NotNil(t, err)
assert.False(t, fedClient.claimKeysCalled)
}

View file

@ -14,6 +14,7 @@ import (
"github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/federationapi/consumers"
"github.com/matrix-org/dendrite/federationapi/statistics"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/version"
)
@ -24,6 +25,10 @@ func (r *FederationInternalAPI) PerformDirectoryLookup(
request *api.PerformDirectoryLookupRequest,
response *api.PerformDirectoryLookupResponse,
) (err error) {
if !r.shouldAttemptDirectFederation(request.ServerName) {
return fmt.Errorf("relay servers have no meaningful response for directory lookup.")
}
dir, err := r.federation.LookupRoomAlias(
ctx,
r.cfg.Matrix.ServerName,
@ -36,7 +41,7 @@ func (r *FederationInternalAPI) PerformDirectoryLookup(
}
response.RoomID = dir.RoomID
response.ServerNames = dir.Servers
r.statistics.ForServer(request.ServerName).Success()
r.statistics.ForServer(request.ServerName).Success(statistics.SendDirect)
return nil
}
@ -144,6 +149,10 @@ func (r *FederationInternalAPI) performJoinUsingServer(
supportedVersions []gomatrixserverlib.RoomVersion,
unsigned map[string]interface{},
) error {
if !r.shouldAttemptDirectFederation(serverName) {
return fmt.Errorf("relay servers have no meaningful response for join.")
}
_, origin, err := r.cfg.Matrix.SplitLocalID('@', userID)
if err != nil {
return err
@ -164,7 +173,7 @@ func (r *FederationInternalAPI) performJoinUsingServer(
r.statistics.ForServer(serverName).Failure()
return fmt.Errorf("r.federation.MakeJoin: %w", err)
}
r.statistics.ForServer(serverName).Success()
r.statistics.ForServer(serverName).Success(statistics.SendDirect)
// Set all the fields to be what they should be, this should be a no-op
// but it's possible that the remote server returned us something "odd"
@ -219,7 +228,7 @@ func (r *FederationInternalAPI) performJoinUsingServer(
r.statistics.ForServer(serverName).Failure()
return fmt.Errorf("r.federation.SendJoin: %w", err)
}
r.statistics.ForServer(serverName).Success()
r.statistics.ForServer(serverName).Success(statistics.SendDirect)
// If the remote server returned an event in the "event" key of
// the send_join request then we should use that instead. It may
@ -407,6 +416,10 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer(
serverName gomatrixserverlib.ServerName,
supportedVersions []gomatrixserverlib.RoomVersion,
) error {
if !r.shouldAttemptDirectFederation(serverName) {
return fmt.Errorf("relay servers have no meaningful response for outbound peek.")
}
// create a unique ID for this peek.
// for now we just use the room ID again. In future, if we ever
// support concurrent peeks to the same room with different filters
@ -446,7 +459,7 @@ func (r *FederationInternalAPI) performOutboundPeekUsingServer(
r.statistics.ForServer(serverName).Failure()
return fmt.Errorf("r.federation.Peek: %w", err)
}
r.statistics.ForServer(serverName).Success()
r.statistics.ForServer(serverName).Success(statistics.SendDirect)
// Work out if we support the room version that has been supplied in
// the peek response.
@ -516,6 +529,10 @@ func (r *FederationInternalAPI) PerformLeave(
// Try each server that we were provided until we land on one that
// successfully completes the make-leave send-leave dance.
for _, serverName := range request.ServerNames {
if !r.shouldAttemptDirectFederation(serverName) {
continue
}
// Try to perform a make_leave using the information supplied in the
// request.
respMakeLeave, err := r.federation.MakeLeave(
@ -585,7 +602,7 @@ func (r *FederationInternalAPI) PerformLeave(
continue
}
r.statistics.ForServer(serverName).Success()
r.statistics.ForServer(serverName).Success(statistics.SendDirect)
return nil
}
@ -616,6 +633,12 @@ func (r *FederationInternalAPI) PerformInvite(
return fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
}
// TODO (devon): This should be allowed via a relay. Currently only transactions
// can be sent to relays. Would need to extend relays to handle invites.
if !r.shouldAttemptDirectFederation(destination) {
return fmt.Errorf("relay servers have no meaningful response for invite.")
}
logrus.WithFields(logrus.Fields{
"event_id": request.Event.EventID(),
"user_id": *request.Event.StateKey(),
@ -682,12 +705,8 @@ func (r *FederationInternalAPI) PerformWakeupServers(
func (r *FederationInternalAPI) MarkServersAlive(destinations []gomatrixserverlib.ServerName) {
for _, srv := range destinations {
// Check the statistics cache for the blacklist status to prevent hitting
// the database unnecessarily.
if r.queues.IsServerBlacklisted(srv) {
_ = r.db.RemoveServerFromBlacklist(srv)
}
r.queues.RetryServer(srv)
wasBlacklisted := r.statistics.ForServer(srv).MarkServerAlive()
r.queues.RetryServer(srv, wasBlacklisted)
}
}
@ -719,7 +738,9 @@ func sanityCheckAuthChain(authChain []*gomatrixserverlib.Event) error {
return fmt.Errorf("auth chain response is missing m.room.create event")
}
func setDefaultRoomVersionFromJoinEvent(joinEvent gomatrixserverlib.EventBuilder) gomatrixserverlib.RoomVersion {
func setDefaultRoomVersionFromJoinEvent(
joinEvent gomatrixserverlib.EventBuilder,
) gomatrixserverlib.RoomVersion {
// if auth events are not event references we know it must be v3+
// we have to do these shenanigans to satisfy sytest, specifically for:
// "Outbound federation rejects m.room.create events with an unknown room version"
@ -802,3 +823,31 @@ func federatedAuthProvider(
return returning, nil
}
}
// P2PQueryRelayServers implements api.FederationInternalAPI
func (r *FederationInternalAPI) P2PQueryRelayServers(
ctx context.Context,
request *api.P2PQueryRelayServersRequest,
response *api.P2PQueryRelayServersResponse,
) error {
logrus.Infof("Getting relay servers for: %s", request.Server)
relayServers, err := r.db.P2PGetRelayServersForServer(ctx, request.Server)
if err != nil {
return err
}
response.RelayServers = relayServers
return nil
}
func (r *FederationInternalAPI) shouldAttemptDirectFederation(
destination gomatrixserverlib.ServerName,
) bool {
var shouldRelay bool
stats := r.statistics.ForServer(destination)
if stats.AssumedOffline() && len(stats.KnownRelayServers()) > 0 {
shouldRelay = true
}
return !shouldRelay
}

View file

@ -0,0 +1,190 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package internal
import (
"context"
"testing"
"github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/federationapi/queue"
"github.com/matrix-org/dendrite/federationapi/statistics"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib"
"github.com/stretchr/testify/assert"
)
type testFedClient struct {
api.FederationClient
queryKeysCalled bool
claimKeysCalled bool
shouldFail bool
}
func (t *testFedClient) LookupRoomAlias(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomAlias string) (res gomatrixserverlib.RespDirectory, err error) {
return gomatrixserverlib.RespDirectory{}, nil
}
func TestPerformWakeupServers(t *testing.T) {
testDB := test.NewInMemoryFederationDatabase()
server := gomatrixserverlib.ServerName("wakeup")
testDB.AddServerToBlacklist(server)
testDB.SetServerAssumedOffline(context.Background(), server)
blacklisted, err := testDB.IsServerBlacklisted(server)
assert.NoError(t, err)
assert.True(t, blacklisted)
offline, err := testDB.IsServerAssumedOffline(context.Background(), server)
assert.NoError(t, err)
assert.True(t, offline)
cfg := config.FederationAPI{
Matrix: &config.Global{
SigningIdentity: gomatrixserverlib.SigningIdentity{
ServerName: "relay",
},
},
}
fedClient := &testFedClient{}
stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline)
queues := queue.NewOutgoingQueues(
testDB, process.NewProcessContext(),
false,
cfg.Matrix.ServerName, fedClient, nil, &stats,
nil,
)
fedAPI := NewFederationInternalAPI(
testDB, &cfg, nil, fedClient, &stats, nil, queues, nil,
)
req := api.PerformWakeupServersRequest{
ServerNames: []gomatrixserverlib.ServerName{server},
}
res := api.PerformWakeupServersResponse{}
err = fedAPI.PerformWakeupServers(context.Background(), &req, &res)
assert.NoError(t, err)
blacklisted, err = testDB.IsServerBlacklisted(server)
assert.NoError(t, err)
assert.False(t, blacklisted)
offline, err = testDB.IsServerAssumedOffline(context.Background(), server)
assert.NoError(t, err)
assert.False(t, offline)
}
func TestQueryRelayServers(t *testing.T) {
testDB := test.NewInMemoryFederationDatabase()
server := gomatrixserverlib.ServerName("wakeup")
relayServers := []gomatrixserverlib.ServerName{"relay1", "relay2"}
err := testDB.P2PAddRelayServersForServer(context.Background(), server, relayServers)
assert.NoError(t, err)
cfg := config.FederationAPI{
Matrix: &config.Global{
SigningIdentity: gomatrixserverlib.SigningIdentity{
ServerName: "relay",
},
},
}
fedClient := &testFedClient{}
stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline)
queues := queue.NewOutgoingQueues(
testDB, process.NewProcessContext(),
false,
cfg.Matrix.ServerName, fedClient, nil, &stats,
nil,
)
fedAPI := NewFederationInternalAPI(
testDB, &cfg, nil, fedClient, &stats, nil, queues, nil,
)
req := api.P2PQueryRelayServersRequest{
Server: server,
}
res := api.P2PQueryRelayServersResponse{}
err = fedAPI.P2PQueryRelayServers(context.Background(), &req, &res)
assert.NoError(t, err)
assert.Equal(t, len(relayServers), len(res.RelayServers))
}
func TestPerformDirectoryLookup(t *testing.T) {
testDB := test.NewInMemoryFederationDatabase()
cfg := config.FederationAPI{
Matrix: &config.Global{
SigningIdentity: gomatrixserverlib.SigningIdentity{
ServerName: "relay",
},
},
}
fedClient := &testFedClient{}
stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline)
queues := queue.NewOutgoingQueues(
testDB, process.NewProcessContext(),
false,
cfg.Matrix.ServerName, fedClient, nil, &stats,
nil,
)
fedAPI := NewFederationInternalAPI(
testDB, &cfg, nil, fedClient, &stats, nil, queues, nil,
)
req := api.PerformDirectoryLookupRequest{
RoomAlias: "room",
ServerName: "server",
}
res := api.PerformDirectoryLookupResponse{}
err := fedAPI.PerformDirectoryLookup(context.Background(), &req, &res)
assert.NoError(t, err)
}
func TestPerformDirectoryLookupRelaying(t *testing.T) {
testDB := test.NewInMemoryFederationDatabase()
server := gomatrixserverlib.ServerName("wakeup")
testDB.SetServerAssumedOffline(context.Background(), server)
testDB.P2PAddRelayServersForServer(context.Background(), server, []gomatrixserverlib.ServerName{"relay"})
cfg := config.FederationAPI{
Matrix: &config.Global{
SigningIdentity: gomatrixserverlib.SigningIdentity{
ServerName: server,
},
},
}
fedClient := &testFedClient{}
stats := statistics.NewStatistics(testDB, FailuresUntilBlacklist, FailuresUntilAssumedOffline)
queues := queue.NewOutgoingQueues(
testDB, process.NewProcessContext(),
false,
cfg.Matrix.ServerName, fedClient, nil, &stats,
nil,
)
fedAPI := NewFederationInternalAPI(
testDB, &cfg, nil, fedClient, &stats, nil, queues, nil,
)
req := api.PerformDirectoryLookupRequest{
RoomAlias: "room",
ServerName: server,
}
res := api.PerformDirectoryLookupResponse{}
err := fedAPI.PerformDirectoryLookup(context.Background(), &req, &res)
assert.Error(t, err)
}

View file

@ -24,6 +24,7 @@ const (
FederationAPIPerformOutboundPeekRequestPath = "/federationapi/performOutboundPeekRequest"
FederationAPIPerformBroadcastEDUPath = "/federationapi/performBroadcastEDU"
FederationAPIPerformWakeupServers = "/federationapi/performWakeupServers"
FederationAPIQueryRelayServers = "/federationapi/queryRelayServers"
FederationAPIGetUserDevicesPath = "/federationapi/client/getUserDevices"
FederationAPIClaimKeysPath = "/federationapi/client/claimKeys"
@ -510,3 +511,14 @@ func (h *httpFederationInternalAPI) QueryPublicKeys(
h.httpClient, ctx, request, response,
)
}
func (h *httpFederationInternalAPI) P2PQueryRelayServers(
ctx context.Context,
request *api.P2PQueryRelayServersRequest,
response *api.P2PQueryRelayServersResponse,
) error {
return httputil.CallInternalRPCAPI(
"QueryRelayServers", h.federationAPIURL+FederationAPIQueryRelayServers,
h.httpClient, ctx, request, response,
)
}

View file

@ -29,7 +29,7 @@ import (
fedapi "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/federationapi/statistics"
"github.com/matrix-org/dendrite/federationapi/storage"
"github.com/matrix-org/dendrite/federationapi/storage/shared"
"github.com/matrix-org/dendrite/federationapi/storage/shared/receipt"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/process"
)
@ -70,7 +70,7 @@ type destinationQueue struct {
// Send event adds the event to the pending queue for the destination.
// If the queue is empty then it starts a background goroutine to
// start sending events to that destination.
func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, receipt *shared.Receipt) {
func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, dbReceipt *receipt.Receipt) {
if event == nil {
logrus.Errorf("attempt to send nil PDU with destination %q", oq.destination)
return
@ -84,8 +84,8 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re
oq.pendingMutex.Lock()
if len(oq.pendingPDUs) < maxPDUsInMemory {
oq.pendingPDUs = append(oq.pendingPDUs, &queuedPDU{
pdu: event,
receipt: receipt,
pdu: event,
dbReceipt: dbReceipt,
})
} else {
oq.overflowed.Store(true)
@ -101,7 +101,7 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re
// sendEDU adds the EDU event to the pending queue for the destination.
// If the queue is empty then it starts a background goroutine to
// start sending events to that destination.
func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *shared.Receipt) {
func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, dbReceipt *receipt.Receipt) {
if event == nil {
logrus.Errorf("attempt to send nil EDU with destination %q", oq.destination)
return
@ -115,8 +115,8 @@ func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *share
oq.pendingMutex.Lock()
if len(oq.pendingEDUs) < maxEDUsInMemory {
oq.pendingEDUs = append(oq.pendingEDUs, &queuedEDU{
edu: event,
receipt: receipt,
edu: event,
dbReceipt: dbReceipt,
})
} else {
oq.overflowed.Store(true)
@ -210,10 +210,10 @@ func (oq *destinationQueue) getPendingFromDatabase() {
gotPDUs := map[string]struct{}{}
gotEDUs := map[string]struct{}{}
for _, pdu := range oq.pendingPDUs {
gotPDUs[pdu.receipt.String()] = struct{}{}
gotPDUs[pdu.dbReceipt.String()] = struct{}{}
}
for _, edu := range oq.pendingEDUs {
gotEDUs[edu.receipt.String()] = struct{}{}
gotEDUs[edu.dbReceipt.String()] = struct{}{}
}
overflowed := false
@ -371,7 +371,7 @@ func (oq *destinationQueue) backgroundSend() {
// If we have pending PDUs or EDUs then construct a transaction.
// Try sending the next transaction and see what happens.
terr := oq.nextTransaction(toSendPDUs, toSendEDUs)
terr, sendMethod := oq.nextTransaction(toSendPDUs, toSendEDUs)
if terr != nil {
// We failed to send the transaction. Mark it as a failure.
_, blacklisted := oq.statistics.Failure()
@ -388,18 +388,19 @@ func (oq *destinationQueue) backgroundSend() {
return
}
} else {
oq.handleTransactionSuccess(pduCount, eduCount)
oq.handleTransactionSuccess(pduCount, eduCount, sendMethod)
}
}
}
// nextTransaction creates a new transaction from the pending event
// queue and sends it.
// Returns an error if the transaction wasn't sent.
// Returns an error if the transaction wasn't sent. And whether the success
// was to a relay server or not.
func (oq *destinationQueue) nextTransaction(
pdus []*queuedPDU,
edus []*queuedEDU,
) error {
) (err error, sendMethod statistics.SendMethod) {
// Create the transaction.
t, pduReceipts, eduReceipts := oq.createTransaction(pdus, edus)
logrus.WithField("server_name", oq.destination).Debugf("Sending transaction %q containing %d PDUs, %d EDUs", t.TransactionID, len(t.PDUs), len(t.EDUs))
@ -407,7 +408,37 @@ func (oq *destinationQueue) nextTransaction(
// Try to send the transaction to the destination server.
ctx, cancel := context.WithTimeout(oq.process.Context(), time.Minute*5)
defer cancel()
_, err := oq.client.SendTransaction(ctx, t)
relayServers := oq.statistics.KnownRelayServers()
if oq.statistics.AssumedOffline() && len(relayServers) > 0 {
sendMethod = statistics.SendViaRelay
relaySuccess := false
logrus.Infof("Sending to relay servers: %v", relayServers)
// TODO : how to pass through actual userID here?!?!?!?!
userID, userErr := gomatrixserverlib.NewUserID("@user:"+string(oq.destination), false)
if userErr != nil {
return userErr, sendMethod
}
// Attempt sending to each known relay server.
for _, relayServer := range relayServers {
_, relayErr := oq.client.P2PSendTransactionToRelay(ctx, *userID, t, relayServer)
if relayErr != nil {
err = relayErr
} else {
// If sending to one of the relay servers succeeds, consider the send successful.
relaySuccess = true
}
}
// Clear the error if sending to any of the relay servers succeeded.
if relaySuccess {
err = nil
}
} else {
sendMethod = statistics.SendDirect
_, err = oq.client.SendTransaction(ctx, t)
}
switch errResponse := err.(type) {
case nil:
// Clean up the transaction in the database.
@ -427,7 +458,7 @@ func (oq *destinationQueue) nextTransaction(
oq.transactionIDMutex.Lock()
oq.transactionID = ""
oq.transactionIDMutex.Unlock()
return nil
return nil, sendMethod
case gomatrix.HTTPError:
// Report that we failed to send the transaction and we
// will retry again, subject to backoff.
@ -437,13 +468,13 @@ func (oq *destinationQueue) nextTransaction(
// to a 400-ish error
code := errResponse.Code
logrus.Debug("Transaction failed with HTTP", code)
return err
return err, sendMethod
default:
logrus.WithFields(logrus.Fields{
"destination": oq.destination,
logrus.ErrorKey: err,
}).Debugf("Failed to send transaction %q", t.TransactionID)
return err
return err, sendMethod
}
}
@ -453,7 +484,7 @@ func (oq *destinationQueue) nextTransaction(
func (oq *destinationQueue) createTransaction(
pdus []*queuedPDU,
edus []*queuedEDU,
) (gomatrixserverlib.Transaction, []*shared.Receipt, []*shared.Receipt) {
) (gomatrixserverlib.Transaction, []*receipt.Receipt, []*receipt.Receipt) {
// If there's no projected transaction ID then generate one. If
// the transaction succeeds then we'll set it back to "" so that
// we generate a new one next time. If it fails, we'll preserve
@ -474,8 +505,8 @@ func (oq *destinationQueue) createTransaction(
t.OriginServerTS = gomatrixserverlib.AsTimestamp(time.Now())
t.TransactionID = oq.transactionID
var pduReceipts []*shared.Receipt
var eduReceipts []*shared.Receipt
var pduReceipts []*receipt.Receipt
var eduReceipts []*receipt.Receipt
// Go through PDUs that we retrieved from the database, if any,
// and add them into the transaction.
@ -487,7 +518,7 @@ func (oq *destinationQueue) createTransaction(
// Append the JSON of the event, since this is a json.RawMessage type in the
// gomatrixserverlib.Transaction struct
t.PDUs = append(t.PDUs, pdu.pdu.JSON())
pduReceipts = append(pduReceipts, pdu.receipt)
pduReceipts = append(pduReceipts, pdu.dbReceipt)
}
// Do the same for pending EDUS in the queue.
@ -497,7 +528,7 @@ func (oq *destinationQueue) createTransaction(
continue
}
t.EDUs = append(t.EDUs, *edu.edu)
eduReceipts = append(eduReceipts, edu.receipt)
eduReceipts = append(eduReceipts, edu.dbReceipt)
}
return t, pduReceipts, eduReceipts
@ -530,10 +561,11 @@ func (oq *destinationQueue) blacklistDestination() {
// handleTransactionSuccess updates the cached event queues as well as the success and
// backoff information for this server.
func (oq *destinationQueue) handleTransactionSuccess(pduCount int, eduCount int) {
func (oq *destinationQueue) handleTransactionSuccess(pduCount int, eduCount int, sendMethod statistics.SendMethod) {
// If we successfully sent the transaction then clear out
// the pending events and EDUs, and wipe our transaction ID.
oq.statistics.Success()
oq.statistics.Success(sendMethod)
oq.pendingMutex.Lock()
defer oq.pendingMutex.Unlock()

View file

@ -30,7 +30,7 @@ import (
fedapi "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/federationapi/statistics"
"github.com/matrix-org/dendrite/federationapi/storage"
"github.com/matrix-org/dendrite/federationapi/storage/shared"
"github.com/matrix-org/dendrite/federationapi/storage/shared/receipt"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/process"
)
@ -138,13 +138,13 @@ func NewOutgoingQueues(
}
type queuedPDU struct {
receipt *shared.Receipt
pdu *gomatrixserverlib.HeaderedEvent
dbReceipt *receipt.Receipt
pdu *gomatrixserverlib.HeaderedEvent
}
type queuedEDU struct {
receipt *shared.Receipt
edu *gomatrixserverlib.EDU
dbReceipt *receipt.Receipt
edu *gomatrixserverlib.EDU
}
func (oqs *OutgoingQueues) getQueue(destination gomatrixserverlib.ServerName) *destinationQueue {
@ -374,24 +374,13 @@ func (oqs *OutgoingQueues) SendEDU(
return nil
}
// IsServerBlacklisted returns whether or not the provided server is currently
// blacklisted.
func (oqs *OutgoingQueues) IsServerBlacklisted(srv gomatrixserverlib.ServerName) bool {
return oqs.statistics.ForServer(srv).Blacklisted()
}
// RetryServer attempts to resend events to the given server if we had given up.
func (oqs *OutgoingQueues) RetryServer(srv gomatrixserverlib.ServerName) {
func (oqs *OutgoingQueues) RetryServer(srv gomatrixserverlib.ServerName, wasBlacklisted bool) {
if oqs.disabled {
return
}
serverStatistics := oqs.statistics.ForServer(srv)
forceWakeup := serverStatistics.Blacklisted()
serverStatistics.RemoveBlacklist()
serverStatistics.ClearBackoff()
if queue := oqs.getQueue(srv); queue != nil {
queue.wakeQueueIfEventsPending(forceWakeup)
queue.wakeQueueIfEventsPending(wasBlacklisted)
}
}

View file

@ -18,7 +18,6 @@ import (
"context"
"encoding/json"
"fmt"
"sync"
"testing"
"time"
@ -26,13 +25,11 @@ import (
"gotest.tools/v3/poll"
"github.com/matrix-org/gomatrixserverlib"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/federationapi/statistics"
"github.com/matrix-org/dendrite/federationapi/storage"
"github.com/matrix-org/dendrite/federationapi/storage/shared"
rsapi "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/process"
@ -57,7 +54,7 @@ func mustCreateFederationDatabase(t *testing.T, dbType test.DBType, realDatabase
}
} else {
// Fake Database
db := createDatabase()
db := test.NewInMemoryFederationDatabase()
b := struct {
ProcessContext *process.ProcessContext
}{ProcessContext: process.NewProcessContext()}
@ -65,220 +62,6 @@ func mustCreateFederationDatabase(t *testing.T, dbType test.DBType, realDatabase
}
}
func createDatabase() storage.Database {
return &fakeDatabase{
pendingPDUServers: make(map[gomatrixserverlib.ServerName]struct{}),
pendingEDUServers: make(map[gomatrixserverlib.ServerName]struct{}),
blacklistedServers: make(map[gomatrixserverlib.ServerName]struct{}),
pendingPDUs: make(map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent),
pendingEDUs: make(map[*shared.Receipt]*gomatrixserverlib.EDU),
associatedPDUs: make(map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{}),
associatedEDUs: make(map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{}),
}
}
type fakeDatabase struct {
storage.Database
dbMutex sync.Mutex
pendingPDUServers map[gomatrixserverlib.ServerName]struct{}
pendingEDUServers map[gomatrixserverlib.ServerName]struct{}
blacklistedServers map[gomatrixserverlib.ServerName]struct{}
pendingPDUs map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent
pendingEDUs map[*shared.Receipt]*gomatrixserverlib.EDU
associatedPDUs map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{}
associatedEDUs map[gomatrixserverlib.ServerName]map[*shared.Receipt]struct{}
}
var nidMutex sync.Mutex
var nid = int64(0)
func (d *fakeDatabase) StoreJSON(ctx context.Context, js string) (*shared.Receipt, error) {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
var event gomatrixserverlib.HeaderedEvent
if err := json.Unmarshal([]byte(js), &event); err == nil {
nidMutex.Lock()
defer nidMutex.Unlock()
nid++
receipt := shared.NewReceipt(nid)
d.pendingPDUs[&receipt] = &event
return &receipt, nil
}
var edu gomatrixserverlib.EDU
if err := json.Unmarshal([]byte(js), &edu); err == nil {
nidMutex.Lock()
defer nidMutex.Unlock()
nid++
receipt := shared.NewReceipt(nid)
d.pendingEDUs[&receipt] = &edu
return &receipt, nil
}
return nil, errors.New("Failed to determine type of json to store")
}
func (d *fakeDatabase) GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent, err error) {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
pduCount := 0
pdus = make(map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent)
if receipts, ok := d.associatedPDUs[serverName]; ok {
for receipt := range receipts {
if event, ok := d.pendingPDUs[receipt]; ok {
pdus[receipt] = event
pduCount++
if pduCount == limit {
break
}
}
}
}
return pdus, nil
}
func (d *fakeDatabase) GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*shared.Receipt]*gomatrixserverlib.EDU, err error) {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
eduCount := 0
edus = make(map[*shared.Receipt]*gomatrixserverlib.EDU)
if receipts, ok := d.associatedEDUs[serverName]; ok {
for receipt := range receipts {
if event, ok := d.pendingEDUs[receipt]; ok {
edus[receipt] = event
eduCount++
if eduCount == limit {
break
}
}
}
}
return edus, nil
}
func (d *fakeDatabase) AssociatePDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt) error {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
if _, ok := d.pendingPDUs[receipt]; ok {
for destination := range destinations {
if _, ok := d.associatedPDUs[destination]; !ok {
d.associatedPDUs[destination] = make(map[*shared.Receipt]struct{})
}
d.associatedPDUs[destination][receipt] = struct{}{}
}
return nil
} else {
return errors.New("PDU doesn't exist")
}
}
func (d *fakeDatabase) AssociateEDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
if _, ok := d.pendingEDUs[receipt]; ok {
for destination := range destinations {
if _, ok := d.associatedEDUs[destination]; !ok {
d.associatedEDUs[destination] = make(map[*shared.Receipt]struct{})
}
d.associatedEDUs[destination][receipt] = struct{}{}
}
return nil
} else {
return errors.New("EDU doesn't exist")
}
}
func (d *fakeDatabase) CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
if pdus, ok := d.associatedPDUs[serverName]; ok {
for _, receipt := range receipts {
delete(pdus, receipt)
}
}
return nil
}
func (d *fakeDatabase) CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
if edus, ok := d.associatedEDUs[serverName]; ok {
for _, receipt := range receipts {
delete(edus, receipt)
}
}
return nil
}
func (d *fakeDatabase) GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
servers := []gomatrixserverlib.ServerName{}
for server := range d.pendingPDUServers {
servers = append(servers, server)
}
return servers, nil
}
func (d *fakeDatabase) GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error) {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
servers := []gomatrixserverlib.ServerName{}
for server := range d.pendingEDUServers {
servers = append(servers, server)
}
return servers, nil
}
func (d *fakeDatabase) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
d.blacklistedServers[serverName] = struct{}{}
return nil
}
func (d *fakeDatabase) RemoveServerFromBlacklist(serverName gomatrixserverlib.ServerName) error {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
delete(d.blacklistedServers, serverName)
return nil
}
func (d *fakeDatabase) RemoveAllServersFromBlacklist() error {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
d.blacklistedServers = make(map[gomatrixserverlib.ServerName]struct{})
return nil
}
func (d *fakeDatabase) IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error) {
d.dbMutex.Lock()
defer d.dbMutex.Unlock()
isBlacklisted := false
if _, ok := d.blacklistedServers[serverName]; ok {
isBlacklisted = true
}
return isBlacklisted, nil
}
type stubFederationRoomServerAPI struct {
rsapi.FederationRoomserverAPI
}
@ -290,8 +73,10 @@ func (r *stubFederationRoomServerAPI) QueryServerBannedFromRoom(ctx context.Cont
type stubFederationClient struct {
api.FederationClient
shouldTxSucceed bool
txCount atomic.Uint32
shouldTxSucceed bool
shouldTxRelaySucceed bool
txCount atomic.Uint32
txRelayCount atomic.Uint32
}
func (f *stubFederationClient) SendTransaction(ctx context.Context, t gomatrixserverlib.Transaction) (res gomatrixserverlib.RespSend, err error) {
@ -304,6 +89,16 @@ func (f *stubFederationClient) SendTransaction(ctx context.Context, t gomatrixse
return gomatrixserverlib.RespSend{}, result
}
func (f *stubFederationClient) P2PSendTransactionToRelay(ctx context.Context, u gomatrixserverlib.UserID, t gomatrixserverlib.Transaction, forwardingServer gomatrixserverlib.ServerName) (res gomatrixserverlib.EmptyResp, err error) {
var result error
if !f.shouldTxRelaySucceed {
result = fmt.Errorf("relay transaction failed")
}
f.txRelayCount.Add(1)
return gomatrixserverlib.EmptyResp{}, result
}
func mustCreatePDU(t *testing.T) *gomatrixserverlib.HeaderedEvent {
t.Helper()
content := `{"type":"m.room.message"}`
@ -319,15 +114,18 @@ func mustCreateEDU(t *testing.T) *gomatrixserverlib.EDU {
return &gomatrixserverlib.EDU{Type: gomatrixserverlib.MTyping}
}
func testSetup(failuresUntilBlacklist uint32, shouldTxSucceed bool, t *testing.T, dbType test.DBType, realDatabase bool) (storage.Database, *stubFederationClient, *OutgoingQueues, *process.ProcessContext, func()) {
func testSetup(failuresUntilBlacklist uint32, failuresUntilAssumedOffline uint32, shouldTxSucceed bool, shouldTxRelaySucceed bool, t *testing.T, dbType test.DBType, realDatabase bool) (storage.Database, *stubFederationClient, *OutgoingQueues, *process.ProcessContext, func()) {
db, processContext, close := mustCreateFederationDatabase(t, dbType, realDatabase)
fc := &stubFederationClient{
shouldTxSucceed: shouldTxSucceed,
txCount: *atomic.NewUint32(0),
shouldTxSucceed: shouldTxSucceed,
shouldTxRelaySucceed: shouldTxRelaySucceed,
txCount: *atomic.NewUint32(0),
txRelayCount: *atomic.NewUint32(0),
}
rs := &stubFederationRoomServerAPI{}
stats := statistics.NewStatistics(db, failuresUntilBlacklist)
stats := statistics.NewStatistics(db, failuresUntilBlacklist, failuresUntilAssumedOffline)
signingInfo := []*gomatrixserverlib.SigningIdentity{
{
KeyID: "ed21019:auto",
@ -344,7 +142,7 @@ func TestSendPDUOnSuccessRemovedFromDB(t *testing.T) {
t.Parallel()
failuresUntilBlacklist := uint32(16)
destination := gomatrixserverlib.ServerName("remotehost")
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false)
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, false, t, test.DBTypeSQLite, false)
defer close()
defer func() {
pc.ShutdownDendrite()
@ -373,7 +171,7 @@ func TestSendEDUOnSuccessRemovedFromDB(t *testing.T) {
t.Parallel()
failuresUntilBlacklist := uint32(16)
destination := gomatrixserverlib.ServerName("remotehost")
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false)
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, false, t, test.DBTypeSQLite, false)
defer close()
defer func() {
pc.ShutdownDendrite()
@ -402,7 +200,7 @@ func TestSendPDUOnFailStoredInDB(t *testing.T) {
t.Parallel()
failuresUntilBlacklist := uint32(16)
destination := gomatrixserverlib.ServerName("remotehost")
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false)
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false)
defer close()
defer func() {
pc.ShutdownDendrite()
@ -432,7 +230,7 @@ func TestSendEDUOnFailStoredInDB(t *testing.T) {
t.Parallel()
failuresUntilBlacklist := uint32(16)
destination := gomatrixserverlib.ServerName("remotehost")
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false)
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false)
defer close()
defer func() {
pc.ShutdownDendrite()
@ -462,7 +260,7 @@ func TestSendPDUAgainDoesntInterruptBackoff(t *testing.T) {
t.Parallel()
failuresUntilBlacklist := uint32(16)
destination := gomatrixserverlib.ServerName("remotehost")
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false)
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false)
defer close()
defer func() {
pc.ShutdownDendrite()
@ -513,7 +311,7 @@ func TestSendEDUAgainDoesntInterruptBackoff(t *testing.T) {
t.Parallel()
failuresUntilBlacklist := uint32(16)
destination := gomatrixserverlib.ServerName("remotehost")
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false)
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false)
defer close()
defer func() {
pc.ShutdownDendrite()
@ -564,7 +362,7 @@ func TestSendPDUMultipleFailuresBlacklisted(t *testing.T) {
t.Parallel()
failuresUntilBlacklist := uint32(2)
destination := gomatrixserverlib.ServerName("remotehost")
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false)
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false)
defer close()
defer func() {
pc.ShutdownDendrite()
@ -596,7 +394,7 @@ func TestSendEDUMultipleFailuresBlacklisted(t *testing.T) {
t.Parallel()
failuresUntilBlacklist := uint32(2)
destination := gomatrixserverlib.ServerName("remotehost")
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false)
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false)
defer close()
defer func() {
pc.ShutdownDendrite()
@ -628,7 +426,7 @@ func TestSendPDUBlacklistedWithPriorExternalFailure(t *testing.T) {
t.Parallel()
failuresUntilBlacklist := uint32(2)
destination := gomatrixserverlib.ServerName("remotehost")
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false)
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false)
defer close()
defer func() {
pc.ShutdownDendrite()
@ -662,7 +460,7 @@ func TestSendEDUBlacklistedWithPriorExternalFailure(t *testing.T) {
t.Parallel()
failuresUntilBlacklist := uint32(2)
destination := gomatrixserverlib.ServerName("remotehost")
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false)
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false)
defer close()
defer func() {
pc.ShutdownDendrite()
@ -696,7 +494,7 @@ func TestRetryServerSendsPDUSuccessfully(t *testing.T) {
t.Parallel()
failuresUntilBlacklist := uint32(1)
destination := gomatrixserverlib.ServerName("remotehost")
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false)
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false)
defer close()
defer func() {
pc.ShutdownDendrite()
@ -730,8 +528,8 @@ func TestRetryServerSendsPDUSuccessfully(t *testing.T) {
poll.WaitOn(t, checkBlacklisted, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
fc.shouldTxSucceed = true
db.RemoveServerFromBlacklist(destination)
queues.RetryServer(destination)
wasBlacklisted := dest.statistics.MarkServerAlive()
queues.RetryServer(destination, wasBlacklisted)
checkRetry := func(log poll.LogT) poll.Result {
data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100)
assert.NoError(t, dbErr)
@ -747,7 +545,7 @@ func TestRetryServerSendsEDUSuccessfully(t *testing.T) {
t.Parallel()
failuresUntilBlacklist := uint32(1)
destination := gomatrixserverlib.ServerName("remotehost")
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, test.DBTypeSQLite, false)
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, test.DBTypeSQLite, false)
defer close()
defer func() {
pc.ShutdownDendrite()
@ -781,8 +579,8 @@ func TestRetryServerSendsEDUSuccessfully(t *testing.T) {
poll.WaitOn(t, checkBlacklisted, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
fc.shouldTxSucceed = true
db.RemoveServerFromBlacklist(destination)
queues.RetryServer(destination)
wasBlacklisted := dest.statistics.MarkServerAlive()
queues.RetryServer(destination, wasBlacklisted)
checkRetry := func(log poll.LogT) poll.Result {
data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100)
assert.NoError(t, dbErr)
@ -801,7 +599,7 @@ func TestSendPDUBatches(t *testing.T) {
// test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
// db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, dbType, true)
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false)
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, false, t, test.DBTypeSQLite, false)
defer close()
defer func() {
pc.ShutdownDendrite()
@ -845,7 +643,7 @@ func TestSendEDUBatches(t *testing.T) {
// test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
// db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, dbType, true)
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false)
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, false, t, test.DBTypeSQLite, false)
defer close()
defer func() {
pc.ShutdownDendrite()
@ -889,7 +687,7 @@ func TestSendPDUAndEDUBatches(t *testing.T) {
// test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
// db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, dbType, true)
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false)
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, false, t, test.DBTypeSQLite, false)
defer close()
defer func() {
pc.ShutdownDendrite()
@ -940,7 +738,7 @@ func TestExternalFailureBackoffDoesntStartQueue(t *testing.T) {
t.Parallel()
failuresUntilBlacklist := uint32(16)
destination := gomatrixserverlib.ServerName("remotehost")
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, true, t, test.DBTypeSQLite, false)
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, true, false, t, test.DBTypeSQLite, false)
defer close()
defer func() {
pc.ShutdownDendrite()
@ -978,7 +776,7 @@ func TestQueueInteractsWithRealDatabasePDUAndEDU(t *testing.T) {
destination := gomatrixserverlib.ServerName("remotehost")
destinations := map[gomatrixserverlib.ServerName]struct{}{destination: {}}
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, dbType, true)
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilBlacklist+1, false, false, t, dbType, true)
// NOTE : These defers aren't called if go test is killed so the dbs may not get cleaned up.
defer close()
defer func() {
@ -1023,8 +821,8 @@ func TestQueueInteractsWithRealDatabasePDUAndEDU(t *testing.T) {
poll.WaitOn(t, checkBlacklisted, poll.WithTimeout(10*time.Second), poll.WithDelay(100*time.Millisecond))
fc.shouldTxSucceed = true
db.RemoveServerFromBlacklist(destination)
queues.RetryServer(destination)
wasBlacklisted := dest.statistics.MarkServerAlive()
queues.RetryServer(destination, wasBlacklisted)
checkRetry := func(log poll.LogT) poll.Result {
pduData, dbErrPDU := db.GetPendingPDUs(pc.Context(), destination, 200)
assert.NoError(t, dbErrPDU)
@ -1038,3 +836,147 @@ func TestQueueInteractsWithRealDatabasePDUAndEDU(t *testing.T) {
poll.WaitOn(t, checkRetry, poll.WithTimeout(10*time.Second), poll.WithDelay(100*time.Millisecond))
})
}
func TestSendPDUMultipleFailuresAssumedOffline(t *testing.T) {
t.Parallel()
failuresUntilBlacklist := uint32(7)
failuresUntilAssumedOffline := uint32(2)
destination := gomatrixserverlib.ServerName("remotehost")
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilAssumedOffline, false, false, t, test.DBTypeSQLite, false)
defer close()
defer func() {
pc.ShutdownDendrite()
<-pc.WaitForShutdown()
}()
ev := mustCreatePDU(t)
err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination})
assert.NoError(t, err)
check := func(log poll.LogT) poll.Result {
if fc.txCount.Load() == failuresUntilAssumedOffline {
data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100)
assert.NoError(t, dbErr)
if len(data) == 1 {
if val, _ := db.IsServerAssumedOffline(context.Background(), destination); val {
return poll.Success()
}
return poll.Continue("waiting for server to be assumed offline")
}
return poll.Continue("waiting for event to be added to database. Currently present PDU: %d", len(data))
}
return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load())
}
poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
}
func TestSendEDUMultipleFailuresAssumedOffline(t *testing.T) {
t.Parallel()
failuresUntilBlacklist := uint32(7)
failuresUntilAssumedOffline := uint32(2)
destination := gomatrixserverlib.ServerName("remotehost")
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilAssumedOffline, false, false, t, test.DBTypeSQLite, false)
defer close()
defer func() {
pc.ShutdownDendrite()
<-pc.WaitForShutdown()
}()
ev := mustCreateEDU(t)
err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination})
assert.NoError(t, err)
check := func(log poll.LogT) poll.Result {
if fc.txCount.Load() == failuresUntilAssumedOffline {
data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100)
assert.NoError(t, dbErr)
if len(data) == 1 {
if val, _ := db.IsServerAssumedOffline(context.Background(), destination); val {
return poll.Success()
}
return poll.Continue("waiting for server to be assumed offline")
}
return poll.Continue("waiting for event to be added to database. Currently present EDU: %d", len(data))
}
return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load())
}
poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
}
func TestSendPDUOnRelaySuccessRemovedFromDB(t *testing.T) {
t.Parallel()
failuresUntilBlacklist := uint32(16)
failuresUntilAssumedOffline := uint32(1)
destination := gomatrixserverlib.ServerName("remotehost")
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilAssumedOffline, false, true, t, test.DBTypeSQLite, false)
defer close()
defer func() {
pc.ShutdownDendrite()
<-pc.WaitForShutdown()
}()
relayServers := []gomatrixserverlib.ServerName{"relayserver"}
queues.statistics.ForServer(destination).AddRelayServers(relayServers)
ev := mustCreatePDU(t)
err := queues.SendEvent(ev, "localhost", []gomatrixserverlib.ServerName{destination})
assert.NoError(t, err)
check := func(log poll.LogT) poll.Result {
if fc.txCount.Load() == 1 {
if fc.txRelayCount.Load() == 1 {
data, dbErr := db.GetPendingPDUs(pc.Context(), destination, 100)
assert.NoError(t, dbErr)
if len(data) == 0 {
return poll.Success()
}
return poll.Continue("waiting for event to be removed from database. Currently present PDU: %d", len(data))
}
return poll.Continue("waiting for more relay send attempts before checking database. Currently %d", fc.txRelayCount.Load())
}
return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load())
}
poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
assumedOffline, _ := db.IsServerAssumedOffline(context.Background(), destination)
assert.Equal(t, true, assumedOffline)
}
func TestSendEDUOnRelaySuccessRemovedFromDB(t *testing.T) {
t.Parallel()
failuresUntilBlacklist := uint32(16)
failuresUntilAssumedOffline := uint32(1)
destination := gomatrixserverlib.ServerName("remotehost")
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, failuresUntilAssumedOffline, false, true, t, test.DBTypeSQLite, false)
defer close()
defer func() {
pc.ShutdownDendrite()
<-pc.WaitForShutdown()
}()
relayServers := []gomatrixserverlib.ServerName{"relayserver"}
queues.statistics.ForServer(destination).AddRelayServers(relayServers)
ev := mustCreateEDU(t)
err := queues.SendEDU(ev, "localhost", []gomatrixserverlib.ServerName{destination})
assert.NoError(t, err)
check := func(log poll.LogT) poll.Result {
if fc.txCount.Load() == 1 {
if fc.txRelayCount.Load() == 1 {
data, dbErr := db.GetPendingEDUs(pc.Context(), destination, 100)
assert.NoError(t, dbErr)
if len(data) == 0 {
return poll.Success()
}
return poll.Continue("waiting for event to be removed from database. Currently present EDU: %d", len(data))
}
return poll.Continue("waiting for more relay send attempts before checking database. Currently %d", fc.txRelayCount.Load())
}
return poll.Continue("waiting for more send attempts before checking database. Currently %d", fc.txCount.Load())
}
poll.WaitOn(t, check, poll.WithTimeout(5*time.Second), poll.WithDelay(100*time.Millisecond))
assumedOffline, _ := db.IsServerAssumedOffline(context.Background(), destination)
assert.Equal(t, true, assumedOffline)
}

View file

@ -0,0 +1,94 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package routing_test
import (
"context"
"encoding/hex"
"io"
"net/http/httptest"
"net/url"
"testing"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing"
fedAPI "github.com/matrix-org/dendrite/federationapi"
fedInternal "github.com/matrix-org/dendrite/federationapi/internal"
"github.com/matrix-org/dendrite/federationapi/routing"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
userAPI "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/ed25519"
)
type fakeUserAPI struct {
userAPI.FederationUserAPI
}
func (u *fakeUserAPI) QueryProfile(ctx context.Context, req *userAPI.QueryProfileRequest, res *userAPI.QueryProfileResponse) error {
return nil
}
func TestHandleQueryProfile(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
base, close := testrig.CreateBaseDendrite(t, dbType)
defer close()
fedMux := mux.NewRouter().SkipClean(true).PathPrefix(httputil.PublicFederationPathPrefix).Subrouter().UseEncodedPath()
base.PublicFederationAPIMux = fedMux
base.Cfg.FederationAPI.Matrix.SigningIdentity.ServerName = testOrigin
base.Cfg.FederationAPI.Matrix.Metrics.Enabled = false
fedClient := fakeFedClient{}
serverKeyAPI := &signing.YggdrasilKeys{}
keyRing := serverKeyAPI.KeyRing()
fedapi := fedAPI.NewInternalAPI(base, &fedClient, nil, nil, keyRing, true)
userapi := fakeUserAPI{}
r, ok := fedapi.(*fedInternal.FederationInternalAPI)
if !ok {
panic("This is a programming error.")
}
routing.Setup(base, nil, r, keyRing, &fedClient, &userapi, nil, &base.Cfg.MSCs, nil, nil)
handler := fedMux.Get(routing.QueryProfileRouteName).GetHandler().ServeHTTP
_, sk, _ := ed25519.GenerateKey(nil)
keyID := signing.KeyID
pk := sk.Public().(ed25519.PublicKey)
serverName := gomatrixserverlib.ServerName(hex.EncodeToString(pk))
req := gomatrixserverlib.NewFederationRequest("GET", serverName, testOrigin, "/query/profile?user_id="+url.QueryEscape("@user:"+string(testOrigin)))
type queryContent struct{}
content := queryContent{}
err := req.SetContent(content)
if err != nil {
t.Fatalf("Error: %s", err.Error())
}
req.Sign(serverName, gomatrixserverlib.KeyID(keyID), sk)
httpReq, err := req.HTTPRequest()
if err != nil {
t.Fatalf("Error: %s", err.Error())
}
// vars := map[string]string{"room_alias": "#room:server"}
w := httptest.NewRecorder()
// httpReq = mux.SetURLVars(httpReq, vars)
handler(w, httpReq)
res := w.Result()
data, _ := io.ReadAll(res.Body)
println(string(data))
assert.Equal(t, 200, res.StatusCode)
})
}

View file

@ -0,0 +1,94 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package routing_test
import (
"context"
"encoding/hex"
"io"
"net/http/httptest"
"net/url"
"testing"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing"
fedAPI "github.com/matrix-org/dendrite/federationapi"
fedclient "github.com/matrix-org/dendrite/federationapi/api"
fedInternal "github.com/matrix-org/dendrite/federationapi/internal"
"github.com/matrix-org/dendrite/federationapi/routing"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
"github.com/matrix-org/gomatrixserverlib"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/ed25519"
)
type fakeFedClient struct {
fedclient.FederationClient
}
func (f *fakeFedClient) LookupRoomAlias(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomAlias string) (res gomatrixserverlib.RespDirectory, err error) {
return
}
func TestHandleQueryDirectory(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
base, close := testrig.CreateBaseDendrite(t, dbType)
defer close()
fedMux := mux.NewRouter().SkipClean(true).PathPrefix(httputil.PublicFederationPathPrefix).Subrouter().UseEncodedPath()
base.PublicFederationAPIMux = fedMux
base.Cfg.FederationAPI.Matrix.SigningIdentity.ServerName = testOrigin
base.Cfg.FederationAPI.Matrix.Metrics.Enabled = false
fedClient := fakeFedClient{}
serverKeyAPI := &signing.YggdrasilKeys{}
keyRing := serverKeyAPI.KeyRing()
fedapi := fedAPI.NewInternalAPI(base, &fedClient, nil, nil, keyRing, true)
userapi := fakeUserAPI{}
r, ok := fedapi.(*fedInternal.FederationInternalAPI)
if !ok {
panic("This is a programming error.")
}
routing.Setup(base, nil, r, keyRing, &fedClient, &userapi, nil, &base.Cfg.MSCs, nil, nil)
handler := fedMux.Get(routing.QueryDirectoryRouteName).GetHandler().ServeHTTP
_, sk, _ := ed25519.GenerateKey(nil)
keyID := signing.KeyID
pk := sk.Public().(ed25519.PublicKey)
serverName := gomatrixserverlib.ServerName(hex.EncodeToString(pk))
req := gomatrixserverlib.NewFederationRequest("GET", serverName, testOrigin, "/query/directory?room_alias="+url.QueryEscape("#room:server"))
type queryContent struct{}
content := queryContent{}
err := req.SetContent(content)
if err != nil {
t.Fatalf("Error: %s", err.Error())
}
req.Sign(serverName, gomatrixserverlib.KeyID(keyID), sk)
httpReq, err := req.HTTPRequest()
if err != nil {
t.Fatalf("Error: %s", err.Error())
}
// vars := map[string]string{"room_alias": "#room:server"}
w := httptest.NewRecorder()
// httpReq = mux.SetURLVars(httpReq, vars)
handler(w, httpReq)
res := w.Result()
data, _ := io.ReadAll(res.Body)
println(string(data))
assert.Equal(t, 200, res.StatusCode)
})
}

View file

@ -41,6 +41,12 @@ import (
"github.com/sirupsen/logrus"
)
const (
SendRouteName = "Send"
QueryDirectoryRouteName = "QueryDirectory"
QueryProfileRouteName = "QueryProfile"
)
// Setup registers HTTP handlers with the given ServeMux.
// The provided publicAPIMux MUST have `UseEncodedPath()` enabled or else routes will incorrectly
// path unescape twice (once from the router, once from MakeFedAPI). We need to have this enabled
@ -68,7 +74,7 @@ func Setup(
if base.EnableMetrics {
prometheus.MustRegister(
pduCountTotal, eduCountTotal,
internal.PDUCountTotal, internal.EDUCountTotal,
)
}
@ -138,7 +144,7 @@ func Setup(
cfg, rsAPI, keyAPI, keys, federation, mu, servers, producer,
)
},
)).Methods(http.MethodPut, http.MethodOptions)
)).Methods(http.MethodPut, http.MethodOptions).Name(SendRouteName)
v1fedmux.Handle("/invite/{roomID}/{eventID}", MakeFedAPI(
"federation_invite", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup,
@ -248,7 +254,7 @@ func Setup(
httpReq, federation, cfg, rsAPI, fsAPI,
)
},
)).Methods(http.MethodGet)
)).Methods(http.MethodGet).Name(QueryDirectoryRouteName)
v1fedmux.Handle("/query/profile", MakeFedAPI(
"federation_query_profile", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup,
@ -257,7 +263,7 @@ func Setup(
httpReq, userAPI, cfg,
)
},
)).Methods(http.MethodGet)
)).Methods(http.MethodGet).Name(QueryProfileRouteName)
v1fedmux.Handle("/user/devices/{userID}", MakeFedAPI(
"federation_user_devices", cfg.Matrix.ServerName, cfg.Matrix.IsLocalServerName, keys, wakeup,

View file

@ -17,26 +17,20 @@ package routing
import (
"context"
"encoding/json"
"fmt"
"net/http"
"sync"
"time"
"github.com/getsentry/sentry-go"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
federationAPI "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/federationapi/producers"
"github.com/matrix-org/dendrite/federationapi/types"
"github.com/matrix-org/dendrite/internal"
keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
syncTypes "github.com/matrix-org/dendrite/syncapi/types"
)
const (
@ -56,26 +50,6 @@ const (
MetricsWorkMissingPrevEvents = "missing_prev_events"
)
var (
pduCountTotal = prometheus.NewCounterVec(
prometheus.CounterOpts{
Namespace: "dendrite",
Subsystem: "federationapi",
Name: "recv_pdus",
Help: "Number of incoming PDUs from remote servers with labels for success",
},
[]string{"status"}, // 'success' or 'total'
)
eduCountTotal = prometheus.NewCounter(
prometheus.CounterOpts{
Namespace: "dendrite",
Subsystem: "federationapi",
Name: "recv_edus",
Help: "Number of incoming EDUs from remote servers",
},
)
)
var inFlightTxnsPerOrigin sync.Map // transaction ID -> chan util.JSONResponse
// Send implements /_matrix/federation/v1/send/{txnID}
@ -123,18 +97,6 @@ func Send(
defer close(ch)
defer inFlightTxnsPerOrigin.Delete(index)
t := txnReq{
rsAPI: rsAPI,
keys: keys,
ourServerName: cfg.Matrix.ServerName,
federation: federation,
servers: servers,
keyAPI: keyAPI,
roomsMu: mu,
producer: producer,
inboundPresenceEnabled: cfg.Matrix.Presence.EnableInbound,
}
var txnEvents struct {
PDUs []json.RawMessage `json:"pdus"`
EDUs []gomatrixserverlib.EDU `json:"edus"`
@ -155,16 +117,23 @@ func Send(
}
}
// TODO: Really we should have a function to convert FederationRequest to txnReq
t.PDUs = txnEvents.PDUs
t.EDUs = txnEvents.EDUs
t.Origin = request.Origin()
t.TransactionID = txnID
t.Destination = cfg.Matrix.ServerName
t := internal.NewTxnReq(
rsAPI,
keyAPI,
cfg.Matrix.ServerName,
keys,
mu,
producer,
cfg.Matrix.Presence.EnableInbound,
txnEvents.PDUs,
txnEvents.EDUs,
request.Origin(),
txnID,
cfg.Matrix.ServerName)
util.GetLogger(httpReq.Context()).Debugf("Received transaction %q from %q containing %d PDUs, %d EDUs", txnID, request.Origin(), len(t.PDUs), len(t.EDUs))
resp, jsonErr := t.processTransaction(httpReq.Context())
resp, jsonErr := t.ProcessTransaction(httpReq.Context())
if jsonErr != nil {
util.GetLogger(httpReq.Context()).WithField("jsonErr", jsonErr).Error("t.processTransaction failed")
return *jsonErr
@ -181,283 +150,3 @@ func Send(
ch <- res
return res
}
type txnReq struct {
gomatrixserverlib.Transaction
rsAPI api.FederationRoomserverAPI
keyAPI keyapi.FederationKeyAPI
ourServerName gomatrixserverlib.ServerName
keys gomatrixserverlib.JSONVerifier
federation txnFederationClient
roomsMu *internal.MutexByRoom
servers federationAPI.ServersInRoomProvider
producer *producers.SyncAPIProducer
inboundPresenceEnabled bool
}
// A subset of FederationClient functionality that txn requires. Useful for testing.
type txnFederationClient interface {
LookupState(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) (
res gomatrixserverlib.RespState, err error,
)
LookupStateIDs(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error)
GetEvent(ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error)
LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents,
roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error)
}
func (t *txnReq) processTransaction(ctx context.Context) (*gomatrixserverlib.RespSend, *util.JSONResponse) {
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
t.processEDUs(ctx)
}()
results := make(map[string]gomatrixserverlib.PDUResult)
roomVersions := make(map[string]gomatrixserverlib.RoomVersion)
getRoomVersion := func(roomID string) gomatrixserverlib.RoomVersion {
if v, ok := roomVersions[roomID]; ok {
return v
}
verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID}
verRes := api.QueryRoomVersionForRoomResponse{}
if err := t.rsAPI.QueryRoomVersionForRoom(ctx, &verReq, &verRes); err != nil {
util.GetLogger(ctx).WithError(err).Debug("Transaction: Failed to query room version for room", verReq.RoomID)
return ""
}
roomVersions[roomID] = verRes.RoomVersion
return verRes.RoomVersion
}
for _, pdu := range t.PDUs {
pduCountTotal.WithLabelValues("total").Inc()
var header struct {
RoomID string `json:"room_id"`
}
if err := json.Unmarshal(pdu, &header); err != nil {
util.GetLogger(ctx).WithError(err).Debug("Transaction: Failed to extract room ID from event")
// We don't know the event ID at this point so we can't return the
// failure in the PDU results
continue
}
roomVersion := getRoomVersion(header.RoomID)
event, err := gomatrixserverlib.NewEventFromUntrustedJSON(pdu, roomVersion)
if err != nil {
if _, ok := err.(gomatrixserverlib.BadJSONError); ok {
// Room version 6 states that homeservers should strictly enforce canonical JSON
// on PDUs.
//
// This enforces that the entire transaction is rejected if a single bad PDU is
// sent. It is unclear if this is the correct behaviour or not.
//
// See https://github.com/matrix-org/synapse/issues/7543
return nil, &util.JSONResponse{
Code: 400,
JSON: jsonerror.BadJSON("PDU contains bad JSON"),
}
}
util.GetLogger(ctx).WithError(err).Debugf("Transaction: Failed to parse event JSON of event %s", string(pdu))
continue
}
if event.Type() == gomatrixserverlib.MRoomCreate && event.StateKeyEquals("") {
continue
}
if api.IsServerBannedFromRoom(ctx, t.rsAPI, event.RoomID(), t.Origin) {
results[event.EventID()] = gomatrixserverlib.PDUResult{
Error: "Forbidden by server ACLs",
}
continue
}
if err = event.VerifyEventSignatures(ctx, t.keys); err != nil {
util.GetLogger(ctx).WithError(err).Debugf("Transaction: Couldn't validate signature of event %q", event.EventID())
results[event.EventID()] = gomatrixserverlib.PDUResult{
Error: err.Error(),
}
continue
}
// pass the event to the roomserver which will do auth checks
// If the event fail auth checks, gmsl.NotAllowed error will be returned which we be silently
// discarded by the caller of this function
if err = api.SendEvents(
ctx,
t.rsAPI,
api.KindNew,
[]*gomatrixserverlib.HeaderedEvent{
event.Headered(roomVersion),
},
t.Destination,
t.Origin,
api.DoNotSendToOtherServers,
nil,
true,
); err != nil {
util.GetLogger(ctx).WithError(err).Errorf("Transaction: Couldn't submit event %q to input queue: %s", event.EventID(), err)
results[event.EventID()] = gomatrixserverlib.PDUResult{
Error: err.Error(),
}
continue
}
results[event.EventID()] = gomatrixserverlib.PDUResult{}
pduCountTotal.WithLabelValues("success").Inc()
}
wg.Wait()
return &gomatrixserverlib.RespSend{PDUs: results}, nil
}
// nolint:gocyclo
func (t *txnReq) processEDUs(ctx context.Context) {
for _, e := range t.EDUs {
eduCountTotal.Inc()
switch e.Type {
case gomatrixserverlib.MTyping:
// https://matrix.org/docs/spec/server_server/latest#typing-notifications
var typingPayload struct {
RoomID string `json:"room_id"`
UserID string `json:"user_id"`
Typing bool `json:"typing"`
}
if err := json.Unmarshal(e.Content, &typingPayload); err != nil {
util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal typing event")
continue
}
if _, serverName, err := gomatrixserverlib.SplitID('@', typingPayload.UserID); err != nil {
continue
} else if serverName == t.ourServerName {
continue
} else if serverName != t.Origin {
continue
}
if err := t.producer.SendTyping(ctx, typingPayload.UserID, typingPayload.RoomID, typingPayload.Typing, 30*1000); err != nil {
util.GetLogger(ctx).WithError(err).Error("Failed to send typing event to JetStream")
}
case gomatrixserverlib.MDirectToDevice:
// https://matrix.org/docs/spec/server_server/r0.1.3#m-direct-to-device-schema
var directPayload gomatrixserverlib.ToDeviceMessage
if err := json.Unmarshal(e.Content, &directPayload); err != nil {
util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal send-to-device events")
continue
}
if _, serverName, err := gomatrixserverlib.SplitID('@', directPayload.Sender); err != nil {
continue
} else if serverName == t.ourServerName {
continue
} else if serverName != t.Origin {
continue
}
for userID, byUser := range directPayload.Messages {
for deviceID, message := range byUser {
// TODO: check that the user and the device actually exist here
if err := t.producer.SendToDevice(ctx, directPayload.Sender, userID, deviceID, directPayload.Type, message); err != nil {
sentry.CaptureException(err)
util.GetLogger(ctx).WithError(err).WithFields(logrus.Fields{
"sender": directPayload.Sender,
"user_id": userID,
"device_id": deviceID,
}).Error("Failed to send send-to-device event to JetStream")
}
}
}
case gomatrixserverlib.MDeviceListUpdate:
if err := t.producer.SendDeviceListUpdate(ctx, e.Content, t.Origin); err != nil {
sentry.CaptureException(err)
util.GetLogger(ctx).WithError(err).Error("failed to InputDeviceListUpdate")
}
case gomatrixserverlib.MReceipt:
// https://matrix.org/docs/spec/server_server/r0.1.4#receipts
payload := map[string]types.FederationReceiptMRead{}
if err := json.Unmarshal(e.Content, &payload); err != nil {
util.GetLogger(ctx).WithError(err).Debug("Failed to unmarshal receipt event")
continue
}
for roomID, receipt := range payload {
for userID, mread := range receipt.User {
_, domain, err := gomatrixserverlib.SplitID('@', userID)
if err != nil {
util.GetLogger(ctx).WithError(err).Debug("Failed to split domain from receipt event sender")
continue
}
if t.Origin != domain {
util.GetLogger(ctx).Debugf("Dropping receipt event where sender domain (%q) doesn't match origin (%q)", domain, t.Origin)
continue
}
if err := t.processReceiptEvent(ctx, userID, roomID, "m.read", mread.Data.TS, mread.EventIDs); err != nil {
util.GetLogger(ctx).WithError(err).WithFields(logrus.Fields{
"sender": t.Origin,
"user_id": userID,
"room_id": roomID,
"events": mread.EventIDs,
}).Error("Failed to send receipt event to JetStream")
continue
}
}
}
case types.MSigningKeyUpdate:
if err := t.producer.SendSigningKeyUpdate(ctx, e.Content, t.Origin); err != nil {
sentry.CaptureException(err)
logrus.WithError(err).Errorf("Failed to process signing key update")
}
case gomatrixserverlib.MPresence:
if t.inboundPresenceEnabled {
if err := t.processPresence(ctx, e); err != nil {
logrus.WithError(err).Errorf("Failed to process presence update")
}
}
default:
util.GetLogger(ctx).WithField("type", e.Type).Debug("Unhandled EDU")
}
}
}
// processPresence handles m.receipt events
func (t *txnReq) processPresence(ctx context.Context, e gomatrixserverlib.EDU) error {
payload := types.Presence{}
if err := json.Unmarshal(e.Content, &payload); err != nil {
return err
}
for _, content := range payload.Push {
if _, serverName, err := gomatrixserverlib.SplitID('@', content.UserID); err != nil {
continue
} else if serverName == t.ourServerName {
continue
} else if serverName != t.Origin {
continue
}
presence, ok := syncTypes.PresenceFromString(content.Presence)
if !ok {
continue
}
if err := t.producer.SendPresence(ctx, content.UserID, presence, content.StatusMsg, content.LastActiveAgo); err != nil {
return err
}
}
return nil
}
// processReceiptEvent sends receipt events to JetStream
func (t *txnReq) processReceiptEvent(ctx context.Context,
userID, roomID, receiptType string,
timestamp gomatrixserverlib.Timestamp,
eventIDs []string,
) error {
if _, serverName, err := gomatrixserverlib.SplitID('@', userID); err != nil {
return nil
} else if serverName == t.ourServerName {
return nil
} else if serverName != t.Origin {
return nil
}
// store every event
for _, eventID := range eventIDs {
if err := t.producer.SendReceipt(ctx, userID, roomID, eventID, receiptType, timestamp); err != nil {
return fmt.Errorf("unable to set receipt event: %w", err)
}
}
return nil
}

View file

@ -1,552 +1,87 @@
package routing
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package routing_test
import (
"context"
"encoding/hex"
"encoding/json"
"fmt"
"net/http/httptest"
"testing"
"time"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/cmd/dendrite-demo-yggdrasil/signing"
fedAPI "github.com/matrix-org/dendrite/federationapi"
fedInternal "github.com/matrix-org/dendrite/federationapi/internal"
"github.com/matrix-org/dendrite/federationapi/routing"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
"github.com/matrix-org/gomatrixserverlib"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/ed25519"
)
const (
testOrigin = gomatrixserverlib.ServerName("kaer.morhen")
testDestination = gomatrixserverlib.ServerName("white.orchard")
testOrigin = gomatrixserverlib.ServerName("kaer.morhen")
)
var (
testRoomVersion = gomatrixserverlib.RoomVersionV1
testData = []json.RawMessage{
[]byte(`{"auth_events":[],"content":{"creator":"@userid:kaer.morhen"},"depth":0,"event_id":"$0ok8ynDp7kjc95e3:kaer.morhen","hashes":{"sha256":"17kPoH+h0Dk4Omn7Sus0qMb6+oGcf+CZFEgDhv7UKWs"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"jP4a04f5/F10Pw95FPpdCyKAO44JOwUQ/MZOOeA/RTU1Dn+AHPMzGSaZnuGjRr/xQuADt+I3ctb5ZQfLKNzHDw"}},"state_key":"","type":"m.room.create"}`),
[]byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}]],"content":{"membership":"join"},"depth":1,"event_id":"$LEwEu0kxrtu5fOiS:kaer.morhen","hashes":{"sha256":"B7M88PhXf3vd1LaFtjQutFu4x/w7fHD28XKZ4sAsJTo"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"p2vqmuJn7ZBRImctSaKbXCAxCcBlIjPH9JHte1ouIUGy84gpu4eLipOvSBCLL26hXfC0Zrm4WUto6Hr+ohdrCg"}},"state_key":"@userid:kaer.morhen","type":"m.room.member"}`),
[]byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"join_rule":"public"},"depth":2,"event_id":"$SMHlqUrNhhBBRLeN:kaer.morhen","hashes":{"sha256":"vIuJQvmMjrGxshAkj1SXe0C4RqvMbv4ZADDw9pFCWqQ"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"hBMsb3Qppo3RaqqAl4JyTgaiWEbW5hlckATky6PrHun+F3YM203TzG7w9clwuQU5F5pZoB1a6nw+to0hN90FAw"}},"state_key":"","type":"m.room.join_rules"}`),
[]byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"history_visibility":"shared"},"depth":3,"event_id":"$6F1yGIbO0J7TM93h:kaer.morhen","hashes":{"sha256":"Mr23GKSlZW7UCCYLgOWawI2Sg6KIoMjUWO2TDenuOgw"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$SMHlqUrNhhBBRLeN:kaer.morhen",{"sha256":"SylzE8U02I+6eyEHgL+FlU0L5YdqrVp8OOlxKS9VQW0"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sHLKrFI3hKGrEJfpMVZSDS3LvLasQsy50CTsOwru9XTVxgRsPo6wozNtRVjxo1J3Rk18RC9JppovmQ5VR5EcDw"}},"state_key":"","type":"m.room.history_visibility"}`),
[]byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"ban":50,"events":null,"events_default":0,"invite":0,"kick":50,"redact":50,"state_default":50,"users":null,"users_default":0},"depth":4,"event_id":"$UKNe10XzYzG0TeA9:kaer.morhen","hashes":{"sha256":"ngbP3yja9U5dlckKerUs/fSOhtKxZMCVvsfhPURSS28"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$6F1yGIbO0J7TM93h:kaer.morhen",{"sha256":"A4CucrKSoWX4IaJXhq02mBg1sxIyZEftbC+5p3fZAvk"}]],"prev_state":[],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"zOmwlP01QL3yFchzuR9WHvogOoBZA3oVtNIF3lM0ZfDnqlSYZB9sns27G/4HVq0k7alaK7ZE3oGoCrVnMkPNCw"}},"state_key":"","type":"m.room.power_levels"}`),
// messages
[]byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":5,"event_id":"$gl2T9l3qm0kUbiIJ:kaer.morhen","hashes":{"sha256":"Qx3nRMHLDPSL5hBAzuX84FiSSP0K0Kju2iFoBWH4Za8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$UKNe10XzYzG0TeA9:kaer.morhen",{"sha256":"KtSRyMjt0ZSjsv2koixTRCxIRCGoOp6QrKscsW97XRo"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"sqDgv3EG7ml5VREzmT9aZeBpS4gAPNIaIeJOwqjDhY0GPU/BcpX5wY4R7hYLrNe5cChgV+eFy/GWm1Zfg5FfDg"}},"type":"m.room.message"}`),
[]byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":6,"event_id":"$MYSbs8m4rEbsCWXD:kaer.morhen","hashes":{"sha256":"kgbYM7v4Ud2YaBsjBTolM4ySg6rHcJNYI6nWhMSdFUA"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$gl2T9l3qm0kUbiIJ:kaer.morhen",{"sha256":"C/rD04h9wGxRdN2G/IBfrgoE1UovzLZ+uskwaKZ37/Q"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"x0UoKh968jj/F5l1/R7Ew0T6CTKuew3PLNHASNxqck/bkNe8yYQiDHXRr+kZxObeqPZZTpaF1+EI+bLU9W8GDQ"}},"type":"m.room.message"}`),
[]byte(`{"auth_events":[["$0ok8ynDp7kjc95e3:kaer.morhen",{"sha256":"sWCi6Ckp9rDimQON+MrUlNRkyfZ2tjbPbWfg2NMB18Q"}],["$LEwEu0kxrtu5fOiS:kaer.morhen",{"sha256":"1aKajq6DWHru1R1HJjvdWMEavkJJHGaTmPvfuERUXaA"}]],"content":{"body":"Test Message"},"depth":7,"event_id":"$N5x9WJkl9ClPrAEg:kaer.morhen","hashes":{"sha256":"FWM8oz4yquTunRZ67qlW2gzPDzdWfBP6RPHXhK1I/x8"},"origin":"kaer.morhen","origin_server_ts":0,"prev_events":[["$MYSbs8m4rEbsCWXD:kaer.morhen",{"sha256":"fatqgW+SE8mb2wFn3UN+drmluoD4UJ/EcSrL6Ur9q1M"}]],"room_id":"!roomid:kaer.morhen","sender":"@userid:kaer.morhen","signatures":{"kaer.morhen":{"ed25519:auto":"Y+LX/xcyufoXMOIoqQBNOzy6lZfUGB1ffgXIrSugk6obMiyAsiRejHQN/pciZXsHKxMJLYRFAz4zSJoS/LGPAA"}},"type":"m.room.message"}`),
}
testEvents = []*gomatrixserverlib.HeaderedEvent{}
testStateEvents = make(map[gomatrixserverlib.StateKeyTuple]*gomatrixserverlib.HeaderedEvent)
)
type sendContent struct {
PDUs []json.RawMessage `json:"pdus"`
EDUs []gomatrixserverlib.EDU `json:"edus"`
}
func init() {
for _, j := range testData {
e, err := gomatrixserverlib.NewEventFromTrustedJSON(j, false, testRoomVersion)
func TestHandleSend(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
base, close := testrig.CreateBaseDendrite(t, dbType)
defer close()
fedMux := mux.NewRouter().SkipClean(true).PathPrefix(httputil.PublicFederationPathPrefix).Subrouter().UseEncodedPath()
base.PublicFederationAPIMux = fedMux
base.Cfg.FederationAPI.Matrix.SigningIdentity.ServerName = testOrigin
base.Cfg.FederationAPI.Matrix.Metrics.Enabled = false
fedapi := fedAPI.NewInternalAPI(base, nil, nil, nil, nil, true)
serverKeyAPI := &signing.YggdrasilKeys{}
keyRing := serverKeyAPI.KeyRing()
r, ok := fedapi.(*fedInternal.FederationInternalAPI)
if !ok {
panic("This is a programming error.")
}
routing.Setup(base, nil, r, keyRing, nil, nil, nil, &base.Cfg.MSCs, nil, nil)
handler := fedMux.Get(routing.SendRouteName).GetHandler().ServeHTTP
_, sk, _ := ed25519.GenerateKey(nil)
keyID := signing.KeyID
pk := sk.Public().(ed25519.PublicKey)
serverName := gomatrixserverlib.ServerName(hex.EncodeToString(pk))
req := gomatrixserverlib.NewFederationRequest("PUT", serverName, testOrigin, "/send/1234")
content := sendContent{}
err := req.SetContent(content)
if err != nil {
panic("cannot load test data: " + err.Error())
t.Fatalf("Error: %s", err.Error())
}
h := e.Headered(testRoomVersion)
testEvents = append(testEvents, h)
if e.StateKey() != nil {
testStateEvents[gomatrixserverlib.StateKeyTuple{
EventType: e.Type(),
StateKey: *e.StateKey(),
}] = h
req.Sign(serverName, gomatrixserverlib.KeyID(keyID), sk)
httpReq, err := req.HTTPRequest()
if err != nil {
t.Fatalf("Error: %s", err.Error())
}
}
vars := map[string]string{"txnID": "1234"}
w := httptest.NewRecorder()
httpReq = mux.SetURLVars(httpReq, vars)
handler(w, httpReq)
res := w.Result()
assert.Equal(t, 200, res.StatusCode)
})
}
type testRoomserverAPI struct {
api.RoomserverInternalAPITrace
inputRoomEvents []api.InputRoomEvent
queryStateAfterEvents func(*api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse
queryEventsByID func(req *api.QueryEventsByIDRequest) api.QueryEventsByIDResponse
queryLatestEventsAndState func(*api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse
}
func (t *testRoomserverAPI) InputRoomEvents(
ctx context.Context,
request *api.InputRoomEventsRequest,
response *api.InputRoomEventsResponse,
) error {
t.inputRoomEvents = append(t.inputRoomEvents, request.InputRoomEvents...)
for _, ire := range request.InputRoomEvents {
fmt.Println("InputRoomEvents: ", ire.Event.EventID())
}
return nil
}
// Query the latest events and state for a room from the room server.
func (t *testRoomserverAPI) QueryLatestEventsAndState(
ctx context.Context,
request *api.QueryLatestEventsAndStateRequest,
response *api.QueryLatestEventsAndStateResponse,
) error {
r := t.queryLatestEventsAndState(request)
response.RoomExists = r.RoomExists
response.RoomVersion = testRoomVersion
response.LatestEvents = r.LatestEvents
response.StateEvents = r.StateEvents
response.Depth = r.Depth
return nil
}
// Query the state after a list of events in a room from the room server.
func (t *testRoomserverAPI) QueryStateAfterEvents(
ctx context.Context,
request *api.QueryStateAfterEventsRequest,
response *api.QueryStateAfterEventsResponse,
) error {
response.RoomVersion = testRoomVersion
res := t.queryStateAfterEvents(request)
response.PrevEventsExist = res.PrevEventsExist
response.RoomExists = res.RoomExists
response.StateEvents = res.StateEvents
return nil
}
// Query a list of events by event ID.
func (t *testRoomserverAPI) QueryEventsByID(
ctx context.Context,
request *api.QueryEventsByIDRequest,
response *api.QueryEventsByIDResponse,
) error {
res := t.queryEventsByID(request)
response.Events = res.Events
return nil
}
// Query if a server is joined to a room
func (t *testRoomserverAPI) QueryServerJoinedToRoom(
ctx context.Context,
request *api.QueryServerJoinedToRoomRequest,
response *api.QueryServerJoinedToRoomResponse,
) error {
response.RoomExists = true
response.IsInRoom = true
return nil
}
// Asks for the room version for a given room.
func (t *testRoomserverAPI) QueryRoomVersionForRoom(
ctx context.Context,
request *api.QueryRoomVersionForRoomRequest,
response *api.QueryRoomVersionForRoomResponse,
) error {
response.RoomVersion = testRoomVersion
return nil
}
func (t *testRoomserverAPI) QueryServerBannedFromRoom(
ctx context.Context, req *api.QueryServerBannedFromRoomRequest, res *api.QueryServerBannedFromRoomResponse,
) error {
res.Banned = false
return nil
}
type txnFedClient struct {
state map[string]gomatrixserverlib.RespState // event_id to response
stateIDs map[string]gomatrixserverlib.RespStateIDs // event_id to response
getEvent map[string]gomatrixserverlib.Transaction // event_id to response
getMissingEvents func(gomatrixserverlib.MissingEvents) (res gomatrixserverlib.RespMissingEvents, err error)
}
func (c *txnFedClient) LookupState(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string, roomVersion gomatrixserverlib.RoomVersion) (
res gomatrixserverlib.RespState, err error,
) {
fmt.Println("testFederationClient.LookupState", eventID)
r, ok := c.state[eventID]
if !ok {
err = fmt.Errorf("txnFedClient: no /state for event %s", eventID)
return
}
res = r
return
}
func (c *txnFedClient) LookupStateIDs(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, eventID string) (res gomatrixserverlib.RespStateIDs, err error) {
fmt.Println("testFederationClient.LookupStateIDs", eventID)
r, ok := c.stateIDs[eventID]
if !ok {
err = fmt.Errorf("txnFedClient: no /state_ids for event %s", eventID)
return
}
res = r
return
}
func (c *txnFedClient) GetEvent(ctx context.Context, origin, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error) {
fmt.Println("testFederationClient.GetEvent", eventID)
r, ok := c.getEvent[eventID]
if !ok {
err = fmt.Errorf("txnFedClient: no /event for event ID %s", eventID)
return
}
res = r
return
}
func (c *txnFedClient) LookupMissingEvents(ctx context.Context, origin, s gomatrixserverlib.ServerName, roomID string, missing gomatrixserverlib.MissingEvents,
roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.RespMissingEvents, err error) {
return c.getMissingEvents(missing)
}
func mustCreateTransaction(rsAPI api.FederationRoomserverAPI, fedClient txnFederationClient, pdus []json.RawMessage) *txnReq {
t := &txnReq{
rsAPI: rsAPI,
keys: &test.NopJSONVerifier{},
federation: fedClient,
roomsMu: internal.NewMutexByRoom(),
}
t.PDUs = pdus
t.Origin = testOrigin
t.TransactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d", time.Now().UnixNano()))
t.Destination = testDestination
return t
}
func mustProcessTransaction(t *testing.T, txn *txnReq, pdusWithErrors []string) {
res, err := txn.processTransaction(context.Background())
if err != nil {
t.Errorf("txn.processTransaction returned an error: %v", err)
return
}
if len(res.PDUs) != len(txn.PDUs) {
t.Errorf("txn.processTransaction did not return results for all PDUs, got %d want %d", len(res.PDUs), len(txn.PDUs))
return
}
NextPDU:
for eventID, result := range res.PDUs {
if result.Error == "" {
continue
}
for _, eventIDWantError := range pdusWithErrors {
if eventID == eventIDWantError {
break NextPDU
}
}
t.Errorf("txn.processTransaction PDU %s returned an error %s", eventID, result.Error)
}
}
/*
func fromStateTuples(tuples []gomatrixserverlib.StateKeyTuple, omitTuples []gomatrixserverlib.StateKeyTuple) (result []*gomatrixserverlib.HeaderedEvent) {
NextTuple:
for _, t := range tuples {
for _, o := range omitTuples {
if t == o {
break NextTuple
}
}
h, ok := testStateEvents[t]
if ok {
result = append(result, h)
}
}
return
}
*/
func assertInputRoomEvents(t *testing.T, got []api.InputRoomEvent, want []*gomatrixserverlib.HeaderedEvent) {
for _, g := range got {
fmt.Println("GOT ", g.Event.EventID())
}
if len(got) != len(want) {
t.Errorf("wrong number of InputRoomEvents: got %d want %d", len(got), len(want))
return
}
for i := range got {
if got[i].Event.EventID() != want[i].EventID() {
t.Errorf("InputRoomEvents[%d] got %s want %s", i, got[i].Event.EventID(), want[i].EventID())
}
}
}
// The purpose of this test is to check that receiving an event over federation for which we have the prev_events works correctly, and passes it on
// to the roomserver. It's the most basic test possible.
func TestBasicTransaction(t *testing.T) {
rsAPI := &testRoomserverAPI{}
pdus := []json.RawMessage{
testData[len(testData)-1], // a message event
}
txn := mustCreateTransaction(rsAPI, &txnFedClient{}, pdus)
mustProcessTransaction(t, txn, nil)
assertInputRoomEvents(t, rsAPI.inputRoomEvents, []*gomatrixserverlib.HeaderedEvent{testEvents[len(testEvents)-1]})
}
// The purpose of this test is to check that if the event received fails auth checks the event is still sent to the roomserver
// as it does the auth check.
func TestTransactionFailAuthChecks(t *testing.T) {
rsAPI := &testRoomserverAPI{}
pdus := []json.RawMessage{
testData[len(testData)-1], // a message event
}
txn := mustCreateTransaction(rsAPI, &txnFedClient{}, pdus)
mustProcessTransaction(t, txn, []string{})
// expect message to be sent to the roomserver
assertInputRoomEvents(t, rsAPI.inputRoomEvents, []*gomatrixserverlib.HeaderedEvent{testEvents[len(testEvents)-1]})
}
// The purpose of this test is to make sure that when an event is received for which we do not know the prev_events,
// we request them from /get_missing_events. It works by setting PrevEventsExist=false in the roomserver query response,
// resulting in a call to /get_missing_events which returns the missing prev event. Both events should be processed in
// topological order and sent to the roomserver.
/*
func TestTransactionFetchMissingPrevEvents(t *testing.T) {
haveEvent := testEvents[len(testEvents)-3]
prevEvent := testEvents[len(testEvents)-2]
inputEvent := testEvents[len(testEvents)-1]
var rsAPI *testRoomserverAPI // ref here so we can refer to inputRoomEvents inside these functions
rsAPI = &testRoomserverAPI{
queryEventsByID: func(req *api.QueryEventsByIDRequest) api.QueryEventsByIDResponse {
res := api.QueryEventsByIDResponse{}
for _, ev := range testEvents {
for _, id := range req.EventIDs {
if ev.EventID() == id {
res.Events = append(res.Events, ev)
}
}
}
return res
},
queryStateAfterEvents: func(req *api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse {
return api.QueryStateAfterEventsResponse{
PrevEventsExist: true,
StateEvents: testEvents[:5],
}
},
queryMissingAuthPrevEvents: func(req *api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse {
missingPrevEvent := []string{"missing_prev_event"}
if len(req.PrevEventIDs) == 1 {
switch req.PrevEventIDs[0] {
case haveEvent.EventID():
missingPrevEvent = []string{}
case prevEvent.EventID():
// we only have this event if we've been send prevEvent
if len(rsAPI.inputRoomEvents) == 1 && rsAPI.inputRoomEvents[0].Event.EventID() == prevEvent.EventID() {
missingPrevEvent = []string{}
}
}
}
return api.QueryMissingAuthPrevEventsResponse{
RoomExists: true,
MissingAuthEventIDs: []string{},
MissingPrevEventIDs: missingPrevEvent,
}
},
queryLatestEventsAndState: func(req *api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse {
return api.QueryLatestEventsAndStateResponse{
RoomExists: true,
Depth: haveEvent.Depth(),
LatestEvents: []gomatrixserverlib.EventReference{
haveEvent.EventReference(),
},
StateEvents: fromStateTuples(req.StateToFetch, nil),
}
},
}
cli := &txnFedClient{
getMissingEvents: func(missing gomatrixserverlib.MissingEvents) (res gomatrixserverlib.RespMissingEvents, err error) {
if !reflect.DeepEqual(missing.EarliestEvents, []string{haveEvent.EventID()}) {
t.Errorf("call to /get_missing_events wrong earliest events: got %v want %v", missing.EarliestEvents, haveEvent.EventID())
}
if !reflect.DeepEqual(missing.LatestEvents, []string{inputEvent.EventID()}) {
t.Errorf("call to /get_missing_events wrong latest events: got %v want %v", missing.LatestEvents, inputEvent.EventID())
}
return gomatrixserverlib.RespMissingEvents{
Events: []*gomatrixserverlib.Event{
prevEvent.Unwrap(),
},
}, nil
},
}
pdus := []json.RawMessage{
inputEvent.JSON(),
}
txn := mustCreateTransaction(rsAPI, cli, pdus)
mustProcessTransaction(t, txn, nil)
assertInputRoomEvents(t, rsAPI.inputRoomEvents, []*gomatrixserverlib.HeaderedEvent{prevEvent, inputEvent})
}
// The purpose of this test is to check that when there are missing prev_events and we still haven't been able to fill
// in the hole with /get_missing_events that the state BEFORE the events we want to persist is fetched via /state_ids
// and /event. It works by setting PrevEventsExist=false in the roomserver query response, resulting in
// a call to /get_missing_events which returns 1 out of the 2 events it needs to fill in the gap. Synapse and Dendrite
// both give up after 1x /get_missing_events call, relying on requesting the state AFTER the missing event in order to
// continue. The DAG looks something like:
// FE GME TXN
// A ---> B ---> C ---> D
// TXN=event in the txn, GME=response to /get_missing_events, FE=roomserver's forward extremity. Should result in:
// - /state_ids?event=B is requested, then /event/B to get the state AFTER B. B is a state event.
// - state resolution is done to check C is allowed.
// This results in B being sent as an outlier FIRST, then C,D.
func TestTransactionFetchMissingStateByStateIDs(t *testing.T) {
eventA := testEvents[len(testEvents)-5]
// this is also len(testEvents)-4
eventB := testStateEvents[gomatrixserverlib.StateKeyTuple{
EventType: gomatrixserverlib.MRoomPowerLevels,
StateKey: "",
}]
eventC := testEvents[len(testEvents)-3]
eventD := testEvents[len(testEvents)-2]
fmt.Println("a:", eventA.EventID())
fmt.Println("b:", eventB.EventID())
fmt.Println("c:", eventC.EventID())
fmt.Println("d:", eventD.EventID())
var rsAPI *testRoomserverAPI
rsAPI = &testRoomserverAPI{
queryStateAfterEvents: func(req *api.QueryStateAfterEventsRequest) api.QueryStateAfterEventsResponse {
omitTuples := []gomatrixserverlib.StateKeyTuple{
{
EventType: gomatrixserverlib.MRoomPowerLevels,
StateKey: "",
},
}
askingForEvent := req.PrevEventIDs[0]
haveEventB := false
haveEventC := false
for _, ev := range rsAPI.inputRoomEvents {
switch ev.Event.EventID() {
case eventB.EventID():
haveEventB = true
omitTuples = nil // include event B now
case eventC.EventID():
haveEventC = true
}
}
prevEventExists := false
if askingForEvent == eventC.EventID() {
prevEventExists = haveEventC
} else if askingForEvent == eventB.EventID() {
prevEventExists = haveEventB
}
var stateEvents []*gomatrixserverlib.HeaderedEvent
if prevEventExists {
stateEvents = fromStateTuples(req.StateToFetch, omitTuples)
}
return api.QueryStateAfterEventsResponse{
PrevEventsExist: prevEventExists,
RoomExists: true,
StateEvents: stateEvents,
}
},
queryMissingAuthPrevEvents: func(req *api.QueryMissingAuthPrevEventsRequest) api.QueryMissingAuthPrevEventsResponse {
askingForEvent := req.PrevEventIDs[0]
haveEventB := false
haveEventC := false
for _, ev := range rsAPI.inputRoomEvents {
switch ev.Event.EventID() {
case eventB.EventID():
haveEventB = true
case eventC.EventID():
haveEventC = true
}
}
prevEventExists := false
if askingForEvent == eventC.EventID() {
prevEventExists = haveEventC
} else if askingForEvent == eventB.EventID() {
prevEventExists = haveEventB
}
var missingPrevEvent []string
if !prevEventExists {
missingPrevEvent = []string{"test"}
}
return api.QueryMissingAuthPrevEventsResponse{
RoomExists: true,
MissingAuthEventIDs: []string{},
MissingPrevEventIDs: missingPrevEvent,
}
},
queryLatestEventsAndState: func(req *api.QueryLatestEventsAndStateRequest) api.QueryLatestEventsAndStateResponse {
omitTuples := []gomatrixserverlib.StateKeyTuple{
{EventType: gomatrixserverlib.MRoomPowerLevels, StateKey: ""},
}
return api.QueryLatestEventsAndStateResponse{
RoomExists: true,
Depth: eventA.Depth(),
LatestEvents: []gomatrixserverlib.EventReference{
eventA.EventReference(),
},
StateEvents: fromStateTuples(req.StateToFetch, omitTuples),
}
},
queryEventsByID: func(req *api.QueryEventsByIDRequest) api.QueryEventsByIDResponse {
var res api.QueryEventsByIDResponse
fmt.Println("queryEventsByID ", req.EventIDs)
for _, wantEventID := range req.EventIDs {
for _, ev := range testStateEvents {
// roomserver is missing the power levels event unless it's been sent to us recently as an outlier
if wantEventID == eventB.EventID() {
fmt.Println("Asked for pl event")
for _, inEv := range rsAPI.inputRoomEvents {
fmt.Println("recv ", inEv.Event.EventID())
if inEv.Event.EventID() == wantEventID {
res.Events = append(res.Events, inEv.Event)
break
}
}
continue
}
if ev.EventID() == wantEventID {
res.Events = append(res.Events, ev)
}
}
}
return res
},
}
// /state_ids for event B returns every state event but B (it's the state before)
var authEventIDs []string
var stateEventIDs []string
for _, ev := range testStateEvents {
if ev.EventID() == eventB.EventID() {
continue
}
// state res checks what auth events you give it, and this isn't a valid auth event
if ev.Type() != gomatrixserverlib.MRoomHistoryVisibility {
authEventIDs = append(authEventIDs, ev.EventID())
}
stateEventIDs = append(stateEventIDs, ev.EventID())
}
cli := &txnFedClient{
stateIDs: map[string]gomatrixserverlib.RespStateIDs{
eventB.EventID(): {
StateEventIDs: stateEventIDs,
AuthEventIDs: authEventIDs,
},
},
// /event for event B returns it
getEvent: map[string]gomatrixserverlib.Transaction{
eventB.EventID(): {
PDUs: []json.RawMessage{
eventB.JSON(),
},
},
},
// /get_missing_events should be done exactly once
getMissingEvents: func(missing gomatrixserverlib.MissingEvents) (res gomatrixserverlib.RespMissingEvents, err error) {
if !reflect.DeepEqual(missing.EarliestEvents, []string{eventA.EventID()}) {
t.Errorf("call to /get_missing_events wrong earliest events: got %v want %v", missing.EarliestEvents, eventA.EventID())
}
if !reflect.DeepEqual(missing.LatestEvents, []string{eventD.EventID()}) {
t.Errorf("call to /get_missing_events wrong latest events: got %v want %v", missing.LatestEvents, eventD.EventID())
}
// just return event C, not event B so /state_ids logic kicks in as there will STILL be missing prev_events
return gomatrixserverlib.RespMissingEvents{
Events: []*gomatrixserverlib.Event{
eventC.Unwrap(),
},
}, nil
},
}
pdus := []json.RawMessage{
eventD.JSON(),
}
txn := mustCreateTransaction(rsAPI, cli, pdus)
mustProcessTransaction(t, txn, nil)
assertInputRoomEvents(t, rsAPI.inputRoomEvents, []*gomatrixserverlib.HeaderedEvent{eventB, eventC, eventD})
}
*/

View file

@ -1,6 +1,7 @@
package statistics
import (
"context"
"math"
"math/rand"
"sync"
@ -28,14 +29,24 @@ type Statistics struct {
// just blacklist the host altogether? The backoff is exponential,
// so the max time here to attempt is 2**failures seconds.
FailuresUntilBlacklist uint32
// How many times should we tolerate consecutive failures before we
// mark the destination as offline. At this point we should attempt
// to send messages to the user's async relay servers if we know them.
FailuresUntilAssumedOffline uint32
}
func NewStatistics(db storage.Database, failuresUntilBlacklist uint32) Statistics {
func NewStatistics(
db storage.Database,
failuresUntilBlacklist uint32,
failuresUntilAssumedOffline uint32,
) Statistics {
return Statistics{
DB: db,
FailuresUntilBlacklist: failuresUntilBlacklist,
backoffTimers: make(map[gomatrixserverlib.ServerName]*time.Timer),
servers: make(map[gomatrixserverlib.ServerName]*ServerStatistics),
DB: db,
FailuresUntilBlacklist: failuresUntilBlacklist,
FailuresUntilAssumedOffline: failuresUntilAssumedOffline,
backoffTimers: make(map[gomatrixserverlib.ServerName]*time.Timer),
servers: make(map[gomatrixserverlib.ServerName]*ServerStatistics),
}
}
@ -50,8 +61,9 @@ func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerS
if !found {
s.mutex.Lock()
server = &ServerStatistics{
statistics: s,
serverName: serverName,
statistics: s,
serverName: serverName,
knownRelayServers: []gomatrixserverlib.ServerName{},
}
s.servers[serverName] = server
s.mutex.Unlock()
@ -61,24 +73,49 @@ func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerS
} else {
server.blacklisted.Store(blacklisted)
}
assumedOffline, err := s.DB.IsServerAssumedOffline(context.Background(), serverName)
if err != nil {
logrus.WithError(err).Errorf("Failed to get assumed offline entry %q", serverName)
} else {
server.assumedOffline.Store(assumedOffline)
}
knownRelayServers, err := s.DB.P2PGetRelayServersForServer(context.Background(), serverName)
if err != nil {
logrus.WithError(err).Errorf("Failed to get relay server list for %q", serverName)
} else {
server.relayMutex.Lock()
server.knownRelayServers = knownRelayServers
server.relayMutex.Unlock()
}
}
return server
}
type SendMethod uint8
const (
SendDirect SendMethod = iota
SendViaRelay
)
// ServerStatistics contains information about our interactions with a
// remote federated host, e.g. how many times we were successful, how
// many times we failed etc. It also manages the backoff time and black-
// listing a remote host if it remains uncooperative.
type ServerStatistics struct {
statistics *Statistics //
serverName gomatrixserverlib.ServerName //
blacklisted atomic.Bool // is the node blacklisted
backoffStarted atomic.Bool // is the backoff started
backoffUntil atomic.Value // time.Time until this backoff interval ends
backoffCount atomic.Uint32 // number of times BackoffDuration has been called
successCounter atomic.Uint32 // how many times have we succeeded?
backoffNotifier func() // notifies destination queue when backoff completes
notifierMutex sync.Mutex
statistics *Statistics //
serverName gomatrixserverlib.ServerName //
blacklisted atomic.Bool // is the node blacklisted
assumedOffline atomic.Bool // is the node assumed to be offline
backoffStarted atomic.Bool // is the backoff started
backoffUntil atomic.Value // time.Time until this backoff interval ends
backoffCount atomic.Uint32 // number of times BackoffDuration has been called
successCounter atomic.Uint32 // how many times have we succeeded?
backoffNotifier func() // notifies destination queue when backoff completes
notifierMutex sync.Mutex
knownRelayServers []gomatrixserverlib.ServerName
relayMutex sync.Mutex
}
const maxJitterMultiplier = 1.4
@ -113,13 +150,19 @@ func (s *ServerStatistics) AssignBackoffNotifier(notifier func()) {
// attempt, which increases the sent counter and resets the idle and
// failure counters. If a host was blacklisted at this point then
// we will unblacklist it.
func (s *ServerStatistics) Success() {
// `relay` specifies whether the success was to the actual destination
// or one of their relay servers.
func (s *ServerStatistics) Success(method SendMethod) {
s.cancel()
s.backoffCount.Store(0)
s.successCounter.Inc()
if s.statistics.DB != nil {
if err := s.statistics.DB.RemoveServerFromBlacklist(s.serverName); err != nil {
logrus.WithError(err).Errorf("Failed to remove %q from blacklist", s.serverName)
// NOTE : Sending to the final destination vs. a relay server has
// slightly different semantics.
if method == SendDirect {
s.successCounter.Inc()
if s.blacklisted.Load() && s.statistics.DB != nil {
if err := s.statistics.DB.RemoveServerFromBlacklist(s.serverName); err != nil {
logrus.WithError(err).Errorf("Failed to remove %q from blacklist", s.serverName)
}
}
}
}
@ -139,7 +182,18 @@ func (s *ServerStatistics) Failure() (time.Time, bool) {
// start a goroutine which will wait out the backoff and
// unset the backoffStarted flag when done.
if s.backoffStarted.CompareAndSwap(false, true) {
if s.backoffCount.Inc() >= s.statistics.FailuresUntilBlacklist {
backoffCount := s.backoffCount.Inc()
if backoffCount >= s.statistics.FailuresUntilAssumedOffline {
s.assumedOffline.CompareAndSwap(false, true)
if s.statistics.DB != nil {
if err := s.statistics.DB.SetServerAssumedOffline(context.Background(), s.serverName); err != nil {
logrus.WithError(err).Errorf("Failed to set %q as assumed offline", s.serverName)
}
}
}
if backoffCount >= s.statistics.FailuresUntilBlacklist {
s.blacklisted.Store(true)
if s.statistics.DB != nil {
if err := s.statistics.DB.AddServerToBlacklist(s.serverName); err != nil {
@ -157,13 +211,21 @@ func (s *ServerStatistics) Failure() (time.Time, bool) {
s.backoffUntil.Store(until)
s.statistics.backoffMutex.Lock()
defer s.statistics.backoffMutex.Unlock()
s.statistics.backoffTimers[s.serverName] = time.AfterFunc(time.Until(until), s.backoffFinished)
s.statistics.backoffMutex.Unlock()
}
return s.backoffUntil.Load().(time.Time), false
}
// MarkServerAlive removes the assumed offline and blacklisted statuses from this server.
// Returns whether the server was blacklisted before this point.
func (s *ServerStatistics) MarkServerAlive() bool {
s.removeAssumedOffline()
wasBlacklisted := s.removeBlacklist()
return wasBlacklisted
}
// ClearBackoff stops the backoff timer for this destination if it is running
// and removes the timer from the backoffTimers map.
func (s *ServerStatistics) ClearBackoff() {
@ -191,13 +253,13 @@ func (s *ServerStatistics) backoffFinished() {
}
// BackoffInfo returns information about the current or previous backoff.
// Returns the last backoffUntil time and whether the server is currently blacklisted or not.
func (s *ServerStatistics) BackoffInfo() (*time.Time, bool) {
// Returns the last backoffUntil time.
func (s *ServerStatistics) BackoffInfo() *time.Time {
until, ok := s.backoffUntil.Load().(time.Time)
if ok {
return &until, s.blacklisted.Load()
return &until
}
return nil, s.blacklisted.Load()
return nil
}
// Blacklisted returns true if the server is blacklisted and false
@ -206,10 +268,33 @@ func (s *ServerStatistics) Blacklisted() bool {
return s.blacklisted.Load()
}
// RemoveBlacklist removes the blacklisted status from the server.
func (s *ServerStatistics) RemoveBlacklist() {
// AssumedOffline returns true if the server is assumed offline and false
// otherwise.
func (s *ServerStatistics) AssumedOffline() bool {
return s.assumedOffline.Load()
}
// removeBlacklist removes the blacklisted status from the server.
// Returns whether the server was blacklisted.
func (s *ServerStatistics) removeBlacklist() bool {
var wasBlacklisted bool
if s.Blacklisted() {
wasBlacklisted = true
_ = s.statistics.DB.RemoveServerFromBlacklist(s.serverName)
}
s.cancel()
s.backoffCount.Store(0)
return wasBlacklisted
}
// removeAssumedOffline removes the assumed offline status from the server.
func (s *ServerStatistics) removeAssumedOffline() {
if s.AssumedOffline() {
_ = s.statistics.DB.RemoveServerAssumedOffline(context.Background(), s.serverName)
}
s.assumedOffline.Store(false)
}
// SuccessCount returns the number of successful requests. This is
@ -217,3 +302,46 @@ func (s *ServerStatistics) RemoveBlacklist() {
func (s *ServerStatistics) SuccessCount() uint32 {
return s.successCounter.Load()
}
// KnownRelayServers returns the list of relay servers associated with this
// server.
func (s *ServerStatistics) KnownRelayServers() []gomatrixserverlib.ServerName {
s.relayMutex.Lock()
defer s.relayMutex.Unlock()
return s.knownRelayServers
}
func (s *ServerStatistics) AddRelayServers(relayServers []gomatrixserverlib.ServerName) {
seenSet := make(map[gomatrixserverlib.ServerName]bool)
uniqueList := []gomatrixserverlib.ServerName{}
for _, srv := range relayServers {
if seenSet[srv] {
continue
}
seenSet[srv] = true
uniqueList = append(uniqueList, srv)
}
err := s.statistics.DB.P2PAddRelayServersForServer(context.Background(), s.serverName, uniqueList)
if err != nil {
logrus.WithError(err).Errorf("Failed to add relay servers for %q. Servers: %v", s.serverName, uniqueList)
return
}
for _, newServer := range uniqueList {
alreadyKnown := false
knownRelayServers := s.KnownRelayServers()
for _, srv := range knownRelayServers {
if srv == newServer {
alreadyKnown = true
}
}
if !alreadyKnown {
{
s.relayMutex.Lock()
s.knownRelayServers = append(s.knownRelayServers, newServer)
s.relayMutex.Unlock()
}
}
}
}

View file

@ -4,17 +4,26 @@ import (
"math"
"testing"
"time"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib"
"github.com/stretchr/testify/assert"
)
const (
FailuresUntilAssumedOffline = 3
FailuresUntilBlacklist = 8
)
func TestBackoff(t *testing.T) {
stats := NewStatistics(nil, 7)
stats := NewStatistics(nil, FailuresUntilBlacklist, FailuresUntilAssumedOffline)
server := ServerStatistics{
statistics: &stats,
serverName: "test.com",
}
// Start by checking that counting successes works.
server.Success()
server.Success(SendDirect)
if successes := server.SuccessCount(); successes != 1 {
t.Fatalf("Expected success count 1, got %d", successes)
}
@ -31,9 +40,8 @@ func TestBackoff(t *testing.T) {
// side effects since a backoff is already in progress. If it does
// then we'll fail.
until, blacklisted := server.Failure()
// Get the duration.
_, blacklist := server.BackoffInfo()
blacklist := server.Blacklisted()
assumedOffline := server.AssumedOffline()
duration := time.Until(until)
// Unset the backoff, or otherwise our next call will think that
@ -41,16 +49,43 @@ func TestBackoff(t *testing.T) {
server.cancel()
server.backoffStarted.Store(false)
if i >= stats.FailuresUntilAssumedOffline {
if !assumedOffline {
t.Fatalf("Backoff %d should have resulted in assuming the destination was offline but didn't", i)
}
}
// Check if we should be assumed offline by now.
if i >= stats.FailuresUntilAssumedOffline {
if !assumedOffline {
t.Fatalf("Backoff %d should have resulted in assumed offline but didn't", i)
} else {
t.Logf("Backoff %d is assumed offline as expected", i)
}
} else {
if assumedOffline {
t.Fatalf("Backoff %d should not have resulted in assumed offline but did", i)
} else {
t.Logf("Backoff %d is not assumed offline as expected", i)
}
}
// Check if we should be blacklisted by now.
if i >= stats.FailuresUntilBlacklist {
if !blacklist {
t.Fatalf("Backoff %d should have resulted in blacklist but didn't", i)
} else if blacklist != blacklisted {
t.Fatalf("BackoffInfo and Failure returned different blacklist values")
t.Fatalf("Blacklisted and Failure returned different blacklist values")
} else {
t.Logf("Backoff %d is blacklisted as expected", i)
continue
}
} else {
if blacklist {
t.Fatalf("Backoff %d should not have resulted in blacklist but did", i)
} else {
t.Logf("Backoff %d is not blacklisted as expected", i)
}
}
// Check if the duration is what we expect.
@ -69,3 +104,14 @@ func TestBackoff(t *testing.T) {
}
}
}
func TestRelayServersListing(t *testing.T) {
stats := NewStatistics(test.NewInMemoryFederationDatabase(), FailuresUntilBlacklist, FailuresUntilAssumedOffline)
server := ServerStatistics{statistics: &stats}
server.AddRelayServers([]gomatrixserverlib.ServerName{"relay1", "relay1", "relay2"})
relayServers := server.KnownRelayServers()
assert.Equal(t, []gomatrixserverlib.ServerName{"relay1", "relay2"}, relayServers)
server.AddRelayServers([]gomatrixserverlib.ServerName{"relay1", "relay1", "relay2"})
relayServers = server.KnownRelayServers()
assert.Equal(t, []gomatrixserverlib.ServerName{"relay1", "relay2"}, relayServers)
}

View file

@ -20,11 +20,12 @@ import (
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/dendrite/federationapi/storage/shared"
"github.com/matrix-org/dendrite/federationapi/storage/shared/receipt"
"github.com/matrix-org/dendrite/federationapi/types"
)
type Database interface {
P2PDatabase
gomatrixserverlib.KeyDatabase
UpdateRoom(ctx context.Context, roomID string, addHosts []types.JoinedHost, removeHosts []string, purgeRoomFirst bool) (joinedHosts []types.JoinedHost, err error)
@ -34,16 +35,16 @@ type Database interface {
// GetJoinedHostsForRooms returns the complete set of servers in the rooms given.
GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]gomatrixserverlib.ServerName, error)
StoreJSON(ctx context.Context, js string) (*shared.Receipt, error)
StoreJSON(ctx context.Context, js string) (*receipt.Receipt, error)
GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent, err error)
GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*shared.Receipt]*gomatrixserverlib.EDU, err error)
GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent, err error)
GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*receipt.Receipt]*gomatrixserverlib.EDU, err error)
AssociatePDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt) error
AssociateEDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error
AssociatePDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, dbReceipt *receipt.Receipt) error
AssociateEDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, dbReceipt *receipt.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error
CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error
CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error
CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*receipt.Receipt) error
CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*receipt.Receipt) error
GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
@ -54,6 +55,18 @@ type Database interface {
RemoveAllServersFromBlacklist() error
IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error)
// Adds the server to the list of assumed offline servers.
// If the server already exists in the table, nothing happens and returns success.
SetServerAssumedOffline(ctx context.Context, serverName gomatrixserverlib.ServerName) error
// Removes the server from the list of assumed offline servers.
// If the server doesn't exist in the table, nothing happens and returns success.
RemoveServerAssumedOffline(ctx context.Context, serverName gomatrixserverlib.ServerName) error
// Purges all entries from the assumed offline table.
RemoveAllServersAssumedOffline(ctx context.Context) error
// Gets whether the provided server is present in the table.
// If it is present, returns true. If not, returns false.
IsServerAssumedOffline(ctx context.Context, serverName gomatrixserverlib.ServerName) (bool, error)
AddOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error
RenewOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error
GetOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.OutboundPeek, error)
@ -74,3 +87,21 @@ type Database interface {
PurgeRoom(ctx context.Context, roomID string) error
}
type P2PDatabase interface {
// Stores the given list of servers as relay servers for the provided destination server.
// Providing duplicates will only lead to a single entry and won't lead to an error.
P2PAddRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error
// Get the list of relay servers associated with the provided destination server.
// If no entry exists in the table, an empty list is returned and does not result in an error.
P2PGetRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName) ([]gomatrixserverlib.ServerName, error)
// Deletes any entries for the provided destination server that match the provided relayServers list.
// If any of the provided servers don't match an entry, nothing happens and no error is returned.
P2PRemoveRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error
// Deletes all entries for the provided destination server.
// If the destination server doesn't exist in the table, nothing happens and no error is returned.
P2PRemoveAllRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName) error
}

View file

@ -0,0 +1,107 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
)
const assumedOfflineSchema = `
CREATE TABLE IF NOT EXISTS federationsender_assumed_offline(
-- The assumed offline server name
server_name TEXT PRIMARY KEY NOT NULL
);
`
const insertAssumedOfflineSQL = "" +
"INSERT INTO federationsender_assumed_offline (server_name) VALUES ($1)" +
" ON CONFLICT DO NOTHING"
const selectAssumedOfflineSQL = "" +
"SELECT server_name FROM federationsender_assumed_offline WHERE server_name = $1"
const deleteAssumedOfflineSQL = "" +
"DELETE FROM federationsender_assumed_offline WHERE server_name = $1"
const deleteAllAssumedOfflineSQL = "" +
"TRUNCATE federationsender_assumed_offline"
type assumedOfflineStatements struct {
db *sql.DB
insertAssumedOfflineStmt *sql.Stmt
selectAssumedOfflineStmt *sql.Stmt
deleteAssumedOfflineStmt *sql.Stmt
deleteAllAssumedOfflineStmt *sql.Stmt
}
func NewPostgresAssumedOfflineTable(db *sql.DB) (s *assumedOfflineStatements, err error) {
s = &assumedOfflineStatements{
db: db,
}
_, err = db.Exec(assumedOfflineSchema)
if err != nil {
return
}
return s, sqlutil.StatementList{
{&s.insertAssumedOfflineStmt, insertAssumedOfflineSQL},
{&s.selectAssumedOfflineStmt, selectAssumedOfflineSQL},
{&s.deleteAssumedOfflineStmt, deleteAssumedOfflineSQL},
{&s.deleteAllAssumedOfflineStmt, deleteAllAssumedOfflineSQL},
}.Prepare(db)
}
func (s *assumedOfflineStatements) InsertAssumedOffline(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) error {
stmt := sqlutil.TxStmt(txn, s.insertAssumedOfflineStmt)
_, err := stmt.ExecContext(ctx, serverName)
return err
}
func (s *assumedOfflineStatements) SelectAssumedOffline(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) (bool, error) {
stmt := sqlutil.TxStmt(txn, s.selectAssumedOfflineStmt)
res, err := stmt.QueryContext(ctx, serverName)
if err != nil {
return false, err
}
defer res.Close() // nolint:errcheck
// The query will return the server name if the server is assume offline, and
// will return no rows if not. By calling Next, we find out if a row was
// returned or not - we don't care about the value itself.
return res.Next(), nil
}
func (s *assumedOfflineStatements) DeleteAssumedOffline(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteAssumedOfflineStmt)
_, err := stmt.ExecContext(ctx, serverName)
return err
}
func (s *assumedOfflineStatements) DeleteAllAssumedOffline(
ctx context.Context, txn *sql.Tx,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteAllAssumedOfflineStmt)
_, err := stmt.ExecContext(ctx)
return err
}

View file

@ -0,0 +1,137 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
)
const relayServersSchema = `
CREATE TABLE IF NOT EXISTS federationsender_relay_servers (
-- The destination server name
server_name TEXT NOT NULL,
-- The relay server name for a given destination
relay_server_name TEXT NOT NULL,
UNIQUE (server_name, relay_server_name)
);
CREATE INDEX IF NOT EXISTS federationsender_relay_servers_server_name_idx
ON federationsender_relay_servers (server_name);
`
const insertRelayServersSQL = "" +
"INSERT INTO federationsender_relay_servers (server_name, relay_server_name) VALUES ($1, $2)" +
" ON CONFLICT DO NOTHING"
const selectRelayServersSQL = "" +
"SELECT relay_server_name FROM federationsender_relay_servers WHERE server_name = $1"
const deleteRelayServersSQL = "" +
"DELETE FROM federationsender_relay_servers WHERE server_name = $1 AND relay_server_name = ANY($2)"
const deleteAllRelayServersSQL = "" +
"DELETE FROM federationsender_relay_servers WHERE server_name = $1"
type relayServersStatements struct {
db *sql.DB
insertRelayServersStmt *sql.Stmt
selectRelayServersStmt *sql.Stmt
deleteRelayServersStmt *sql.Stmt
deleteAllRelayServersStmt *sql.Stmt
}
func NewPostgresRelayServersTable(db *sql.DB) (s *relayServersStatements, err error) {
s = &relayServersStatements{
db: db,
}
_, err = db.Exec(relayServersSchema)
if err != nil {
return
}
return s, sqlutil.StatementList{
{&s.insertRelayServersStmt, insertRelayServersSQL},
{&s.selectRelayServersStmt, selectRelayServersSQL},
{&s.deleteRelayServersStmt, deleteRelayServersSQL},
{&s.deleteAllRelayServersStmt, deleteAllRelayServersSQL},
}.Prepare(db)
}
func (s *relayServersStatements) InsertRelayServers(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
relayServers []gomatrixserverlib.ServerName,
) error {
for _, relayServer := range relayServers {
stmt := sqlutil.TxStmt(txn, s.insertRelayServersStmt)
if _, err := stmt.ExecContext(ctx, serverName, relayServer); err != nil {
return err
}
}
return nil
}
func (s *relayServersStatements) SelectRelayServers(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
) ([]gomatrixserverlib.ServerName, error) {
stmt := sqlutil.TxStmt(txn, s.selectRelayServersStmt)
rows, err := stmt.QueryContext(ctx, serverName)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "SelectRelayServers: rows.close() failed")
var result []gomatrixserverlib.ServerName
for rows.Next() {
var relayServer string
if err = rows.Scan(&relayServer); err != nil {
return nil, err
}
result = append(result, gomatrixserverlib.ServerName(relayServer))
}
return result, nil
}
func (s *relayServersStatements) DeleteRelayServers(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
relayServers []gomatrixserverlib.ServerName,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteRelayServersStmt)
_, err := stmt.ExecContext(ctx, serverName, pq.Array(relayServers))
return err
}
func (s *relayServersStatements) DeleteAllRelayServers(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteAllRelayServersStmt)
if _, err := stmt.ExecContext(ctx, serverName); err != nil {
return err
}
return nil
}

View file

@ -62,6 +62,14 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
if err != nil {
return nil, err
}
assumedOffline, err := NewPostgresAssumedOfflineTable(d.db)
if err != nil {
return nil, err
}
relayServers, err := NewPostgresRelayServersTable(d.db)
if err != nil {
return nil, err
}
inboundPeeks, err := NewPostgresInboundPeeksTable(d.db)
if err != nil {
return nil, err
@ -104,6 +112,8 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
FederationQueueEDUs: queueEDUs,
FederationQueueJSON: queueJSON,
FederationBlacklist: blacklist,
FederationAssumedOffline: assumedOffline,
FederationRelayServers: relayServers,
FederationInboundPeeks: inboundPeeks,
FederationOutboundPeeks: outboundPeeks,
NotaryServerKeysJSON: notaryJSON,

View file

@ -0,0 +1,42 @@
// Copyright 2023 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// A Receipt contains the NIDs of a call to GetNextTransactionPDUs/EDUs.
// We don't actually export the NIDs but we need the caller to be able
// to pass them back so that we can clean up if the transaction sends
// successfully.
package receipt
import "fmt"
// Receipt is a wrapper type used to represent a nid that corresponds to a unique row entry
// in some database table.
// The internal nid value cannot be modified after a Receipt has been created.
// This guarantees a receipt will always refer to the same table entry that it was created
// to represent.
type Receipt struct {
nid int64
}
func NewReceipt(nid int64) Receipt {
return Receipt{nid: nid}
}
func (r *Receipt) GetNID() int64 {
return r.nid
}
func (r *Receipt) String() string {
return fmt.Sprintf("%d", r.nid)
}

View file

@ -20,6 +20,7 @@ import (
"fmt"
"time"
"github.com/matrix-org/dendrite/federationapi/storage/shared/receipt"
"github.com/matrix-org/dendrite/federationapi/storage/tables"
"github.com/matrix-org/dendrite/federationapi/types"
"github.com/matrix-org/dendrite/internal/caching"
@ -37,6 +38,8 @@ type Database struct {
FederationQueueJSON tables.FederationQueueJSON
FederationJoinedHosts tables.FederationJoinedHosts
FederationBlacklist tables.FederationBlacklist
FederationAssumedOffline tables.FederationAssumedOffline
FederationRelayServers tables.FederationRelayServers
FederationOutboundPeeks tables.FederationOutboundPeeks
FederationInboundPeeks tables.FederationInboundPeeks
NotaryServerKeysJSON tables.FederationNotaryServerKeysJSON
@ -44,22 +47,6 @@ type Database struct {
ServerSigningKeys tables.FederationServerSigningKeys
}
// An Receipt contains the NIDs of a call to GetNextTransactionPDUs/EDUs.
// We don't actually export the NIDs but we need the caller to be able
// to pass them back so that we can clean up if the transaction sends
// successfully.
type Receipt struct {
nid int64
}
func NewReceipt(nid int64) Receipt {
return Receipt{nid: nid}
}
func (r *Receipt) String() string {
return fmt.Sprintf("%d", r.nid)
}
// UpdateRoom updates the joined hosts for a room and returns what the joined
// hosts were before the update, or nil if this was a duplicate message.
// This is called when we receive a message from kafka, so we pass in
@ -113,11 +100,18 @@ func (d *Database) GetJoinedHosts(
// GetAllJoinedHosts returns the currently joined hosts for
// all rooms known to the federation sender.
// Returns an error if something goes wrong.
func (d *Database) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) {
func (d *Database) GetAllJoinedHosts(
ctx context.Context,
) ([]gomatrixserverlib.ServerName, error) {
return d.FederationJoinedHosts.SelectAllJoinedHosts(ctx)
}
func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]gomatrixserverlib.ServerName, error) {
func (d *Database) GetJoinedHostsForRooms(
ctx context.Context,
roomIDs []string,
excludeSelf,
excludeBlacklisted bool,
) ([]gomatrixserverlib.ServerName, error) {
servers, err := d.FederationJoinedHosts.SelectJoinedHostsForRooms(ctx, roomIDs, excludeBlacklisted)
if err != nil {
return nil, err
@ -139,7 +133,7 @@ func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string,
// metadata entries.
func (d *Database) StoreJSON(
ctx context.Context, js string,
) (*Receipt, error) {
) (*receipt.Receipt, error) {
var nid int64
var err error
_ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
@ -149,18 +143,21 @@ func (d *Database) StoreJSON(
if err != nil {
return nil, fmt.Errorf("d.insertQueueJSON: %w", err)
}
return &Receipt{
nid: nid,
}, nil
newReceipt := receipt.NewReceipt(nid)
return &newReceipt, nil
}
func (d *Database) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error {
func (d *Database) AddServerToBlacklist(
serverName gomatrixserverlib.ServerName,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationBlacklist.InsertBlacklist(context.TODO(), txn, serverName)
})
}
func (d *Database) RemoveServerFromBlacklist(serverName gomatrixserverlib.ServerName) error {
func (d *Database) RemoveServerFromBlacklist(
serverName gomatrixserverlib.ServerName,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationBlacklist.DeleteBlacklist(context.TODO(), txn, serverName)
})
@ -172,51 +169,166 @@ func (d *Database) RemoveAllServersFromBlacklist() error {
})
}
func (d *Database) IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error) {
func (d *Database) IsServerBlacklisted(
serverName gomatrixserverlib.ServerName,
) (bool, error) {
return d.FederationBlacklist.SelectBlacklist(context.TODO(), nil, serverName)
}
func (d *Database) AddOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error {
func (d *Database) SetServerAssumedOffline(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationAssumedOffline.InsertAssumedOffline(ctx, txn, serverName)
})
}
func (d *Database) RemoveServerAssumedOffline(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationAssumedOffline.DeleteAssumedOffline(ctx, txn, serverName)
})
}
func (d *Database) RemoveAllServersAssumedOffline(
ctx context.Context,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationAssumedOffline.DeleteAllAssumedOffline(ctx, txn)
})
}
func (d *Database) IsServerAssumedOffline(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
) (bool, error) {
return d.FederationAssumedOffline.SelectAssumedOffline(ctx, nil, serverName)
}
func (d *Database) P2PAddRelayServersForServer(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
relayServers []gomatrixserverlib.ServerName,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationRelayServers.InsertRelayServers(ctx, txn, serverName, relayServers)
})
}
func (d *Database) P2PGetRelayServersForServer(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
) ([]gomatrixserverlib.ServerName, error) {
return d.FederationRelayServers.SelectRelayServers(ctx, nil, serverName)
}
func (d *Database) P2PRemoveRelayServersForServer(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
relayServers []gomatrixserverlib.ServerName,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationRelayServers.DeleteRelayServers(ctx, txn, serverName, relayServers)
})
}
func (d *Database) P2PRemoveAllRelayServersForServer(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationRelayServers.DeleteAllRelayServers(ctx, txn, serverName)
})
}
func (d *Database) AddOutboundPeek(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
roomID string,
peekID string,
renewalInterval int64,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationOutboundPeeks.InsertOutboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval)
})
}
func (d *Database) RenewOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error {
func (d *Database) RenewOutboundPeek(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
roomID string,
peekID string,
renewalInterval int64,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationOutboundPeeks.RenewOutboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval)
})
}
func (d *Database) GetOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.OutboundPeek, error) {
func (d *Database) GetOutboundPeek(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
roomID,
peekID string,
) (*types.OutboundPeek, error) {
return d.FederationOutboundPeeks.SelectOutboundPeek(ctx, nil, serverName, roomID, peekID)
}
func (d *Database) GetOutboundPeeks(ctx context.Context, roomID string) ([]types.OutboundPeek, error) {
func (d *Database) GetOutboundPeeks(
ctx context.Context,
roomID string,
) ([]types.OutboundPeek, error) {
return d.FederationOutboundPeeks.SelectOutboundPeeks(ctx, nil, roomID)
}
func (d *Database) AddInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error {
func (d *Database) AddInboundPeek(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
roomID string,
peekID string,
renewalInterval int64,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationInboundPeeks.InsertInboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval)
})
}
func (d *Database) RenewInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error {
func (d *Database) RenewInboundPeek(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
roomID string,
peekID string,
renewalInterval int64,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationInboundPeeks.RenewInboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval)
})
}
func (d *Database) GetInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.InboundPeek, error) {
func (d *Database) GetInboundPeek(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
roomID string,
peekID string,
) (*types.InboundPeek, error) {
return d.FederationInboundPeeks.SelectInboundPeek(ctx, nil, serverName, roomID, peekID)
}
func (d *Database) GetInboundPeeks(ctx context.Context, roomID string) ([]types.InboundPeek, error) {
func (d *Database) GetInboundPeeks(
ctx context.Context,
roomID string,
) ([]types.InboundPeek, error) {
return d.FederationInboundPeeks.SelectInboundPeeks(ctx, nil, roomID)
}
func (d *Database) UpdateNotaryKeys(ctx context.Context, serverName gomatrixserverlib.ServerName, serverKeys gomatrixserverlib.ServerKeys) error {
func (d *Database) UpdateNotaryKeys(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
serverKeys gomatrixserverlib.ServerKeys,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
validUntil := serverKeys.ValidUntilTS
// Servers MUST use the lesser of this field and 7 days into the future when determining if a key is valid.
@ -251,7 +363,9 @@ func (d *Database) UpdateNotaryKeys(ctx context.Context, serverName gomatrixserv
}
func (d *Database) GetNotaryKeys(
ctx context.Context, serverName gomatrixserverlib.ServerName, optKeyIDs []gomatrixserverlib.KeyID,
ctx context.Context,
serverName gomatrixserverlib.ServerName,
optKeyIDs []gomatrixserverlib.KeyID,
) (sks []gomatrixserverlib.ServerKeys, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
sks, err = d.NotaryServerKeysMetadata.SelectKeys(ctx, txn, serverName, optKeyIDs)

View file

@ -22,6 +22,7 @@ import (
"fmt"
"time"
"github.com/matrix-org/dendrite/federationapi/storage/shared/receipt"
"github.com/matrix-org/gomatrixserverlib"
)
@ -41,7 +42,7 @@ var defaultExpireEDUTypes = map[string]time.Duration{
func (d *Database) AssociateEDUWithDestinations(
ctx context.Context,
destinations map[gomatrixserverlib.ServerName]struct{},
receipt *Receipt,
dbReceipt *receipt.Receipt,
eduType string,
expireEDUTypes map[string]time.Duration,
) error {
@ -62,12 +63,12 @@ func (d *Database) AssociateEDUWithDestinations(
var err error
for destination := range destinations {
err = d.FederationQueueEDUs.InsertQueueEDU(
ctx, // context
txn, // SQL transaction
eduType, // EDU type for coalescing
destination, // destination server name
receipt.nid, // NID from the federationapi_queue_json table
expiresAt, // The timestamp this EDU will expire
ctx, // context
txn, // SQL transaction
eduType, // EDU type for coalescing
destination, // destination server name
dbReceipt.GetNID(), // NID from the federationapi_queue_json table
expiresAt, // The timestamp this EDU will expire
)
}
return err
@ -81,10 +82,10 @@ func (d *Database) GetPendingEDUs(
serverName gomatrixserverlib.ServerName,
limit int,
) (
edus map[*Receipt]*gomatrixserverlib.EDU,
edus map[*receipt.Receipt]*gomatrixserverlib.EDU,
err error,
) {
edus = make(map[*Receipt]*gomatrixserverlib.EDU)
edus = make(map[*receipt.Receipt]*gomatrixserverlib.EDU)
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
nids, err := d.FederationQueueEDUs.SelectQueueEDUs(ctx, txn, serverName, limit)
if err != nil {
@ -94,7 +95,8 @@ func (d *Database) GetPendingEDUs(
retrieve := make([]int64, 0, len(nids))
for _, nid := range nids {
if edu, ok := d.Cache.GetFederationQueuedEDU(nid); ok {
edus[&Receipt{nid}] = edu
newReceipt := receipt.NewReceipt(nid)
edus[&newReceipt] = edu
} else {
retrieve = append(retrieve, nid)
}
@ -110,7 +112,8 @@ func (d *Database) GetPendingEDUs(
if err := json.Unmarshal(blob, &event); err != nil {
return fmt.Errorf("json.Unmarshal: %w", err)
}
edus[&Receipt{nid}] = &event
newReceipt := receipt.NewReceipt(nid)
edus[&newReceipt] = &event
d.Cache.StoreFederationQueuedEDU(nid, &event)
}
@ -124,7 +127,7 @@ func (d *Database) GetPendingEDUs(
func (d *Database) CleanEDUs(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
receipts []*Receipt,
receipts []*receipt.Receipt,
) error {
if len(receipts) == 0 {
return errors.New("expected receipt")
@ -132,7 +135,7 @@ func (d *Database) CleanEDUs(
nids := make([]int64, len(receipts))
for i := range receipts {
nids[i] = receipts[i].nid
nids[i] = receipts[i].GetNID()
}
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {

View file

@ -21,6 +21,7 @@ import (
"errors"
"fmt"
"github.com/matrix-org/dendrite/federationapi/storage/shared/receipt"
"github.com/matrix-org/gomatrixserverlib"
)
@ -30,17 +31,17 @@ import (
func (d *Database) AssociatePDUWithDestinations(
ctx context.Context,
destinations map[gomatrixserverlib.ServerName]struct{},
receipt *Receipt,
dbReceipt *receipt.Receipt,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
var err error
for destination := range destinations {
err = d.FederationQueuePDUs.InsertQueuePDU(
ctx, // context
txn, // SQL transaction
"", // transaction ID
destination, // destination server name
receipt.nid, // NID from the federationapi_queue_json table
ctx, // context
txn, // SQL transaction
"", // transaction ID
destination, // destination server name
dbReceipt.GetNID(), // NID from the federationapi_queue_json table
)
}
return err
@ -54,7 +55,7 @@ func (d *Database) GetPendingPDUs(
serverName gomatrixserverlib.ServerName,
limit int,
) (
events map[*Receipt]*gomatrixserverlib.HeaderedEvent,
events map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent,
err error,
) {
// Strictly speaking this doesn't need to be using the writer
@ -62,7 +63,7 @@ func (d *Database) GetPendingPDUs(
// a guarantee of transactional isolation, it's actually useful
// to know in SQLite mode that nothing else is trying to modify
// the database.
events = make(map[*Receipt]*gomatrixserverlib.HeaderedEvent)
events = make(map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent)
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
nids, err := d.FederationQueuePDUs.SelectQueuePDUs(ctx, txn, serverName, limit)
if err != nil {
@ -72,7 +73,8 @@ func (d *Database) GetPendingPDUs(
retrieve := make([]int64, 0, len(nids))
for _, nid := range nids {
if event, ok := d.Cache.GetFederationQueuedPDU(nid); ok {
events[&Receipt{nid}] = event
newReceipt := receipt.NewReceipt(nid)
events[&newReceipt] = event
} else {
retrieve = append(retrieve, nid)
}
@ -88,7 +90,8 @@ func (d *Database) GetPendingPDUs(
if err := json.Unmarshal(blob, &event); err != nil {
return fmt.Errorf("json.Unmarshal: %w", err)
}
events[&Receipt{nid}] = &event
newReceipt := receipt.NewReceipt(nid)
events[&newReceipt] = &event
d.Cache.StoreFederationQueuedPDU(nid, &event)
}
@ -103,7 +106,7 @@ func (d *Database) GetPendingPDUs(
func (d *Database) CleanPDUs(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
receipts []*Receipt,
receipts []*receipt.Receipt,
) error {
if len(receipts) == 0 {
return errors.New("expected receipt")
@ -111,7 +114,7 @@ func (d *Database) CleanPDUs(
nids := make([]int64, len(receipts))
for i := range receipts {
nids[i] = receipts[i].nid
nids[i] = receipts[i].GetNID()
}
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {

View file

@ -0,0 +1,107 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
)
const assumedOfflineSchema = `
CREATE TABLE IF NOT EXISTS federationsender_assumed_offline(
-- The assumed offline server name
server_name TEXT PRIMARY KEY NOT NULL
);
`
const insertAssumedOfflineSQL = "" +
"INSERT INTO federationsender_assumed_offline (server_name) VALUES ($1)" +
" ON CONFLICT DO NOTHING"
const selectAssumedOfflineSQL = "" +
"SELECT server_name FROM federationsender_assumed_offline WHERE server_name = $1"
const deleteAssumedOfflineSQL = "" +
"DELETE FROM federationsender_assumed_offline WHERE server_name = $1"
const deleteAllAssumedOfflineSQL = "" +
"DELETE FROM federationsender_assumed_offline"
type assumedOfflineStatements struct {
db *sql.DB
insertAssumedOfflineStmt *sql.Stmt
selectAssumedOfflineStmt *sql.Stmt
deleteAssumedOfflineStmt *sql.Stmt
deleteAllAssumedOfflineStmt *sql.Stmt
}
func NewSQLiteAssumedOfflineTable(db *sql.DB) (s *assumedOfflineStatements, err error) {
s = &assumedOfflineStatements{
db: db,
}
_, err = db.Exec(assumedOfflineSchema)
if err != nil {
return
}
return s, sqlutil.StatementList{
{&s.insertAssumedOfflineStmt, insertAssumedOfflineSQL},
{&s.selectAssumedOfflineStmt, selectAssumedOfflineSQL},
{&s.deleteAssumedOfflineStmt, deleteAssumedOfflineSQL},
{&s.deleteAllAssumedOfflineStmt, deleteAllAssumedOfflineSQL},
}.Prepare(db)
}
func (s *assumedOfflineStatements) InsertAssumedOffline(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) error {
stmt := sqlutil.TxStmt(txn, s.insertAssumedOfflineStmt)
_, err := stmt.ExecContext(ctx, serverName)
return err
}
func (s *assumedOfflineStatements) SelectAssumedOffline(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) (bool, error) {
stmt := sqlutil.TxStmt(txn, s.selectAssumedOfflineStmt)
res, err := stmt.QueryContext(ctx, serverName)
if err != nil {
return false, err
}
defer res.Close() // nolint:errcheck
// The query will return the server name if the server is assume offline, and
// will return no rows if not. By calling Next, we find out if a row was
// returned or not - we don't care about the value itself.
return res.Next(), nil
}
func (s *assumedOfflineStatements) DeleteAssumedOffline(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteAssumedOfflineStmt)
_, err := stmt.ExecContext(ctx, serverName)
return err
}
func (s *assumedOfflineStatements) DeleteAllAssumedOffline(
ctx context.Context, txn *sql.Tx,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteAllAssumedOfflineStmt)
_, err := stmt.ExecContext(ctx)
return err
}

View file

@ -0,0 +1,148 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"strings"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
)
const relayServersSchema = `
CREATE TABLE IF NOT EXISTS federationsender_relay_servers (
-- The destination server name
server_name TEXT NOT NULL,
-- The relay server name for a given destination
relay_server_name TEXT NOT NULL,
UNIQUE (server_name, relay_server_name)
);
CREATE INDEX IF NOT EXISTS federationsender_relay_servers_server_name_idx
ON federationsender_relay_servers (server_name);
`
const insertRelayServersSQL = "" +
"INSERT INTO federationsender_relay_servers (server_name, relay_server_name) VALUES ($1, $2)" +
" ON CONFLICT DO NOTHING"
const selectRelayServersSQL = "" +
"SELECT relay_server_name FROM federationsender_relay_servers WHERE server_name = $1"
const deleteRelayServersSQL = "" +
"DELETE FROM federationsender_relay_servers WHERE server_name = $1 AND relay_server_name IN ($2)"
const deleteAllRelayServersSQL = "" +
"DELETE FROM federationsender_relay_servers WHERE server_name = $1"
type relayServersStatements struct {
db *sql.DB
insertRelayServersStmt *sql.Stmt
selectRelayServersStmt *sql.Stmt
// deleteRelayServersStmt *sql.Stmt - prepared at runtime due to variadic
deleteAllRelayServersStmt *sql.Stmt
}
func NewSQLiteRelayServersTable(db *sql.DB) (s *relayServersStatements, err error) {
s = &relayServersStatements{
db: db,
}
_, err = db.Exec(relayServersSchema)
if err != nil {
return
}
return s, sqlutil.StatementList{
{&s.insertRelayServersStmt, insertRelayServersSQL},
{&s.selectRelayServersStmt, selectRelayServersSQL},
{&s.deleteAllRelayServersStmt, deleteAllRelayServersSQL},
}.Prepare(db)
}
func (s *relayServersStatements) InsertRelayServers(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
relayServers []gomatrixserverlib.ServerName,
) error {
for _, relayServer := range relayServers {
stmt := sqlutil.TxStmt(txn, s.insertRelayServersStmt)
if _, err := stmt.ExecContext(ctx, serverName, relayServer); err != nil {
return err
}
}
return nil
}
func (s *relayServersStatements) SelectRelayServers(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
) ([]gomatrixserverlib.ServerName, error) {
stmt := sqlutil.TxStmt(txn, s.selectRelayServersStmt)
rows, err := stmt.QueryContext(ctx, serverName)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "SelectRelayServers: rows.close() failed")
var result []gomatrixserverlib.ServerName
for rows.Next() {
var relayServer string
if err = rows.Scan(&relayServer); err != nil {
return nil, err
}
result = append(result, gomatrixserverlib.ServerName(relayServer))
}
return result, nil
}
func (s *relayServersStatements) DeleteRelayServers(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
relayServers []gomatrixserverlib.ServerName,
) error {
deleteSQL := strings.Replace(deleteRelayServersSQL, "($2)", sqlutil.QueryVariadicOffset(len(relayServers), 1), 1)
deleteStmt, err := s.db.Prepare(deleteSQL)
if err != nil {
return err
}
stmt := sqlutil.TxStmt(txn, deleteStmt)
params := make([]interface{}, len(relayServers)+1)
params[0] = serverName
for i, v := range relayServers {
params[i+1] = v
}
_, err = stmt.ExecContext(ctx, params...)
return err
}
func (s *relayServersStatements) DeleteAllRelayServers(
ctx context.Context,
txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteAllRelayServersStmt)
if _, err := stmt.ExecContext(ctx, serverName); err != nil {
return err
}
return nil
}

View file

@ -1,5 +1,4 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@ -61,6 +60,14 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
if err != nil {
return nil, err
}
assumedOffline, err := NewSQLiteAssumedOfflineTable(d.db)
if err != nil {
return nil, err
}
relayServers, err := NewSQLiteRelayServersTable(d.db)
if err != nil {
return nil, err
}
outboundPeeks, err := NewSQLiteOutboundPeeksTable(d.db)
if err != nil {
return nil, err
@ -103,6 +110,8 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
FederationQueueEDUs: queueEDUs,
FederationQueueJSON: queueJSON,
FederationBlacklist: blacklist,
FederationAssumedOffline: assumedOffline,
FederationRelayServers: relayServers,
FederationOutboundPeeks: outboundPeeks,
FederationInboundPeeks: inboundPeeks,
NotaryServerKeysJSON: notaryKeys,

View file

@ -6,14 +6,13 @@ import (
"testing"
"time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/stretchr/testify/assert"
"github.com/matrix-org/dendrite/federationapi/storage"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/test/testrig"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/stretchr/testify/assert"
)
func mustCreateFederationDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
@ -246,3 +245,99 @@ func TestInboundPeeking(t *testing.T) {
assert.ElementsMatch(t, gotPeekIDs, peekIDs)
})
}
func TestServersAssumedOffline(t *testing.T) {
server1 := gomatrixserverlib.ServerName("server1")
server2 := gomatrixserverlib.ServerName("server2")
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, closeDB := mustCreateFederationDatabase(t, dbType)
defer closeDB()
// Set server1 & server2 as assumed offline.
err := db.SetServerAssumedOffline(context.Background(), server1)
assert.Nil(t, err)
err = db.SetServerAssumedOffline(context.Background(), server2)
assert.Nil(t, err)
// Ensure both servers are assumed offline.
isOffline, err := db.IsServerAssumedOffline(context.Background(), server1)
assert.Nil(t, err)
assert.True(t, isOffline)
isOffline, err = db.IsServerAssumedOffline(context.Background(), server2)
assert.Nil(t, err)
assert.True(t, isOffline)
// Set server1 as not assumed offline.
err = db.RemoveServerAssumedOffline(context.Background(), server1)
assert.Nil(t, err)
// Ensure both servers have correct state.
isOffline, err = db.IsServerAssumedOffline(context.Background(), server1)
assert.Nil(t, err)
assert.False(t, isOffline)
isOffline, err = db.IsServerAssumedOffline(context.Background(), server2)
assert.Nil(t, err)
assert.True(t, isOffline)
// Re-set server1 as assumed offline.
err = db.SetServerAssumedOffline(context.Background(), server1)
assert.Nil(t, err)
// Ensure server1 is assumed offline.
isOffline, err = db.IsServerAssumedOffline(context.Background(), server1)
assert.Nil(t, err)
assert.True(t, isOffline)
err = db.RemoveAllServersAssumedOffline(context.Background())
assert.Nil(t, err)
// Ensure both servers have correct state.
isOffline, err = db.IsServerAssumedOffline(context.Background(), server1)
assert.Nil(t, err)
assert.False(t, isOffline)
isOffline, err = db.IsServerAssumedOffline(context.Background(), server2)
assert.Nil(t, err)
assert.False(t, isOffline)
})
}
func TestRelayServersStored(t *testing.T) {
server := gomatrixserverlib.ServerName("server")
relayServer1 := gomatrixserverlib.ServerName("relayserver1")
relayServer2 := gomatrixserverlib.ServerName("relayserver2")
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, closeDB := mustCreateFederationDatabase(t, dbType)
defer closeDB()
err := db.P2PAddRelayServersForServer(context.Background(), server, []gomatrixserverlib.ServerName{relayServer1})
assert.Nil(t, err)
relayServers, err := db.P2PGetRelayServersForServer(context.Background(), server)
assert.Nil(t, err)
assert.Equal(t, relayServer1, relayServers[0])
err = db.P2PRemoveRelayServersForServer(context.Background(), server, []gomatrixserverlib.ServerName{relayServer1})
assert.Nil(t, err)
relayServers, err = db.P2PGetRelayServersForServer(context.Background(), server)
assert.Nil(t, err)
assert.Zero(t, len(relayServers))
err = db.P2PAddRelayServersForServer(context.Background(), server, []gomatrixserverlib.ServerName{relayServer1, relayServer2})
assert.Nil(t, err)
relayServers, err = db.P2PGetRelayServersForServer(context.Background(), server)
assert.Nil(t, err)
assert.Equal(t, relayServer1, relayServers[0])
assert.Equal(t, relayServer2, relayServers[1])
err = db.P2PRemoveAllRelayServersForServer(context.Background(), server)
assert.Nil(t, err)
relayServers, err = db.P2PGetRelayServersForServer(context.Background(), server)
assert.Nil(t, err)
assert.Zero(t, len(relayServers))
})
}

View file

@ -49,6 +49,19 @@ type FederationQueueJSON interface {
SelectQueueJSON(ctx context.Context, txn *sql.Tx, jsonNIDs []int64) (map[int64][]byte, error)
}
type FederationQueueTransactions interface {
InsertQueueTransaction(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error
DeleteQueueTransactions(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error
SelectQueueTransactions(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error)
SelectQueueTransactionCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error)
}
type FederationTransactionJSON interface {
InsertTransactionJSON(ctx context.Context, txn *sql.Tx, json string) (int64, error)
DeleteTransactionJSON(ctx context.Context, txn *sql.Tx, nids []int64) error
SelectTransactionJSON(ctx context.Context, txn *sql.Tx, jsonNIDs []int64) (map[int64][]byte, error)
}
type FederationJoinedHosts interface {
InsertJoinedHosts(ctx context.Context, txn *sql.Tx, roomID, eventID string, serverName gomatrixserverlib.ServerName) error
DeleteJoinedHosts(ctx context.Context, txn *sql.Tx, eventIDs []string) error
@ -66,6 +79,20 @@ type FederationBlacklist interface {
DeleteAllBlacklist(ctx context.Context, txn *sql.Tx) error
}
type FederationAssumedOffline interface {
InsertAssumedOffline(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) error
SelectAssumedOffline(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (bool, error)
DeleteAssumedOffline(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) error
DeleteAllAssumedOffline(ctx context.Context, txn *sql.Tx) error
}
type FederationRelayServers interface {
InsertRelayServers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error
SelectRelayServers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) ([]gomatrixserverlib.ServerName, error)
DeleteRelayServers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error
DeleteAllRelayServers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) error
}
type FederationOutboundPeeks interface {
InsertOutboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) (err error)
RenewOutboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) (err error)

View file

@ -0,0 +1,224 @@
package tables_test
import (
"context"
"database/sql"
"testing"
"github.com/matrix-org/dendrite/federationapi/storage/postgres"
"github.com/matrix-org/dendrite/federationapi/storage/sqlite3"
"github.com/matrix-org/dendrite/federationapi/storage/tables"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/gomatrixserverlib"
"github.com/stretchr/testify/assert"
)
const (
server1 = "server1"
server2 = "server2"
server3 = "server3"
server4 = "server4"
)
type RelayServersDatabase struct {
DB *sql.DB
Writer sqlutil.Writer
Table tables.FederationRelayServers
}
func mustCreateRelayServersTable(
t *testing.T,
dbType test.DBType,
) (database RelayServersDatabase, close func()) {
t.Helper()
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := sqlutil.Open(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, sqlutil.NewExclusiveWriter())
assert.NoError(t, err)
var tab tables.FederationRelayServers
switch dbType {
case test.DBTypePostgres:
tab, err = postgres.NewPostgresRelayServersTable(db)
assert.NoError(t, err)
case test.DBTypeSQLite:
tab, err = sqlite3.NewSQLiteRelayServersTable(db)
assert.NoError(t, err)
}
assert.NoError(t, err)
database = RelayServersDatabase{
DB: db,
Writer: sqlutil.NewDummyWriter(),
Table: tab,
}
return database, close
}
func Equal(a, b []gomatrixserverlib.ServerName) bool {
if len(a) != len(b) {
return false
}
for i, v := range a {
if v != b[i] {
return false
}
}
return true
}
func TestShouldInsertRelayServers(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateRelayServersTable(t, dbType)
defer close()
expectedRelayServers := []gomatrixserverlib.ServerName{server2, server3}
err := db.Table.InsertRelayServers(ctx, nil, server1, expectedRelayServers)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1)
if err != nil {
t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error())
}
if !Equal(relayServers, expectedRelayServers) {
t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers, relayServers)
}
})
}
func TestShouldInsertRelayServersWithDuplicates(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateRelayServersTable(t, dbType)
defer close()
insertRelayServers := []gomatrixserverlib.ServerName{server2, server2, server2, server3, server2}
expectedRelayServers := []gomatrixserverlib.ServerName{server2, server3}
err := db.Table.InsertRelayServers(ctx, nil, server1, insertRelayServers)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
// Insert the same list again, this shouldn't fail and should have no effect.
err = db.Table.InsertRelayServers(ctx, nil, server1, insertRelayServers)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1)
if err != nil {
t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error())
}
if !Equal(relayServers, expectedRelayServers) {
t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers, relayServers)
}
})
}
func TestShouldGetRelayServersUnknownDestination(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateRelayServersTable(t, dbType)
defer close()
// Query relay servers for a destination that doesn't exist in the table.
relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1)
if err != nil {
t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error())
}
if !Equal(relayServers, []gomatrixserverlib.ServerName{}) {
t.Fatalf("Expected: %v \nActual: %v", []gomatrixserverlib.ServerName{}, relayServers)
}
})
}
func TestShouldDeleteCorrectRelayServers(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateRelayServersTable(t, dbType)
defer close()
relayServers1 := []gomatrixserverlib.ServerName{server2, server3}
relayServers2 := []gomatrixserverlib.ServerName{server1, server3, server4}
err := db.Table.InsertRelayServers(ctx, nil, server1, relayServers1)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
err = db.Table.InsertRelayServers(ctx, nil, server2, relayServers2)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
err = db.Table.DeleteRelayServers(ctx, nil, server1, []gomatrixserverlib.ServerName{server2})
if err != nil {
t.Fatalf("Failed deleting relay servers for %s: %s", server1, err.Error())
}
err = db.Table.DeleteRelayServers(ctx, nil, server2, []gomatrixserverlib.ServerName{server1, server4})
if err != nil {
t.Fatalf("Failed deleting relay servers for %s: %s", server2, err.Error())
}
expectedRelayServers := []gomatrixserverlib.ServerName{server3}
relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1)
if err != nil {
t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error())
}
if !Equal(relayServers, expectedRelayServers) {
t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers, relayServers)
}
relayServers, err = db.Table.SelectRelayServers(ctx, nil, server2)
if err != nil {
t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error())
}
if !Equal(relayServers, expectedRelayServers) {
t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers, relayServers)
}
})
}
func TestShouldDeleteAllRelayServers(t *testing.T) {
ctx := context.Background()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateRelayServersTable(t, dbType)
defer close()
expectedRelayServers := []gomatrixserverlib.ServerName{server2, server3}
err := db.Table.InsertRelayServers(ctx, nil, server1, expectedRelayServers)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
err = db.Table.InsertRelayServers(ctx, nil, server2, expectedRelayServers)
if err != nil {
t.Fatalf("Failed inserting transaction: %s", err.Error())
}
err = db.Table.DeleteAllRelayServers(ctx, nil, server1)
if err != nil {
t.Fatalf("Failed deleting relay servers for %s: %s", server1, err.Error())
}
expectedRelayServers1 := []gomatrixserverlib.ServerName{}
relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1)
if err != nil {
t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error())
}
if !Equal(relayServers, expectedRelayServers1) {
t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers1, relayServers)
}
relayServers, err = db.Table.SelectRelayServers(ctx, nil, server2)
if err != nil {
t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error())
}
if !Equal(relayServers, expectedRelayServers) {
t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers, relayServers)
}
})
}