Merge federationapi, federationsender, signingkeyserver components (#2055)

* Initial federation sender -> federation API refactoring

* Move base into own package, avoids import cycle

* Fix build errors

* Fix tests

* Add signing key server tables

* Try to fold signing key server into federation API

* Fix dendritejs builds

* Update embedded interfaces

* Fix panic, fix lint error

* Update configs, docker

* Rename some things

* Reuse same keyring on the implementing side

* Fix federation tests, `NewBaseDendrite` can accept freeform options

* Fix build

* Update create_db, configs

* Name tables back

* Don't rename federationsender consumer for now
This commit is contained in:
Neil Alexander 2021-11-24 10:45:23 +00:00 committed by GitHub
parent 6e93531e94
commit ec716793eb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
136 changed files with 1211 additions and 1786 deletions

208
federationapi/api/api.go Normal file
View file

@ -0,0 +1,208 @@
package api
import (
"context"
"fmt"
"time"
"github.com/matrix-org/dendrite/federationapi/types"
"github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib"
)
// FederationClient is a subset of gomatrixserverlib.FederationClient functions which the fedsender
// implements as proxy calls, with built-in backoff/retries/etc. Errors returned from functions in
// this interface are of type FederationClientError
type FederationClient interface {
gomatrixserverlib.BackfillClient
gomatrixserverlib.FederatedStateClient
GetUserDevices(ctx context.Context, s gomatrixserverlib.ServerName, userID string) (res gomatrixserverlib.RespUserDevices, err error)
ClaimKeys(ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (res gomatrixserverlib.RespClaimKeys, err error)
QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error)
GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error)
MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error)
MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, r gomatrixserverlib.MSC2946SpacesRequest) (res gomatrixserverlib.MSC2946SpacesResponse, err error)
LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error)
}
// FederationClientError is returned from FederationClient methods in the event of a problem.
type FederationClientError struct {
Err string
RetryAfter time.Duration
Blacklisted bool
}
func (e *FederationClientError) Error() string {
return fmt.Sprintf("%s - (retry_after=%s, blacklisted=%v)", e.Err, e.RetryAfter.String(), e.Blacklisted)
}
// FederationInternalAPI is used to query information from the federation sender.
type FederationInternalAPI interface {
FederationClient
gomatrixserverlib.KeyDatabase
KeyRing() *gomatrixserverlib.KeyRing
QueryServerKeys(ctx context.Context, request *QueryServerKeysRequest, response *QueryServerKeysResponse) error
// PerformDirectoryLookup looks up a remote room ID from a room alias.
PerformDirectoryLookup(
ctx context.Context,
request *PerformDirectoryLookupRequest,
response *PerformDirectoryLookupResponse,
) error
// Query the server names of the joined hosts in a room.
// Unlike QueryJoinedHostsInRoom, this function returns a de-duplicated slice
// containing only the server names (without information for membership events).
// The response will include this server if they are joined to the room.
QueryJoinedHostServerNamesInRoom(
ctx context.Context,
request *QueryJoinedHostServerNamesInRoomRequest,
response *QueryJoinedHostServerNamesInRoomResponse,
) error
// Handle an instruction to make_join & send_join with a remote server.
PerformJoin(
ctx context.Context,
request *PerformJoinRequest,
response *PerformJoinResponse,
)
// Handle an instruction to peek a room on a remote server.
PerformOutboundPeek(
ctx context.Context,
request *PerformOutboundPeekRequest,
response *PerformOutboundPeekResponse,
) error
// Handle an instruction to make_leave & send_leave with a remote server.
PerformLeave(
ctx context.Context,
request *PerformLeaveRequest,
response *PerformLeaveResponse,
) error
// Handle sending an invite to a remote server.
PerformInvite(
ctx context.Context,
request *PerformInviteRequest,
response *PerformInviteResponse,
) error
// Notifies the federation sender that these servers may be online and to retry sending messages.
PerformServersAlive(
ctx context.Context,
request *PerformServersAliveRequest,
response *PerformServersAliveResponse,
) error
// Broadcasts an EDU to all servers in rooms we are joined to.
PerformBroadcastEDU(
ctx context.Context,
request *PerformBroadcastEDURequest,
response *PerformBroadcastEDUResponse,
) error
}
type QueryServerKeysRequest struct {
ServerName gomatrixserverlib.ServerName
KeyIDToCriteria map[gomatrixserverlib.KeyID]gomatrixserverlib.PublicKeyNotaryQueryCriteria
}
func (q *QueryServerKeysRequest) KeyIDs() []gomatrixserverlib.KeyID {
kids := make([]gomatrixserverlib.KeyID, len(q.KeyIDToCriteria))
i := 0
for keyID := range q.KeyIDToCriteria {
kids[i] = keyID
i++
}
return kids
}
type QueryServerKeysResponse struct {
ServerKeys []gomatrixserverlib.ServerKeys
}
type QueryPublicKeysRequest struct {
Requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp `json:"requests"`
}
type QueryPublicKeysResponse struct {
Results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult `json:"results"`
}
type PerformDirectoryLookupRequest struct {
RoomAlias string `json:"room_alias"`
ServerName gomatrixserverlib.ServerName `json:"server_name"`
}
type PerformDirectoryLookupResponse struct {
RoomID string `json:"room_id"`
ServerNames []gomatrixserverlib.ServerName `json:"server_names"`
}
type PerformJoinRequest struct {
RoomID string `json:"room_id"`
UserID string `json:"user_id"`
// The sorted list of servers to try. Servers will be tried sequentially, after de-duplication.
ServerNames types.ServerNames `json:"server_names"`
Content map[string]interface{} `json:"content"`
}
type PerformJoinResponse struct {
JoinedVia gomatrixserverlib.ServerName
LastError *gomatrix.HTTPError
}
type PerformOutboundPeekRequest struct {
RoomID string `json:"room_id"`
// The sorted list of servers to try. Servers will be tried sequentially, after de-duplication.
ServerNames types.ServerNames `json:"server_names"`
}
type PerformOutboundPeekResponse struct {
LastError *gomatrix.HTTPError
}
type PerformLeaveRequest struct {
RoomID string `json:"room_id"`
UserID string `json:"user_id"`
ServerNames types.ServerNames `json:"server_names"`
}
type PerformLeaveResponse struct {
}
type PerformInviteRequest struct {
RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"`
Event *gomatrixserverlib.HeaderedEvent `json:"event"`
InviteRoomState []gomatrixserverlib.InviteV2StrippedState `json:"invite_room_state"`
}
type PerformInviteResponse struct {
Event *gomatrixserverlib.HeaderedEvent `json:"event"`
}
type PerformServersAliveRequest struct {
Servers []gomatrixserverlib.ServerName
}
type PerformServersAliveResponse struct {
}
// QueryJoinedHostServerNamesInRoomRequest is a request to QueryJoinedHostServerNames
type QueryJoinedHostServerNamesInRoomRequest struct {
RoomID string `json:"room_id"`
}
// QueryJoinedHostServerNamesInRoomResponse is a response to QueryJoinedHostServerNames
type QueryJoinedHostServerNamesInRoomResponse struct {
ServerNames []gomatrixserverlib.ServerName `json:"server_names"`
}
type PerformBroadcastEDURequest struct {
}
type PerformBroadcastEDUResponse struct {
}
type InputPublicKeysRequest struct {
Keys map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult `json:"keys"`
}
type InputPublicKeysResponse struct {
}

View file

@ -0,0 +1,249 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package consumers
import (
"context"
"encoding/json"
"fmt"
"github.com/Shopify/sarama"
"github.com/matrix-org/dendrite/eduserver/api"
"github.com/matrix-org/dendrite/federationapi/queue"
"github.com/matrix-org/dendrite/federationapi/storage"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
log "github.com/sirupsen/logrus"
)
// OutputEDUConsumer consumes events that originate in EDU server.
type OutputEDUConsumer struct {
typingConsumer *internal.ContinualConsumer
sendToDeviceConsumer *internal.ContinualConsumer
receiptConsumer *internal.ContinualConsumer
db storage.Database
queues *queue.OutgoingQueues
ServerName gomatrixserverlib.ServerName
TypingTopic string
SendToDeviceTopic string
}
// NewOutputEDUConsumer creates a new OutputEDUConsumer. Call Start() to begin consuming from EDU servers.
func NewOutputEDUConsumer(
process *process.ProcessContext,
cfg *config.FederationAPI,
kafkaConsumer sarama.Consumer,
queues *queue.OutgoingQueues,
store storage.Database,
) *OutputEDUConsumer {
c := &OutputEDUConsumer{
typingConsumer: &internal.ContinualConsumer{
Process: process,
ComponentName: "eduserver/typing",
Topic: cfg.Matrix.Kafka.TopicFor(config.TopicOutputTypingEvent),
Consumer: kafkaConsumer,
PartitionStore: store,
},
sendToDeviceConsumer: &internal.ContinualConsumer{
Process: process,
ComponentName: "eduserver/sendtodevice",
Topic: cfg.Matrix.Kafka.TopicFor(config.TopicOutputSendToDeviceEvent),
Consumer: kafkaConsumer,
PartitionStore: store,
},
receiptConsumer: &internal.ContinualConsumer{
Process: process,
ComponentName: "eduserver/receipt",
Topic: cfg.Matrix.Kafka.TopicFor(config.TopicOutputReceiptEvent),
Consumer: kafkaConsumer,
PartitionStore: store,
},
queues: queues,
db: store,
ServerName: cfg.Matrix.ServerName,
TypingTopic: cfg.Matrix.Kafka.TopicFor(config.TopicOutputTypingEvent),
SendToDeviceTopic: cfg.Matrix.Kafka.TopicFor(config.TopicOutputSendToDeviceEvent),
}
c.typingConsumer.ProcessMessage = c.onTypingEvent
c.sendToDeviceConsumer.ProcessMessage = c.onSendToDeviceEvent
c.receiptConsumer.ProcessMessage = c.onReceiptEvent
return c
}
// Start consuming from EDU servers
func (t *OutputEDUConsumer) Start() error {
if err := t.typingConsumer.Start(); err != nil {
return fmt.Errorf("t.typingConsumer.Start: %w", err)
}
if err := t.sendToDeviceConsumer.Start(); err != nil {
return fmt.Errorf("t.sendToDeviceConsumer.Start: %w", err)
}
if err := t.receiptConsumer.Start(); err != nil {
return fmt.Errorf("t.receiptConsumer.Start: %w", err)
}
return nil
}
// onSendToDeviceEvent is called in response to a message received on the
// send-to-device events topic from the EDU server.
func (t *OutputEDUConsumer) onSendToDeviceEvent(msg *sarama.ConsumerMessage) error {
// Extract the send-to-device event from msg.
var ote api.OutputSendToDeviceEvent
if err := json.Unmarshal(msg.Value, &ote); err != nil {
log.WithError(err).Errorf("eduserver output log: message parse failed (expected send-to-device)")
return nil
}
// only send send-to-device events which originated from us
_, originServerName, err := gomatrixserverlib.SplitID('@', ote.Sender)
if err != nil {
log.WithError(err).WithField("user_id", ote.Sender).Error("Failed to extract domain from send-to-device sender")
return nil
}
if originServerName != t.ServerName {
log.WithField("other_server", originServerName).Info("Suppressing send-to-device: originated elsewhere")
return nil
}
_, destServerName, err := gomatrixserverlib.SplitID('@', ote.UserID)
if err != nil {
log.WithError(err).WithField("user_id", ote.UserID).Error("Failed to extract domain from send-to-device destination")
return nil
}
// Pack the EDU and marshal it
edu := &gomatrixserverlib.EDU{
Type: gomatrixserverlib.MDirectToDevice,
Origin: string(t.ServerName),
}
tdm := gomatrixserverlib.ToDeviceMessage{
Sender: ote.Sender,
Type: ote.Type,
MessageID: util.RandomString(32),
Messages: map[string]map[string]json.RawMessage{
ote.UserID: {
ote.DeviceID: ote.Content,
},
},
}
if edu.Content, err = json.Marshal(tdm); err != nil {
return err
}
log.Infof("Sending send-to-device message into %q destination queue", destServerName)
return t.queues.SendEDU(edu, t.ServerName, []gomatrixserverlib.ServerName{destServerName})
}
// onTypingEvent is called in response to a message received on the typing
// events topic from the EDU server.
func (t *OutputEDUConsumer) onTypingEvent(msg *sarama.ConsumerMessage) error {
// Extract the typing event from msg.
var ote api.OutputTypingEvent
if err := json.Unmarshal(msg.Value, &ote); err != nil {
// Skip this msg but continue processing messages.
log.WithError(err).Errorf("eduserver output log: message parse failed (expected typing)")
return nil
}
// only send typing events which originated from us
_, typingServerName, err := gomatrixserverlib.SplitID('@', ote.Event.UserID)
if err != nil {
log.WithError(err).WithField("user_id", ote.Event.UserID).Error("Failed to extract domain from typing sender")
return nil
}
if typingServerName != t.ServerName {
log.WithField("other_server", typingServerName).Info("Suppressing typing notif: originated elsewhere")
return nil
}
joined, err := t.db.GetJoinedHosts(context.TODO(), ote.Event.RoomID)
if err != nil {
return err
}
names := make([]gomatrixserverlib.ServerName, len(joined))
for i := range joined {
names[i] = joined[i].ServerName
}
edu := &gomatrixserverlib.EDU{Type: ote.Event.Type}
if edu.Content, err = json.Marshal(map[string]interface{}{
"room_id": ote.Event.RoomID,
"user_id": ote.Event.UserID,
"typing": ote.Event.Typing,
}); err != nil {
return err
}
return t.queues.SendEDU(edu, t.ServerName, names)
}
// onReceiptEvent is called in response to a message received on the receipt
// events topic from the EDU server.
func (t *OutputEDUConsumer) onReceiptEvent(msg *sarama.ConsumerMessage) error {
// Extract the typing event from msg.
var receipt api.OutputReceiptEvent
if err := json.Unmarshal(msg.Value, &receipt); err != nil {
// Skip this msg but continue processing messages.
log.WithError(err).Errorf("eduserver output log: message parse failed (expected receipt)")
return nil
}
// only send receipt events which originated from us
_, receiptServerName, err := gomatrixserverlib.SplitID('@', receipt.UserID)
if err != nil {
log.WithError(err).WithField("user_id", receipt.UserID).Error("Failed to extract domain from receipt sender")
return nil
}
if receiptServerName != t.ServerName {
return nil // don't log, very spammy as it logs for each remote receipt
}
joined, err := t.db.GetJoinedHosts(context.TODO(), receipt.RoomID)
if err != nil {
return err
}
names := make([]gomatrixserverlib.ServerName, len(joined))
for i := range joined {
names[i] = joined[i].ServerName
}
content := map[string]api.FederationReceiptMRead{}
content[receipt.RoomID] = api.FederationReceiptMRead{
User: map[string]api.FederationReceiptData{
receipt.UserID: {
Data: api.ReceiptTS{
TS: receipt.Timestamp,
},
EventIDs: []string{receipt.EventID},
},
},
}
edu := &gomatrixserverlib.EDU{
Type: gomatrixserverlib.MReceipt,
Origin: string(t.ServerName),
}
if edu.Content, err = json.Marshal(content); err != nil {
return err
}
return t.queues.SendEDU(edu, t.ServerName, names)
}

View file

@ -0,0 +1,202 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package consumers
import (
"context"
"encoding/json"
"fmt"
"github.com/Shopify/sarama"
eduserverAPI "github.com/matrix-org/dendrite/eduserver/api"
"github.com/matrix-org/dendrite/federationapi/queue"
"github.com/matrix-org/dendrite/federationapi/storage"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/keyserver/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
)
// KeyChangeConsumer consumes events that originate in key server.
type KeyChangeConsumer struct {
consumer *internal.ContinualConsumer
db storage.Database
queues *queue.OutgoingQueues
serverName gomatrixserverlib.ServerName
rsAPI roomserverAPI.RoomserverInternalAPI
}
// NewKeyChangeConsumer creates a new KeyChangeConsumer. Call Start() to begin consuming from key servers.
func NewKeyChangeConsumer(
process *process.ProcessContext,
cfg *config.KeyServer,
kafkaConsumer sarama.Consumer,
queues *queue.OutgoingQueues,
store storage.Database,
rsAPI roomserverAPI.RoomserverInternalAPI,
) *KeyChangeConsumer {
c := &KeyChangeConsumer{
consumer: &internal.ContinualConsumer{
Process: process,
ComponentName: "federationapi/keychange",
Topic: string(cfg.Matrix.Kafka.TopicFor(config.TopicOutputKeyChangeEvent)),
Consumer: kafkaConsumer,
PartitionStore: store,
},
queues: queues,
db: store,
serverName: cfg.Matrix.ServerName,
rsAPI: rsAPI,
}
c.consumer.ProcessMessage = c.onMessage
return c
}
// Start consuming from key servers
func (t *KeyChangeConsumer) Start() error {
if err := t.consumer.Start(); err != nil {
return fmt.Errorf("t.consumer.Start: %w", err)
}
return nil
}
// onMessage is called in response to a message received on the
// key change events topic from the key server.
func (t *KeyChangeConsumer) onMessage(msg *sarama.ConsumerMessage) error {
var m api.DeviceMessage
if err := json.Unmarshal(msg.Value, &m); err != nil {
logrus.WithError(err).Errorf("failed to read device message from key change topic")
return nil
}
if m.DeviceKeys == nil && m.OutputCrossSigningKeyUpdate == nil {
// This probably shouldn't happen but stops us from panicking if we come
// across an update that doesn't satisfy either types.
return nil
}
switch m.Type {
case api.TypeCrossSigningUpdate:
return t.onCrossSigningMessage(m)
case api.TypeDeviceKeyUpdate:
fallthrough
default:
return t.onDeviceKeyMessage(m)
}
}
func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) error {
logger := logrus.WithField("user_id", m.UserID)
// only send key change events which originated from us
_, originServerName, err := gomatrixserverlib.SplitID('@', m.UserID)
if err != nil {
logger.WithError(err).Error("Failed to extract domain from key change event")
return nil
}
if originServerName != t.serverName {
return nil
}
var queryRes roomserverAPI.QueryRoomsForUserResponse
err = t.rsAPI.QueryRoomsForUser(context.Background(), &roomserverAPI.QueryRoomsForUserRequest{
UserID: m.UserID,
WantMembership: "join",
}, &queryRes)
if err != nil {
logger.WithError(err).Error("failed to calculate joined rooms for user")
return nil
}
// send this key change to all servers who share rooms with this user.
destinations, err := t.db.GetJoinedHostsForRooms(context.Background(), queryRes.RoomIDs)
if err != nil {
logger.WithError(err).Error("failed to calculate joined hosts for rooms user is in")
return nil
}
// Pack the EDU and marshal it
edu := &gomatrixserverlib.EDU{
Type: gomatrixserverlib.MDeviceListUpdate,
Origin: string(t.serverName),
}
event := gomatrixserverlib.DeviceListUpdateEvent{
UserID: m.UserID,
DeviceID: m.DeviceID,
DeviceDisplayName: m.DisplayName,
StreamID: m.StreamID,
PrevID: prevID(m.StreamID),
Deleted: len(m.KeyJSON) == 0,
Keys: m.KeyJSON,
}
if edu.Content, err = json.Marshal(event); err != nil {
return err
}
logrus.Infof("Sending device list update message to %q", destinations)
return t.queues.SendEDU(edu, t.serverName, destinations)
}
func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) error {
output := m.CrossSigningKeyUpdate
_, host, err := gomatrixserverlib.SplitID('@', output.UserID)
if err != nil {
logrus.WithError(err).Errorf("fedsender key change consumer: user ID parse failure")
return nil
}
if host != gomatrixserverlib.ServerName(t.serverName) {
// Ignore any messages that didn't originate locally, otherwise we'll
// end up parroting information we received from other servers.
return nil
}
logger := logrus.WithField("user_id", output.UserID)
var queryRes roomserverAPI.QueryRoomsForUserResponse
err = t.rsAPI.QueryRoomsForUser(context.Background(), &roomserverAPI.QueryRoomsForUserRequest{
UserID: output.UserID,
WantMembership: "join",
}, &queryRes)
if err != nil {
logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined rooms for user")
return nil
}
// send this key change to all servers who share rooms with this user.
destinations, err := t.db.GetJoinedHostsForRooms(context.Background(), queryRes.RoomIDs)
if err != nil {
logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined hosts for rooms user is in")
return nil
}
// Pack the EDU and marshal it
edu := &gomatrixserverlib.EDU{
Type: eduserverAPI.MSigningKeyUpdate,
Origin: string(t.serverName),
}
if edu.Content, err = json.Marshal(output); err != nil {
logger.WithError(err).Error("fedsender key change consumer: failed to marshal output, dropping")
return nil
}
logger.Infof("Sending cross-signing update message to %q", destinations)
return t.queues.SendEDU(edu, t.serverName, destinations)
}
func prevID(streamID int) []int {
if streamID <= 1 {
return nil
}
return []int{streamID - 1}
}

View file

@ -0,0 +1,407 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package consumers
import (
"context"
"encoding/json"
"fmt"
"github.com/Shopify/sarama"
"github.com/matrix-org/dendrite/federationapi/queue"
"github.com/matrix-org/dendrite/federationapi/storage"
"github.com/matrix-org/dendrite/federationapi/types"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus"
)
// OutputRoomEventConsumer consumes events that originated in the room server.
type OutputRoomEventConsumer struct {
cfg *config.FederationAPI
rsAPI api.RoomserverInternalAPI
rsConsumer *internal.ContinualConsumer
db storage.Database
queues *queue.OutgoingQueues
}
// NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers.
func NewOutputRoomEventConsumer(
process *process.ProcessContext,
cfg *config.FederationAPI,
kafkaConsumer sarama.Consumer,
queues *queue.OutgoingQueues,
store storage.Database,
rsAPI api.RoomserverInternalAPI,
) *OutputRoomEventConsumer {
consumer := internal.ContinualConsumer{
Process: process,
ComponentName: "federationapi/roomserver",
Topic: string(cfg.Matrix.Kafka.TopicFor(config.TopicOutputRoomEvent)),
Consumer: kafkaConsumer,
PartitionStore: store,
}
s := &OutputRoomEventConsumer{
cfg: cfg,
rsConsumer: &consumer,
db: store,
queues: queues,
rsAPI: rsAPI,
}
consumer.ProcessMessage = s.onMessage
return s
}
// Start consuming from room servers
func (s *OutputRoomEventConsumer) Start() error {
return s.rsConsumer.Start()
}
// onMessage is called when the federation server receives a new event from the room server output log.
// It is unsafe to call this with messages for the same room in multiple gorountines
// because updates it will likely fail with a types.EventIDMismatchError when it
// realises that it cannot update the room state using the deltas.
func (s *OutputRoomEventConsumer) onMessage(msg *sarama.ConsumerMessage) error {
// Parse out the event JSON
var output api.OutputEvent
if err := json.Unmarshal(msg.Value, &output); err != nil {
// If the message was invalid, log it and move on to the next message in the stream
log.WithError(err).Errorf("roomserver output log: message parse failure")
return nil
}
switch output.Type {
case api.OutputTypeNewRoomEvent:
ev := output.NewRoomEvent.Event
if output.NewRoomEvent.RewritesState {
if err := s.db.PurgeRoomState(context.TODO(), ev.RoomID()); err != nil {
return fmt.Errorf("s.db.PurgeRoom: %w", err)
}
}
if err := s.processMessage(*output.NewRoomEvent); err != nil {
switch err.(type) {
case *queue.ErrorFederationDisabled:
log.WithField("error", output.Type).Info(
err.Error(),
)
default:
// panic rather than continue with an inconsistent database
log.WithFields(log.Fields{
"event_id": ev.EventID(),
"event": string(ev.JSON()),
"add": output.NewRoomEvent.AddsStateEventIDs,
"del": output.NewRoomEvent.RemovesStateEventIDs,
log.ErrorKey: err,
}).Panicf("roomserver output log: write room event failure")
}
return nil
}
case api.OutputTypeNewInboundPeek:
if err := s.processInboundPeek(*output.NewInboundPeek); err != nil {
log.WithFields(log.Fields{
"event": output.NewInboundPeek,
log.ErrorKey: err,
}).Panicf("roomserver output log: remote peek event failure")
return nil
}
default:
log.WithField("type", output.Type).Debug(
"roomserver output log: ignoring unknown output type",
)
return nil
}
return nil
}
// processInboundPeek starts tracking a new federated inbound peek (replacing the existing one if any)
// causing the federationapi to start sending messages to the peeking server
func (s *OutputRoomEventConsumer) processInboundPeek(orp api.OutputNewInboundPeek) error {
// FIXME: there's a race here - we should start /sending new peeked events
// atomically after the orp.LatestEventID to ensure there are no gaps between
// the peek beginning and the send stream beginning.
//
// We probably need to track orp.LatestEventID on the inbound peek, but it's
// unclear how we then use that to prevent the race when we start the send
// stream.
//
// This is making the tests flakey.
return s.db.AddInboundPeek(context.TODO(), orp.ServerName, orp.RoomID, orp.PeekID, orp.RenewalInterval)
}
// processMessage updates the list of currently joined hosts in the room
// and then sends the event to the hosts that were joined before the event.
func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent) error {
addsJoinedHosts, err := joinedHostsFromEvents(gomatrixserverlib.UnwrapEventHeaders(ore.AddsState()))
if err != nil {
return err
}
// Update our copy of the current state.
// We keep a copy of the current state because the state at each event is
// expressed as a delta against the current state.
// TODO(#290): handle EventIDMismatchError and recover the current state by
// talking to the roomserver
oldJoinedHosts, err := s.db.UpdateRoom(
context.TODO(),
ore.Event.RoomID(),
ore.LastSentEventID,
ore.Event.EventID(),
addsJoinedHosts,
ore.RemovesStateEventIDs,
)
if err != nil {
return err
}
if oldJoinedHosts == nil {
// This means that there is nothing to update as this is a duplicate
// message.
// This can happen if dendrite crashed between reading the message and
// persisting the stream position.
return nil
}
if ore.SendAsServer == api.DoNotSendToOtherServers {
// Ignore event that we don't need to send anywhere.
return nil
}
// Work out which hosts were joined at the event itself.
joinedHostsAtEvent, err := s.joinedHostsAtEvent(ore, oldJoinedHosts)
if err != nil {
return err
}
// TODO: do housekeeping to evict unrenewed peeking hosts
// TODO: implement query to let the fedapi check whether a given peek is live or not
// Send the event.
return s.queues.SendEvent(
ore.Event, gomatrixserverlib.ServerName(ore.SendAsServer), joinedHostsAtEvent,
)
}
// joinedHostsAtEvent works out a list of matrix servers that were joined to
// the room at the event (including peeking ones)
// It is important to use the state at the event for sending messages because:
// 1) We shouldn't send messages to servers that weren't in the room.
// 2) If a server is kicked from the rooms it should still be told about the
// kick event,
// Usually the list can be calculated locally, but sometimes it will need fetch
// events from the room server.
// Returns an error if there was a problem talking to the room server.
func (s *OutputRoomEventConsumer) joinedHostsAtEvent(
ore api.OutputNewRoomEvent, oldJoinedHosts []types.JoinedHost,
) ([]gomatrixserverlib.ServerName, error) {
// Combine the delta into a single delta so that the adds and removes can
// cancel each other out. This should reduce the number of times we need
// to fetch a state event from the room server.
combinedAdds, combinedRemoves := combineDeltas(
ore.AddsStateEventIDs, ore.RemovesStateEventIDs,
ore.StateBeforeAddsEventIDs, ore.StateBeforeRemovesEventIDs,
)
combinedAddsEvents, err := s.lookupStateEvents(combinedAdds, ore.Event.Event)
if err != nil {
return nil, err
}
combinedAddsJoinedHosts, err := joinedHostsFromEvents(combinedAddsEvents)
if err != nil {
return nil, err
}
removed := map[string]bool{}
for _, eventID := range combinedRemoves {
removed[eventID] = true
}
joined := map[gomatrixserverlib.ServerName]bool{}
for _, joinedHost := range oldJoinedHosts {
if removed[joinedHost.MemberEventID] {
// This m.room.member event is part of the current state of the
// room, but not part of the state at the event we are processing
// Therefore we can't use it to tell whether the server was in
// the room at the event.
continue
}
joined[joinedHost.ServerName] = true
}
for _, joinedHost := range combinedAddsJoinedHosts {
// This m.room.member event was part of the state of the room at the
// event, but isn't part of the current state of the room now.
joined[joinedHost.ServerName] = true
}
// handle peeking hosts
inboundPeeks, err := s.db.GetInboundPeeks(context.TODO(), ore.Event.Event.RoomID())
if err != nil {
return nil, err
}
for _, inboundPeek := range inboundPeeks {
joined[inboundPeek.ServerName] = true
}
var result []gomatrixserverlib.ServerName
for serverName, include := range joined {
if include {
result = append(result, serverName)
}
}
return result, nil
}
// joinedHostsFromEvents turns a list of state events into a list of joined hosts.
// This errors if one of the events was invalid.
// It should be impossible for an invalid event to get this far in the pipeline.
func joinedHostsFromEvents(evs []*gomatrixserverlib.Event) ([]types.JoinedHost, error) {
var joinedHosts []types.JoinedHost
for _, ev := range evs {
if ev.Type() != "m.room.member" || ev.StateKey() == nil {
continue
}
membership, err := ev.Membership()
if err != nil {
return nil, err
}
if membership != gomatrixserverlib.Join {
continue
}
_, serverName, err := gomatrixserverlib.SplitID('@', *ev.StateKey())
if err != nil {
return nil, err
}
joinedHosts = append(joinedHosts, types.JoinedHost{
MemberEventID: ev.EventID(), ServerName: serverName,
})
}
return joinedHosts, nil
}
// combineDeltas combines two deltas into a single delta.
// Assumes that the order of operations is add(1), remove(1), add(2), remove(2).
// Removes duplicate entries and redundant operations from each delta.
func combineDeltas(adds1, removes1, adds2, removes2 []string) (adds, removes []string) {
addSet := map[string]bool{}
removeSet := map[string]bool{}
// combine processes each unique value in a list.
// If the value is in the removeFrom set then it is removed from that set.
// Otherwise it is added to the addTo set.
combine := func(values []string, removeFrom, addTo map[string]bool) {
processed := map[string]bool{}
for _, value := range values {
if processed[value] {
continue
}
processed[value] = true
if removeFrom[value] {
delete(removeFrom, value)
} else {
addTo[value] = true
}
}
}
combine(adds1, nil, addSet)
combine(removes1, addSet, removeSet)
combine(adds2, removeSet, addSet)
combine(removes2, addSet, removeSet)
for value := range addSet {
adds = append(adds, value)
}
for value := range removeSet {
removes = append(removes, value)
}
return
}
// lookupStateEvents looks up the state events that are added by a new event.
func (s *OutputRoomEventConsumer) lookupStateEvents(
addsStateEventIDs []string, event *gomatrixserverlib.Event,
) ([]*gomatrixserverlib.Event, error) {
// Fast path if there aren't any new state events.
if len(addsStateEventIDs) == 0 {
return nil, nil
}
// Fast path if the only state event added is the event itself.
if len(addsStateEventIDs) == 1 && addsStateEventIDs[0] == event.EventID() {
return []*gomatrixserverlib.Event{event}, nil
}
missing := addsStateEventIDs
var result []*gomatrixserverlib.Event
// Check if event itself is being added.
for _, eventID := range missing {
if eventID == event.EventID() {
result = append(result, event)
break
}
}
missing = missingEventsFrom(result, addsStateEventIDs)
if len(missing) == 0 {
return result, nil
}
// At this point the missing events are neither the event itself nor are
// they present in our local database. Our only option is to fetch them
// from the roomserver using the query API.
eventReq := api.QueryEventsByIDRequest{EventIDs: missing}
var eventResp api.QueryEventsByIDResponse
if err := s.rsAPI.QueryEventsByID(context.TODO(), &eventReq, &eventResp); err != nil {
return nil, err
}
for _, headeredEvent := range eventResp.Events {
result = append(result, headeredEvent.Event)
}
missing = missingEventsFrom(result, addsStateEventIDs)
if len(missing) != 0 {
return nil, fmt.Errorf(
"missing %d state events IDs at event %q", len(missing), event.EventID(),
)
}
return result, nil
}
func missingEventsFrom(events []*gomatrixserverlib.Event, required []string) []string {
have := map[string]bool{}
for _, event := range events {
have[event.EventID()] = true
}
var missing []string
for _, eventID := range required {
if !have[eventID] {
missing = append(missing, eventID)
}
}
return missing
}

View file

@ -0,0 +1,53 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package consumers
import (
"testing"
)
func TestCombineNoOp(t *testing.T) {
inputAdd1 := []string{"a", "b", "c"}
inputDel1 := []string{"a", "b", "d"}
inputAdd2 := []string{"a", "d", "e"}
inputDel2 := []string{"a", "c", "e", "e"}
gotAdd, gotDel := combineDeltas(inputAdd1, inputDel1, inputAdd2, inputDel2)
if len(gotAdd) != 0 {
t.Errorf("wanted combined adds to be an empty list, got %#v", gotAdd)
}
if len(gotDel) != 0 {
t.Errorf("wanted combined removes to be an empty list, got %#v", gotDel)
}
}
func TestCombineDedup(t *testing.T) {
inputAdd1 := []string{"a", "a"}
inputDel1 := []string{"b", "b"}
inputAdd2 := []string{"a", "a"}
inputDel2 := []string{"b", "b"}
gotAdd, gotDel := combineDeltas(inputAdd1, inputDel1, inputAdd2, inputDel2)
if len(gotAdd) != 1 || gotAdd[0] != "a" {
t.Errorf("wanted combined adds to be %#v, got %#v", []string{"a"}, gotAdd)
}
if len(gotDel) != 1 || gotDel[0] != "b" {
t.Errorf("wanted combined removes to be %#v, got %#v", []string{"b"}, gotDel)
}
}

View file

@ -17,17 +17,33 @@ package federationapi
import (
"github.com/gorilla/mux"
eduserverAPI "github.com/matrix-org/dendrite/eduserver/api"
"github.com/matrix-org/dendrite/federationapi/api"
federationAPI "github.com/matrix-org/dendrite/federationapi/api"
federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api"
"github.com/matrix-org/dendrite/federationapi/consumers"
"github.com/matrix-org/dendrite/federationapi/internal"
"github.com/matrix-org/dendrite/federationapi/inthttp"
"github.com/matrix-org/dendrite/federationapi/queue"
"github.com/matrix-org/dendrite/federationapi/statistics"
"github.com/matrix-org/dendrite/federationapi/storage"
"github.com/matrix-org/dendrite/internal/caching"
keyserverAPI "github.com/matrix-org/dendrite/keyserver/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/setup/kafka"
userapi "github.com/matrix-org/dendrite/userapi/api"
"github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/federationapi/routing"
"github.com/matrix-org/gomatrixserverlib"
)
// AddInternalRoutes registers HTTP handlers for the internal API. Invokes functions
// on the given input API.
func AddInternalRoutes(router *mux.Router, intAPI api.FederationInternalAPI) {
inthttp.AddRoutes(intAPI, router)
}
// AddPublicRoutes sets up and registers HTTP handlers on the base API muxes for the FederationAPI component.
func AddPublicRoutes(
fedRouter, keyRouter, wellKnownRouter *mux.Router,
@ -36,7 +52,7 @@ func AddPublicRoutes(
federation *gomatrixserverlib.FederationClient,
keyRing gomatrixserverlib.JSONVerifier,
rsAPI roomserverAPI.RoomserverInternalAPI,
federationSenderAPI federationSenderAPI.FederationSenderInternalAPI,
federationAPI federationAPI.FederationInternalAPI,
eduAPI eduserverAPI.EDUServerInputAPI,
keyAPI keyserverAPI.KeyInternalAPI,
mscCfg *config.MSCs,
@ -44,8 +60,70 @@ func AddPublicRoutes(
) {
routing.Setup(
fedRouter, keyRouter, wellKnownRouter, cfg, rsAPI,
eduAPI, federationSenderAPI, keyRing,
eduAPI, federationAPI, keyRing,
federation, userAPI, keyAPI, mscCfg,
servers,
)
}
// NewInternalAPI returns a concerete implementation of the internal API. Callers
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
func NewInternalAPI(
base *base.BaseDendrite,
federation *gomatrixserverlib.FederationClient,
rsAPI roomserverAPI.RoomserverInternalAPI,
caches *caching.Caches,
resetBlacklist bool,
) api.FederationInternalAPI {
cfg := &base.Cfg.FederationAPI
federationDB, err := storage.NewDatabase(&cfg.Database, base.Caches)
if err != nil {
logrus.WithError(err).Panic("failed to connect to federation sender db")
}
if resetBlacklist {
_ = federationDB.RemoveAllServersFromBlacklist()
}
stats := &statistics.Statistics{
DB: federationDB,
FailuresUntilBlacklist: cfg.FederationMaxRetries,
}
consumer, _ := kafka.SetupConsumerProducer(&cfg.Matrix.Kafka)
queues := queue.NewOutgoingQueues(
federationDB, base.ProcessContext,
cfg.Matrix.DisableFederation,
cfg.Matrix.ServerName, federation, rsAPI, stats,
&queue.SigningInfo{
KeyID: cfg.Matrix.KeyID,
PrivateKey: cfg.Matrix.PrivateKey,
ServerName: cfg.Matrix.ServerName,
},
)
rsConsumer := consumers.NewOutputRoomEventConsumer(
base.ProcessContext, cfg, consumer, queues,
federationDB, rsAPI,
)
if err = rsConsumer.Start(); err != nil {
logrus.WithError(err).Panic("failed to start room server consumer")
}
tsConsumer := consumers.NewOutputEDUConsumer(
base.ProcessContext, cfg, consumer, queues, federationDB,
)
if err := tsConsumer.Start(); err != nil {
logrus.WithError(err).Panic("failed to start typing server consumer")
}
keyConsumer := consumers.NewKeyChangeConsumer(
base.ProcessContext, &base.Cfg.KeyServer, consumer, queues, federationDB, rsAPI,
)
if err := keyConsumer.Start(); err != nil {
logrus.WithError(err).Panic("failed to start key server consumer")
}
return internal.NewFederationInternalAPI(federationDB, cfg, rsAPI, federation, stats, caches, queues)
}

View file

@ -0,0 +1,320 @@
package federationapi
import (
"bytes"
"context"
"crypto/ed25519"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"os"
"reflect"
"testing"
"time"
"github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/federationapi/routing"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
)
type server struct {
name gomatrixserverlib.ServerName // server name
validity time.Duration // key validity duration from now
config *config.FederationAPI // skeleton config, from TestMain
fedclient *gomatrixserverlib.FederationClient // uses MockRoundTripper
cache *caching.Caches // server-specific cache
api api.FederationInternalAPI // server-specific server key API
}
func (s *server) renew() {
// This updates the validity period to be an hour in the
// future, which is particularly useful in server A and
// server C's cases which have validity either as now or
// in the past.
s.validity = time.Hour
s.config.Matrix.KeyValidityPeriod = s.validity
}
var (
serverKeyID = gomatrixserverlib.KeyID("ed25519:auto")
serverA = &server{name: "a.com", validity: time.Duration(0)} // expires now
serverB = &server{name: "b.com", validity: time.Hour} // expires in an hour
serverC = &server{name: "c.com", validity: -time.Hour} // expired an hour ago
)
var servers = map[string]*server{
"a.com": serverA,
"b.com": serverB,
"c.com": serverC,
}
func TestMain(m *testing.M) {
// Set up the server key API for each "server" that we
// will use in our tests.
for _, s := range servers {
// Generate a new key.
_, testPriv, err := ed25519.GenerateKey(nil)
if err != nil {
panic("can't generate identity key: " + err.Error())
}
// Create a new cache but don't enable prometheus!
s.cache, err = caching.NewInMemoryLRUCache(false)
if err != nil {
panic("can't create cache: " + err.Error())
}
// Draw up just enough Dendrite config for the server key
// API to work.
cfg := &config.Dendrite{}
cfg.Defaults()
cfg.Global.ServerName = gomatrixserverlib.ServerName(s.name)
cfg.Global.PrivateKey = testPriv
cfg.Global.Kafka.UseNaffka = true
cfg.Global.KeyID = serverKeyID
cfg.Global.KeyValidityPeriod = s.validity
cfg.FederationAPI.Database.ConnectionString = config.DataSource("file::memory:")
s.config = &cfg.FederationAPI
// Create a transport which redirects federation requests to
// the mock round tripper. Since we're not *really* listening for
// federation requests then this will return the key instead.
transport := &http.Transport{}
transport.RegisterProtocol("matrix", &MockRoundTripper{})
// Create the federation client.
s.fedclient = gomatrixserverlib.NewFederationClient(
s.config.Matrix.ServerName, serverKeyID, testPriv,
gomatrixserverlib.WithTransport(transport),
)
// Finally, build the server key APIs.
sbase := base.NewBaseDendrite(cfg, "Monolith", base.NoCacheMetrics)
s.api = NewInternalAPI(sbase, s.fedclient, nil, s.cache, true)
}
// Now that we have built our server key APIs, start the
// rest of the tests.
os.Exit(m.Run())
}
type MockRoundTripper struct{}
func (m *MockRoundTripper) RoundTrip(req *http.Request) (res *http.Response, err error) {
// Check if the request is looking for keys from a server that
// we know about in the test. The only reason this should go wrong
// is if the test is broken.
s, ok := servers[req.Host]
if !ok {
return nil, fmt.Errorf("server not known: %s", req.Host)
}
// We're intercepting /matrix/key/v2/server requests here, so check
// that the URL supplied in the request is for that.
if req.URL.Path != "/_matrix/key/v2/server" {
return nil, fmt.Errorf("unexpected request path: %s", req.URL.Path)
}
// Get the keys and JSON-ify them.
keys := routing.LocalKeys(s.config)
body, err := json.MarshalIndent(keys.JSON, "", " ")
if err != nil {
return nil, err
}
// And respond.
res = &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(bytes.NewReader(body)),
}
return
}
func TestServersRequestOwnKeys(t *testing.T) {
// Each server will request its own keys. There's no reason
// for this to fail as each server should know its own keys.
for name, s := range servers {
req := gomatrixserverlib.PublicKeyLookupRequest{
ServerName: s.name,
KeyID: serverKeyID,
}
res, err := s.api.FetchKeys(
context.Background(),
map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp{
req: gomatrixserverlib.AsTimestamp(time.Now()),
},
)
if err != nil {
t.Fatalf("server could not fetch own key: %s", err)
}
if _, ok := res[req]; !ok {
t.Fatalf("server didn't return its own key in the results")
}
t.Logf("%s's key expires at %s\n", name, res[req].ValidUntilTS.Time())
}
}
func TestCachingBehaviour(t *testing.T) {
// Server A will request Server B's key, which has a validity
// period of an hour from now. We should retrieve the key and
// it should make it into the cache automatically.
req := gomatrixserverlib.PublicKeyLookupRequest{
ServerName: serverB.name,
KeyID: serverKeyID,
}
ts := gomatrixserverlib.AsTimestamp(time.Now())
res, err := serverA.api.FetchKeys(
context.Background(),
map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp{
req: ts,
},
)
if err != nil {
t.Fatalf("server A failed to retrieve server B key: %s", err)
}
if len(res) != 1 {
t.Fatalf("server B should have returned one key but instead returned %d keys", len(res))
}
if _, ok := res[req]; !ok {
t.Fatalf("server B isn't included in the key fetch response")
}
// At this point, if the previous key request was a success,
// then the cache should now contain the key. Check if that's
// the case - if it isn't then there's something wrong with
// the cache implementation or we failed to get the key.
cres, ok := serverA.cache.GetServerKey(req, ts)
if !ok {
t.Fatalf("server B key should be in cache but isn't")
}
if !reflect.DeepEqual(cres, res[req]) {
t.Fatalf("the cached result from server B wasn't what server B gave us")
}
// If we ask the cache for the same key but this time for an event
// that happened in +30 minutes. Since the validity period is for
// another hour, then we should get a response back from the cache.
_, ok = serverA.cache.GetServerKey(
req,
gomatrixserverlib.AsTimestamp(time.Now().Add(time.Minute*30)),
)
if !ok {
t.Fatalf("server B key isn't in cache when it should be (+30 minutes)")
}
// If we ask the cache for the same key but this time for an event
// that happened in +90 minutes then we should expect to get no
// cache result. This is because the cache shouldn't return a result
// that is obviously past the validity of the event.
_, ok = serverA.cache.GetServerKey(
req,
gomatrixserverlib.AsTimestamp(time.Now().Add(time.Minute*90)),
)
if ok {
t.Fatalf("server B key is in cache when it shouldn't be (+90 minutes)")
}
}
func TestRenewalBehaviour(t *testing.T) {
// Server A will request Server C's key but their validity period
// is an hour in the past. We'll retrieve the key as, even though it's
// past its validity, it will be able to verify past events.
req := gomatrixserverlib.PublicKeyLookupRequest{
ServerName: serverC.name,
KeyID: serverKeyID,
}
res, err := serverA.api.FetchKeys(
context.Background(),
map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp{
req: gomatrixserverlib.AsTimestamp(time.Now()),
},
)
if err != nil {
t.Fatalf("server A failed to retrieve server C key: %s", err)
}
if len(res) != 1 {
t.Fatalf("server C should have returned one key but instead returned %d keys", len(res))
}
if _, ok := res[req]; !ok {
t.Fatalf("server C isn't included in the key fetch response")
}
// If we ask the cache for the server key for an event that happened
// 90 minutes ago then we should get a cache result, as the key hadn't
// passed its validity by that point. The fact that the key is now in
// the cache is, in itself, proof that we successfully retrieved the
// key before.
oldcached, ok := serverA.cache.GetServerKey(
req,
gomatrixserverlib.AsTimestamp(time.Now().Add(-time.Minute*90)),
)
if !ok {
t.Fatalf("server C key isn't in cache when it should be (-90 minutes)")
}
// If we now ask the cache for the same key but this time for an event
// that only happened 30 minutes ago then we shouldn't get a cached
// result, as the event happened after the key validity expired. This
// is really just for sanity checking.
_, ok = serverA.cache.GetServerKey(
req,
gomatrixserverlib.AsTimestamp(time.Now().Add(-time.Minute*30)),
)
if ok {
t.Fatalf("server B key is in cache when it shouldn't be (-30 minutes)")
}
// We're now going to kick server C into renewing its key. Since we're
// happy at this point that the key that we already have is from the past
// then repeating a key fetch should cause us to try and renew the key.
// If so, then the new key will end up in our cache.
serverC.renew()
res, err = serverA.api.FetchKeys(
context.Background(),
map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp{
req: gomatrixserverlib.AsTimestamp(time.Now()),
},
)
if err != nil {
t.Fatalf("server A failed to retrieve server C key: %s", err)
}
if len(res) != 1 {
t.Fatalf("server C should have returned one key but instead returned %d keys", len(res))
}
if _, ok = res[req]; !ok {
t.Fatalf("server C isn't included in the key fetch response")
}
// We're now going to ask the cache what the new key validity is. If
// it is still the same as the previous validity then we've failed to
// retrieve the renewed key. If it's newer then we've successfully got
// the renewed key.
newcached, ok := serverA.cache.GetServerKey(
req,
gomatrixserverlib.AsTimestamp(time.Now().Add(-time.Minute*30)),
)
if !ok {
t.Fatalf("server B key isn't in cache when it shouldn't be (post-renewal)")
}
if oldcached.ValidUntilTS >= newcached.ValidUntilTS {
t.Fatalf("the server B key should have been renewed but wasn't")
}
t.Log(res)
}

View file

@ -8,7 +8,7 @@ import (
"github.com/matrix-org/dendrite/federationapi"
"github.com/matrix-org/dendrite/internal/test"
"github.com/matrix-org/dendrite/setup"
"github.com/matrix-org/dendrite/setup/base"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib"
@ -25,10 +25,10 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) {
cfg.Global.PrivateKey = privKey
cfg.Global.Kafka.UseNaffka = true
cfg.Global.Kafka.Database.ConnectionString = config.DataSource("file::memory:")
cfg.FederationSender.Database.ConnectionString = config.DataSource("file::memory:")
base := setup.NewBaseDendrite(cfg, "Monolith", false)
cfg.FederationAPI.Database.ConnectionString = config.DataSource("file::memory:")
base := base.NewBaseDendrite(cfg, "Monolith", base.NoCacheMetrics)
keyRing := &test.NopJSONVerifier{}
fsAPI := base.FederationSenderHTTPClient()
fsAPI := base.FederationAPIHTTPClient()
// TODO: This is pretty fragile, as if anything calls anything on these nils this test will break.
// Unfortunately, it makes little sense to instantiate these dependencies when we just want to test routing.
federationapi.AddPublicRoutes(base.PublicFederationAPIMux, base.PublicKeyAPIMux, base.PublicWellKnownAPIMux, &cfg.FederationAPI, nil, nil, keyRing, nil, fsAPI, nil, nil, &cfg.MSCs, nil)

View file

@ -0,0 +1,304 @@
package internal
import (
"context"
"crypto/ed25519"
"encoding/base64"
"sync"
"time"
"github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/federationapi/queue"
"github.com/matrix-org/dendrite/federationapi/statistics"
"github.com/matrix-org/dendrite/federationapi/storage"
"github.com/matrix-org/dendrite/federationapi/storage/cache"
"github.com/matrix-org/dendrite/internal/caching"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
)
// FederationInternalAPI is an implementation of api.FederationInternalAPI
type FederationInternalAPI struct {
db storage.Database
cfg *config.FederationAPI
statistics *statistics.Statistics
rsAPI roomserverAPI.RoomserverInternalAPI
federation *gomatrixserverlib.FederationClient
keyRing *gomatrixserverlib.KeyRing
queues *queue.OutgoingQueues
joins sync.Map // joins currently in progress
}
func NewFederationInternalAPI(
db storage.Database, cfg *config.FederationAPI,
rsAPI roomserverAPI.RoomserverInternalAPI,
federation *gomatrixserverlib.FederationClient,
statistics *statistics.Statistics,
caches *caching.Caches,
queues *queue.OutgoingQueues,
) *FederationInternalAPI {
serverKeyDB, err := cache.NewKeyDatabase(db, caches)
if err != nil {
logrus.WithError(err).Panicf("failed to set up caching wrapper for server key database")
}
keyRing := &gomatrixserverlib.KeyRing{
KeyFetchers: []gomatrixserverlib.KeyFetcher{},
KeyDatabase: serverKeyDB,
}
addDirectFetcher := func() {
keyRing.KeyFetchers = append(
keyRing.KeyFetchers,
&gomatrixserverlib.DirectKeyFetcher{
Client: federation,
},
)
}
if cfg.PreferDirectFetch {
addDirectFetcher()
} else {
defer addDirectFetcher()
}
var b64e = base64.StdEncoding.WithPadding(base64.NoPadding)
for _, ps := range cfg.KeyPerspectives {
perspective := &gomatrixserverlib.PerspectiveKeyFetcher{
PerspectiveServerName: ps.ServerName,
PerspectiveServerKeys: map[gomatrixserverlib.KeyID]ed25519.PublicKey{},
Client: federation,
}
for _, key := range ps.Keys {
rawkey, err := b64e.DecodeString(key.PublicKey)
if err != nil {
logrus.WithError(err).WithFields(logrus.Fields{
"server_name": ps.ServerName,
"public_key": key.PublicKey,
}).Warn("Couldn't parse perspective key")
continue
}
perspective.PerspectiveServerKeys[key.KeyID] = rawkey
}
keyRing.KeyFetchers = append(keyRing.KeyFetchers, perspective)
logrus.WithFields(logrus.Fields{
"server_name": ps.ServerName,
"num_public_keys": len(ps.Keys),
}).Info("Enabled perspective key fetcher")
}
return &FederationInternalAPI{
db: db,
cfg: cfg,
rsAPI: rsAPI,
keyRing: keyRing,
federation: federation,
statistics: statistics,
queues: queues,
}
}
func (a *FederationInternalAPI) isBlacklistedOrBackingOff(s gomatrixserverlib.ServerName) (*statistics.ServerStatistics, error) {
stats := a.statistics.ForServer(s)
until, blacklisted := stats.BackoffInfo()
if blacklisted {
return stats, &api.FederationClientError{
Blacklisted: true,
}
}
now := time.Now()
if until != nil && now.Before(*until) {
return stats, &api.FederationClientError{
RetryAfter: time.Until(*until),
}
}
return stats, nil
}
func failBlacklistableError(err error, stats *statistics.ServerStatistics) (until time.Time, blacklisted bool) {
if err == nil {
return
}
mxerr, ok := err.(gomatrix.HTTPError)
if !ok {
return stats.Failure()
}
if mxerr.Code == 401 { // invalid signature in X-Matrix header
return stats.Failure()
}
if mxerr.Code >= 500 && mxerr.Code < 600 { // internal server errors
return stats.Failure()
}
return
}
func (a *FederationInternalAPI) doRequest(
s gomatrixserverlib.ServerName, request func() (interface{}, error),
) (interface{}, error) {
stats, err := a.isBlacklistedOrBackingOff(s)
if err != nil {
return nil, err
}
res, err := request()
if err != nil {
until, blacklisted := failBlacklistableError(err, stats)
now := time.Now()
var retryAfter time.Duration
if until.After(now) {
retryAfter = time.Until(until)
}
return res, &api.FederationClientError{
Err: err.Error(),
Blacklisted: blacklisted,
RetryAfter: retryAfter,
}
}
stats.Success()
return res, nil
}
func (a *FederationInternalAPI) GetUserDevices(
ctx context.Context, s gomatrixserverlib.ServerName, userID string,
) (gomatrixserverlib.RespUserDevices, error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel()
ires, err := a.doRequest(s, func() (interface{}, error) {
return a.federation.GetUserDevices(ctx, s, userID)
})
if err != nil {
return gomatrixserverlib.RespUserDevices{}, err
}
return ires.(gomatrixserverlib.RespUserDevices), nil
}
func (a *FederationInternalAPI) ClaimKeys(
ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string,
) (gomatrixserverlib.RespClaimKeys, error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel()
ires, err := a.doRequest(s, func() (interface{}, error) {
return a.federation.ClaimKeys(ctx, s, oneTimeKeys)
})
if err != nil {
return gomatrixserverlib.RespClaimKeys{}, err
}
return ires.(gomatrixserverlib.RespClaimKeys), nil
}
func (a *FederationInternalAPI) QueryKeys(
ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string,
) (gomatrixserverlib.RespQueryKeys, error) {
ires, err := a.doRequest(s, func() (interface{}, error) {
return a.federation.QueryKeys(ctx, s, keys)
})
if err != nil {
return gomatrixserverlib.RespQueryKeys{}, err
}
return ires.(gomatrixserverlib.RespQueryKeys), nil
}
func (a *FederationInternalAPI) Backfill(
ctx context.Context, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string,
) (res gomatrixserverlib.Transaction, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel()
ires, err := a.doRequest(s, func() (interface{}, error) {
return a.federation.Backfill(ctx, s, roomID, limit, eventIDs)
})
if err != nil {
return gomatrixserverlib.Transaction{}, err
}
return ires.(gomatrixserverlib.Transaction), nil
}
func (a *FederationInternalAPI) LookupState(
ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion,
) (res gomatrixserverlib.RespState, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel()
ires, err := a.doRequest(s, func() (interface{}, error) {
return a.federation.LookupState(ctx, s, roomID, eventID, roomVersion)
})
if err != nil {
return gomatrixserverlib.RespState{}, err
}
return ires.(gomatrixserverlib.RespState), nil
}
func (a *FederationInternalAPI) LookupStateIDs(
ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string,
) (res gomatrixserverlib.RespStateIDs, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel()
ires, err := a.doRequest(s, func() (interface{}, error) {
return a.federation.LookupStateIDs(ctx, s, roomID, eventID)
})
if err != nil {
return gomatrixserverlib.RespStateIDs{}, err
}
return ires.(gomatrixserverlib.RespStateIDs), nil
}
func (a *FederationInternalAPI) GetEvent(
ctx context.Context, s gomatrixserverlib.ServerName, eventID string,
) (res gomatrixserverlib.Transaction, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel()
ires, err := a.doRequest(s, func() (interface{}, error) {
return a.federation.GetEvent(ctx, s, eventID)
})
if err != nil {
return gomatrixserverlib.Transaction{}, err
}
return ires.(gomatrixserverlib.Transaction), nil
}
func (a *FederationInternalAPI) LookupServerKeys(
ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
) ([]gomatrixserverlib.ServerKeys, error) {
ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel()
ires, err := a.doRequest(s, func() (interface{}, error) {
return a.federation.LookupServerKeys(ctx, s, keyRequests)
})
if err != nil {
return []gomatrixserverlib.ServerKeys{}, err
}
return ires.([]gomatrixserverlib.ServerKeys), nil
}
func (a *FederationInternalAPI) MSC2836EventRelationships(
ctx context.Context, s gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest,
roomVersion gomatrixserverlib.RoomVersion,
) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel()
ires, err := a.doRequest(s, func() (interface{}, error) {
return a.federation.MSC2836EventRelationships(ctx, s, r, roomVersion)
})
if err != nil {
return res, err
}
return ires.(gomatrixserverlib.MSC2836EventRelationshipsResponse), nil
}
func (a *FederationInternalAPI) MSC2946Spaces(
ctx context.Context, s gomatrixserverlib.ServerName, roomID string, r gomatrixserverlib.MSC2946SpacesRequest,
) (res gomatrixserverlib.MSC2946SpacesResponse, err error) {
ctx, cancel := context.WithTimeout(ctx, time.Minute)
defer cancel()
ires, err := a.doRequest(s, func() (interface{}, error) {
return a.federation.MSC2946Spaces(ctx, s, roomID, r)
})
if err != nil {
return res, err
}
return ires.(gomatrixserverlib.MSC2946SpacesResponse), nil
}

View file

@ -0,0 +1,248 @@
package internal
import (
"context"
"crypto/ed25519"
"fmt"
"time"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
)
func (s *FederationInternalAPI) KeyRing() *gomatrixserverlib.KeyRing {
// Return a keyring that forces requests to be proxied through the
// below functions. That way we can enforce things like validity
// and keeping the cache up-to-date.
return s.keyRing
}
func (s *FederationInternalAPI) StoreKeys(
_ context.Context,
results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult,
) error {
// Run in a background context - we don't want to stop this work just
// because the caller gives up waiting.
ctx := context.Background()
// Store any keys that we were given in our database.
return s.keyRing.KeyDatabase.StoreKeys(ctx, results)
}
func (s *FederationInternalAPI) FetchKeys(
_ context.Context,
requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) {
// Run in a background context - we don't want to stop this work just
// because the caller gives up waiting.
ctx := context.Background()
now := gomatrixserverlib.AsTimestamp(time.Now())
results := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{}
origRequests := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp{}
for k, v := range requests {
origRequests[k] = v
}
// First, check if any of these key checks are for our own keys. If
// they are then we will satisfy them directly.
s.handleLocalKeys(ctx, requests, results)
// Then consult our local database and see if we have the requested
// keys. These might come from a cache, depending on the database
// implementation used.
if err := s.handleDatabaseKeys(ctx, now, requests, results); err != nil {
return nil, err
}
// For any key requests that we still have outstanding, next try to
// fetch them directly. We'll go through each of the key fetchers to
// ask for the remaining keys
for _, fetcher := range s.keyRing.KeyFetchers {
// If there are no more keys to look up then stop.
if len(requests) == 0 {
break
}
// Ask the fetcher to look up our keys.
if err := s.handleFetcherKeys(ctx, now, fetcher, requests, results); err != nil {
logrus.WithError(err).WithFields(logrus.Fields{
"fetcher_name": fetcher.FetcherName(),
}).Errorf("Failed to retrieve %d key(s)", len(requests))
continue
}
}
// Check that we've actually satisfied all of the key requests that we
// were given. We should report an error if we didn't.
for req := range origRequests {
if _, ok := results[req]; !ok {
// The results don't contain anything for this specific request, so
// we've failed to satisfy it from local keys, database keys or from
// all of the fetchers. Report an error.
logrus.Warnf("Failed to retrieve key %q for server %q", req.KeyID, req.ServerName)
}
}
// Return the keys.
return results, nil
}
func (s *FederationInternalAPI) FetcherName() string {
return fmt.Sprintf("FederationInternalAPI (wrapping %q)", s.keyRing.KeyDatabase.FetcherName())
}
// handleLocalKeys handles cases where the key request contains
// a request for our own server keys, either current or old.
func (s *FederationInternalAPI) handleLocalKeys(
_ context.Context,
requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult,
) {
for req := range requests {
if req.ServerName != s.cfg.Matrix.ServerName {
continue
}
if req.KeyID == s.cfg.Matrix.KeyID {
// We found a key request that is supposed to be for our own
// keys. Remove it from the request list so we don't hit the
// database or the fetchers for it.
delete(requests, req)
// Insert our own key into the response.
results[req] = gomatrixserverlib.PublicKeyLookupResult{
VerifyKey: gomatrixserverlib.VerifyKey{
Key: gomatrixserverlib.Base64Bytes(s.cfg.Matrix.PrivateKey.Public().(ed25519.PublicKey)),
},
ExpiredTS: gomatrixserverlib.PublicKeyNotExpired,
ValidUntilTS: gomatrixserverlib.AsTimestamp(time.Now().Add(s.cfg.Matrix.KeyValidityPeriod)),
}
} else {
// The key request doesn't match our current key. Let's see
// if it matches any of our old verify keys.
for _, oldVerifyKey := range s.cfg.Matrix.OldVerifyKeys {
if req.KeyID == oldVerifyKey.KeyID {
// We found a key request that is supposed to be an expired
// key.
delete(requests, req)
// Insert our own key into the response.
results[req] = gomatrixserverlib.PublicKeyLookupResult{
VerifyKey: gomatrixserverlib.VerifyKey{
Key: gomatrixserverlib.Base64Bytes(oldVerifyKey.PrivateKey.Public().(ed25519.PublicKey)),
},
ExpiredTS: oldVerifyKey.ExpiredAt,
ValidUntilTS: gomatrixserverlib.PublicKeyNotValid,
}
// No need to look at the other keys.
break
}
}
}
}
}
// handleDatabaseKeys handles cases where the key requests can be
// satisfied from our local database/cache.
func (s *FederationInternalAPI) handleDatabaseKeys(
ctx context.Context,
now gomatrixserverlib.Timestamp,
requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult,
) error {
// Ask the database/cache for the keys.
dbResults, err := s.keyRing.KeyDatabase.FetchKeys(ctx, requests)
if err != nil {
return err
}
// We successfully got some keys. Add them to the results.
for req, res := range dbResults {
// The key we've retrieved from the database/cache might
// have passed its validity period, but right now, it's
// the best thing we've got, and it might be sufficient to
// verify a past event.
results[req] = res
// If the key is valid right now then we can also remove it
// from the request list as we don't need to fetch it again
// in that case. If the key isn't valid right now, then by
// leaving it in the 'requests' map, we'll try to update the
// key using the fetchers in handleFetcherKeys.
if res.WasValidAt(now, true) {
delete(requests, req)
}
}
return nil
}
// handleFetcherKeys handles cases where a fetcher can satisfy
// the remaining requests.
func (s *FederationInternalAPI) handleFetcherKeys(
ctx context.Context,
_ gomatrixserverlib.Timestamp,
fetcher gomatrixserverlib.KeyFetcher,
requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult,
) error {
logrus.WithFields(logrus.Fields{
"fetcher_name": fetcher.FetcherName(),
}).Infof("Fetching %d key(s)", len(requests))
// Create a context that limits our requests to 30 seconds.
fetcherCtx, fetcherCancel := context.WithTimeout(ctx, time.Second*30)
defer fetcherCancel()
// Try to fetch the keys.
fetcherResults, err := fetcher.FetchKeys(fetcherCtx, requests)
if err != nil {
return fmt.Errorf("fetcher.FetchKeys: %w", err)
}
// Build a map of the results that we want to commit to the
// database. We do this in a separate map because otherwise we
// might end up trying to rewrite database entries.
storeResults := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{}
// Now let's look at the results that we got from this fetcher.
for req, res := range fetcherResults {
if prev, ok := results[req]; ok {
// We've already got a previous entry for this request
// so let's see if the newly retrieved one contains a more
// up-to-date validity period.
if res.ValidUntilTS > prev.ValidUntilTS {
// This key is newer than the one we had so let's store
// it in the database.
storeResults[req] = res
}
} else {
// We didn't already have a previous entry for this request
// so store it in the database anyway for now.
storeResults[req] = res
}
// Update the results map with this new result. If nothing
// else, we can try verifying against this key.
results[req] = res
// Remove it from the request list so we won't re-fetch it.
delete(requests, req)
}
// Store the keys from our store map.
if err = s.keyRing.KeyDatabase.StoreKeys(context.Background(), storeResults); err != nil {
logrus.WithError(err).WithFields(logrus.Fields{
"fetcher_name": fetcher.FetcherName(),
"database_name": s.keyRing.KeyDatabase.FetcherName(),
}).Errorf("Failed to store keys in the database")
return fmt.Errorf("server key API failed to store retrieved keys: %w", err)
}
if len(storeResults) > 0 {
logrus.WithFields(logrus.Fields{
"fetcher_name": fetcher.FetcherName(),
}).Infof("Updated %d of %d key(s) in database (%d keys remaining)", len(storeResults), len(results), len(requests))
}
return nil
}

View file

@ -0,0 +1,727 @@
package internal
import (
"context"
"encoding/json"
"errors"
"fmt"
"time"
"github.com/matrix-org/dendrite/federationapi/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/version"
"github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
)
// PerformLeaveRequest implements api.FederationInternalAPI
func (r *FederationInternalAPI) PerformDirectoryLookup(
ctx context.Context,
request *api.PerformDirectoryLookupRequest,
response *api.PerformDirectoryLookupResponse,
) (err error) {
dir, err := r.federation.LookupRoomAlias(
ctx,
request.ServerName,
request.RoomAlias,
)
if err != nil {
r.statistics.ForServer(request.ServerName).Failure()
return err
}
response.RoomID = dir.RoomID
response.ServerNames = dir.Servers
r.statistics.ForServer(request.ServerName).Success()
return nil
}
type federatedJoin struct {
UserID string
RoomID string
}
// PerformJoin implements api.FederationInternalAPI
func (r *FederationInternalAPI) PerformJoin(
ctx context.Context,
request *api.PerformJoinRequest,
response *api.PerformJoinResponse,
) {
// Check that a join isn't already in progress for this user/room.
j := federatedJoin{request.UserID, request.RoomID}
if _, found := r.joins.Load(j); found {
response.LastError = &gomatrix.HTTPError{
Code: 429,
Message: `{
"errcode": "M_LIMIT_EXCEEDED",
"error": "There is already a federated join to this room in progress. Please wait for it to finish."
}`, // TODO: Why do none of our error types play nicely with each other?
}
return
}
r.joins.Store(j, nil)
defer r.joins.Delete(j)
// Look up the supported room versions.
var supportedVersions []gomatrixserverlib.RoomVersion
for version := range version.SupportedRoomVersions() {
supportedVersions = append(supportedVersions, version)
}
// Deduplicate the server names we were provided but keep the ordering
// as this encodes useful information about which servers are most likely
// to respond.
seenSet := make(map[gomatrixserverlib.ServerName]bool)
var uniqueList []gomatrixserverlib.ServerName
for _, srv := range request.ServerNames {
if seenSet[srv] {
continue
}
seenSet[srv] = true
uniqueList = append(uniqueList, srv)
}
request.ServerNames = uniqueList
// Try each server that we were provided until we land on one that
// successfully completes the make-join send-join dance.
var lastErr error
for _, serverName := range request.ServerNames {
if err := r.performJoinUsingServer(
ctx,
request.RoomID,
request.UserID,
request.Content,
serverName,
supportedVersions,
); err != nil {
logrus.WithError(err).WithFields(logrus.Fields{
"server_name": serverName,
"room_id": request.RoomID,
}).Warnf("Failed to join room through server")
lastErr = err
continue
}
// We're all good.
response.JoinedVia = serverName
return
}
// If we reach here then we didn't complete a join for some reason.
var httpErr gomatrix.HTTPError
if ok := errors.As(lastErr, &httpErr); ok {
httpErr.Message = string(httpErr.Contents)
// Clear the wrapped error, else serialising to JSON (in polylith mode) will fail
httpErr.WrappedError = nil
response.LastError = &httpErr
} else {
response.LastError = &gomatrix.HTTPError{
Code: 0,
WrappedError: nil,
Message: "Unknown HTTP error",
}
if lastErr != nil {
response.LastError.Message = lastErr.Error()
}
}
logrus.Errorf(
"failed to join user %q to room %q through %d server(s): last error %s",
request.UserID, request.RoomID, len(request.ServerNames), lastErr,
)
}
func (r *FederationInternalAPI) performJoinUsingServer(
ctx context.Context,
roomID, userID string,
content map[string]interface{},
serverName gomatrixserverlib.ServerName,
supportedVersions []gomatrixserverlib.RoomVersion,
) error {
// Try to perform a make_join using the information supplied in the
// request.
respMakeJoin, err := r.federation.MakeJoin(
ctx,
serverName,
roomID,
userID,
supportedVersions,
)
if err != nil {
// TODO: Check if the user was not allowed to join the room.
r.statistics.ForServer(serverName).Failure()
return fmt.Errorf("r.federation.MakeJoin: %w", err)
}
r.statistics.ForServer(serverName).Success()
// Set all the fields to be what they should be, this should be a no-op
// but it's possible that the remote server returned us something "odd"
respMakeJoin.JoinEvent.Type = gomatrixserverlib.MRoomMember
respMakeJoin.JoinEvent.Sender = userID
respMakeJoin.JoinEvent.StateKey = &userID
respMakeJoin.JoinEvent.RoomID = roomID
respMakeJoin.JoinEvent.Redacts = ""
if content == nil {
content = map[string]interface{}{}
}
content["membership"] = "join"
if err = respMakeJoin.JoinEvent.SetContent(content); err != nil {
return fmt.Errorf("respMakeJoin.JoinEvent.SetContent: %w", err)
}
if err = respMakeJoin.JoinEvent.SetUnsigned(struct{}{}); err != nil {
return fmt.Errorf("respMakeJoin.JoinEvent.SetUnsigned: %w", err)
}
// Work out if we support the room version that has been supplied in
// the make_join response.
// "If not provided, the room version is assumed to be either "1" or "2"."
// https://matrix.org/docs/spec/server_server/unstable#get-matrix-federation-v1-make-join-roomid-userid
if respMakeJoin.RoomVersion == "" {
respMakeJoin.RoomVersion = setDefaultRoomVersionFromJoinEvent(respMakeJoin.JoinEvent)
}
if _, err = respMakeJoin.RoomVersion.EventFormat(); err != nil {
return fmt.Errorf("respMakeJoin.RoomVersion.EventFormat: %w", err)
}
// Build the join event.
event, err := respMakeJoin.JoinEvent.Build(
time.Now(),
r.cfg.Matrix.ServerName,
r.cfg.Matrix.KeyID,
r.cfg.Matrix.PrivateKey,
respMakeJoin.RoomVersion,
)
if err != nil {
return fmt.Errorf("respMakeJoin.JoinEvent.Build: %w", err)
}
// No longer reuse the request context from this point forward.
// We don't want the client timing out to interrupt the join.
var cancel context.CancelFunc
ctx, cancel = context.WithCancel(context.Background())
// Try to perform a send_join using the newly built event.
respSendJoin, err := r.federation.SendJoin(
ctx,
serverName,
event,
respMakeJoin.RoomVersion,
)
if err != nil {
r.statistics.ForServer(serverName).Failure()
cancel()
return fmt.Errorf("r.federation.SendJoin: %w", err)
}
r.statistics.ForServer(serverName).Success()
// Sanity-check the join response to ensure that it has a create
// event, that the room version is known, etc.
if err := sanityCheckAuthChain(respSendJoin.AuthEvents); err != nil {
cancel()
return fmt.Errorf("sanityCheckAuthChain: %w", err)
}
// Process the join response in a goroutine. The idea here is
// that we'll try and wait for as long as possible for the work
// to complete, but if the client does give up waiting, we'll
// still continue to process the join anyway so that we don't
// waste the effort.
go func() {
defer cancel()
// TODO: Can we expand Check here to return a list of missing auth
// events rather than failing one at a time?
respState, err := respSendJoin.Check(ctx, r.keyRing, event, federatedAuthProvider(ctx, r.federation, r.keyRing, serverName))
if err != nil {
logrus.WithFields(logrus.Fields{
"room_id": roomID,
"user_id": userID,
}).WithError(err).Error("Failed to process room join response")
return
}
// If we successfully performed a send_join above then the other
// server now thinks we're a part of the room. Send the newly
// returned state to the roomserver to update our local view.
if err = roomserverAPI.SendEventWithState(
ctx, r.rsAPI,
roomserverAPI.KindNew,
respState,
event.Headered(respMakeJoin.RoomVersion),
nil,
); err != nil {
logrus.WithFields(logrus.Fields{
"room_id": roomID,
"user_id": userID,
}).WithError(err).Error("Failed to send room join response to roomserver")
return
}
}()
<-ctx.Done()
return nil
}
// PerformOutboundPeekRequest implements api.FederationInternalAPI
func (r *FederationInternalAPI) PerformOutboundPeek(
ctx context.Context,
request *api.PerformOutboundPeekRequest,
response *api.PerformOutboundPeekResponse,
) error {
// Look up the supported room versions.
var supportedVersions []gomatrixserverlib.RoomVersion
for version := range version.SupportedRoomVersions() {
supportedVersions = append(supportedVersions, version)
}
// Deduplicate the server names we were provided but keep the ordering
// as this encodes useful information about which servers are most likely
// to respond.
seenSet := make(map[gomatrixserverlib.ServerName]bool)
var uniqueList []gomatrixserverlib.ServerName
for _, srv := range request.ServerNames {
if seenSet[srv] {
continue
}
seenSet[srv] = true
uniqueList = append(uniqueList, srv)
}
request.ServerNames = uniqueList
// See if there's an existing outbound peek for this room ID with
// one of the specified servers.
if peeks, err := r.db.GetOutboundPeeks(ctx, request.RoomID); err == nil {
for _, peek := range peeks {
if _, ok := seenSet[peek.ServerName]; ok {
return nil
}
}
}
// Try each server that we were provided until we land on one that
// successfully completes the peek
var lastErr error
for _, serverName := range request.ServerNames {
if err := r.performOutboundPeekUsingServer(
ctx,
request.RoomID,
serverName,
supportedVersions,
); err != nil {
logrus.WithError(err).WithFields(logrus.Fields{
"server_name": serverName,
"room_id": request.RoomID,
}).Warnf("Failed to peek room through server")
lastErr = err
continue
}
// We're all good.
return nil
}
// If we reach here then we didn't complete a peek for some reason.
var httpErr gomatrix.HTTPError
if ok := errors.As(lastErr, &httpErr); ok {
httpErr.Message = string(httpErr.Contents)
// Clear the wrapped error, else serialising to JSON (in polylith mode) will fail
httpErr.WrappedError = nil
response.LastError = &httpErr
} else {
response.LastError = &gomatrix.HTTPError{
Code: 0,
WrappedError: nil,
Message: lastErr.Error(),
}
}
logrus.Errorf(
"failed to peek room %q through %d server(s): last error %s",
request.RoomID, len(request.ServerNames), lastErr,
)
return lastErr
}
func (r *FederationInternalAPI) performOutboundPeekUsingServer(
ctx context.Context,
roomID string,
serverName gomatrixserverlib.ServerName,
supportedVersions []gomatrixserverlib.RoomVersion,
) error {
// create a unique ID for this peek.
// for now we just use the room ID again. In future, if we ever
// support concurrent peeks to the same room with different filters
// then we would need to disambiguate further.
peekID := roomID
// check whether we're peeking already to try to avoid needlessly
// re-peeking on the server. we don't need a transaction for this,
// given this is a nice-to-have.
outboundPeek, err := r.db.GetOutboundPeek(ctx, serverName, roomID, peekID)
if err != nil {
return err
}
renewing := false
if outboundPeek != nil {
nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
if nowMilli > outboundPeek.RenewedTimestamp+outboundPeek.RenewalInterval {
logrus.Infof("stale outbound peek to %s for %s already exists; renewing", serverName, roomID)
renewing = true
} else {
logrus.Infof("live outbound peek to %s for %s already exists", serverName, roomID)
return nil
}
}
// Try to perform an outbound /peek using the information supplied in the
// request.
respPeek, err := r.federation.Peek(
ctx,
serverName,
roomID,
peekID,
supportedVersions,
)
if err != nil {
r.statistics.ForServer(serverName).Failure()
return fmt.Errorf("r.federation.Peek: %w", err)
}
r.statistics.ForServer(serverName).Success()
// Work out if we support the room version that has been supplied in
// the peek response.
if respPeek.RoomVersion == "" {
respPeek.RoomVersion = gomatrixserverlib.RoomVersionV1
}
if _, err = respPeek.RoomVersion.EventFormat(); err != nil {
return fmt.Errorf("respPeek.RoomVersion.EventFormat: %w", err)
}
// we have the peek state now so let's process regardless of whether upstream gives up
ctx = context.Background()
respState := respPeek.ToRespState()
// authenticate the state returned (check its auth events etc)
// the equivalent of CheckSendJoinResponse()
if err = sanityCheckAuthChain(respState.AuthEvents); err != nil {
return fmt.Errorf("sanityCheckAuthChain: %w", err)
}
if err = respState.Check(ctx, r.keyRing, federatedAuthProvider(ctx, r.federation, r.keyRing, serverName)); err != nil {
return fmt.Errorf("error checking state returned from peeking: %w", err)
}
// If we've got this far, the remote server is peeking.
if renewing {
if err = r.db.RenewOutboundPeek(ctx, serverName, roomID, peekID, respPeek.RenewalInterval); err != nil {
return err
}
} else {
if err = r.db.AddOutboundPeek(ctx, serverName, roomID, peekID, respPeek.RenewalInterval); err != nil {
return err
}
}
// logrus.Warnf("got respPeek %#v", respPeek)
// Send the newly returned state to the roomserver to update our local view.
if err = roomserverAPI.SendEventWithState(
ctx, r.rsAPI,
roomserverAPI.KindNew,
&respState,
respPeek.LatestEvent.Headered(respPeek.RoomVersion),
nil,
); err != nil {
return fmt.Errorf("r.producer.SendEventWithState: %w", err)
}
return nil
}
// PerformLeaveRequest implements api.FederationInternalAPI
func (r *FederationInternalAPI) PerformLeave(
ctx context.Context,
request *api.PerformLeaveRequest,
response *api.PerformLeaveResponse,
) (err error) {
// Deduplicate the server names we were provided.
util.SortAndUnique(request.ServerNames)
// Try each server that we were provided until we land on one that
// successfully completes the make-leave send-leave dance.
for _, serverName := range request.ServerNames {
// Try to perform a make_leave using the information supplied in the
// request.
respMakeLeave, err := r.federation.MakeLeave(
ctx,
serverName,
request.RoomID,
request.UserID,
)
if err != nil {
// TODO: Check if the user was not allowed to leave the room.
logrus.WithError(err).Warnf("r.federation.MakeLeave failed")
r.statistics.ForServer(serverName).Failure()
continue
}
// Set all the fields to be what they should be, this should be a no-op
// but it's possible that the remote server returned us something "odd"
respMakeLeave.LeaveEvent.Type = gomatrixserverlib.MRoomMember
respMakeLeave.LeaveEvent.Sender = request.UserID
respMakeLeave.LeaveEvent.StateKey = &request.UserID
respMakeLeave.LeaveEvent.RoomID = request.RoomID
respMakeLeave.LeaveEvent.Redacts = ""
if respMakeLeave.LeaveEvent.Content == nil {
content := map[string]interface{}{
"membership": "leave",
}
if err = respMakeLeave.LeaveEvent.SetContent(content); err != nil {
logrus.WithError(err).Warnf("respMakeLeave.LeaveEvent.SetContent failed")
continue
}
}
if err = respMakeLeave.LeaveEvent.SetUnsigned(struct{}{}); err != nil {
logrus.WithError(err).Warnf("respMakeLeave.LeaveEvent.SetUnsigned failed")
continue
}
// Work out if we support the room version that has been supplied in
// the make_leave response.
if _, err = respMakeLeave.RoomVersion.EventFormat(); err != nil {
return gomatrixserverlib.UnsupportedRoomVersionError{}
}
// Build the leave event.
event, err := respMakeLeave.LeaveEvent.Build(
time.Now(),
r.cfg.Matrix.ServerName,
r.cfg.Matrix.KeyID,
r.cfg.Matrix.PrivateKey,
respMakeLeave.RoomVersion,
)
if err != nil {
logrus.WithError(err).Warnf("respMakeLeave.LeaveEvent.Build failed")
continue
}
// Try to perform a send_leave using the newly built event.
err = r.federation.SendLeave(
ctx,
serverName,
event,
)
if err != nil {
logrus.WithError(err).Warnf("r.federation.SendLeave failed")
r.statistics.ForServer(serverName).Failure()
continue
}
r.statistics.ForServer(serverName).Success()
return nil
}
// If we reach here then we didn't complete a leave for some reason.
return fmt.Errorf(
"failed to leave room %q through %d server(s)",
request.RoomID, len(request.ServerNames),
)
}
// PerformLeaveRequest implements api.FederationInternalAPI
func (r *FederationInternalAPI) PerformInvite(
ctx context.Context,
request *api.PerformInviteRequest,
response *api.PerformInviteResponse,
) (err error) {
if request.Event.StateKey() == nil {
return errors.New("invite must be a state event")
}
_, destination, err := gomatrixserverlib.SplitID('@', *request.Event.StateKey())
if err != nil {
return fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
}
logrus.WithFields(logrus.Fields{
"event_id": request.Event.EventID(),
"user_id": *request.Event.StateKey(),
"room_id": request.Event.RoomID(),
"room_version": request.RoomVersion,
"destination": destination,
}).Info("Sending invite")
inviteReq, err := gomatrixserverlib.NewInviteV2Request(request.Event, request.InviteRoomState)
if err != nil {
return fmt.Errorf("gomatrixserverlib.NewInviteV2Request: %w", err)
}
inviteRes, err := r.federation.SendInviteV2(ctx, destination, inviteReq)
if err != nil {
return fmt.Errorf("r.federation.SendInviteV2: %w", err)
}
response.Event = inviteRes.Event.Headered(request.RoomVersion)
return nil
}
// PerformServersAlive implements api.FederationInternalAPI
func (r *FederationInternalAPI) PerformServersAlive(
ctx context.Context,
request *api.PerformServersAliveRequest,
response *api.PerformServersAliveResponse,
) (err error) {
for _, srv := range request.Servers {
_ = r.db.RemoveServerFromBlacklist(srv)
r.queues.RetryServer(srv)
}
return nil
}
// PerformServersAlive implements api.FederationInternalAPI
func (r *FederationInternalAPI) PerformBroadcastEDU(
ctx context.Context,
request *api.PerformBroadcastEDURequest,
response *api.PerformBroadcastEDUResponse,
) (err error) {
destinations, err := r.db.GetAllJoinedHosts(ctx)
if err != nil {
return fmt.Errorf("r.db.GetAllJoinedHosts: %w", err)
}
if len(destinations) == 0 {
return nil
}
logrus.WithContext(ctx).Infof("Sending wake-up EDU to %d destination(s)", len(destinations))
edu := &gomatrixserverlib.EDU{
Type: "org.matrix.dendrite.wakeup",
Origin: string(r.cfg.Matrix.ServerName),
}
if err = r.queues.SendEDU(edu, r.cfg.Matrix.ServerName, destinations); err != nil {
return fmt.Errorf("r.queues.SendEDU: %w", err)
}
wakeReq := &api.PerformServersAliveRequest{
Servers: destinations,
}
wakeRes := &api.PerformServersAliveResponse{}
if err := r.PerformServersAlive(ctx, wakeReq, wakeRes); err != nil {
return fmt.Errorf("r.PerformServersAlive: %w", err)
}
return nil
}
func sanityCheckAuthChain(authChain []*gomatrixserverlib.Event) error {
// sanity check we have a create event and it has a known room version
for _, ev := range authChain {
if ev.Type() == gomatrixserverlib.MRoomCreate && ev.StateKeyEquals("") {
// make sure the room version is known
content := ev.Content()
verBody := struct {
Version string `json:"room_version"`
}{}
err := json.Unmarshal(content, &verBody)
if err != nil {
return err
}
if verBody.Version == "" {
// https://matrix.org/docs/spec/client_server/r0.6.0#m-room-create
// The version of the room. Defaults to "1" if the key does not exist.
verBody.Version = "1"
}
knownVersions := gomatrixserverlib.RoomVersions()
if _, ok := knownVersions[gomatrixserverlib.RoomVersion(verBody.Version)]; !ok {
return fmt.Errorf("auth chain m.room.create event has an unknown room version: %s", verBody.Version)
}
return nil
}
}
return fmt.Errorf("auth chain response is missing m.room.create event")
}
func setDefaultRoomVersionFromJoinEvent(joinEvent gomatrixserverlib.EventBuilder) gomatrixserverlib.RoomVersion {
// if auth events are not event references we know it must be v3+
// we have to do these shenanigans to satisfy sytest, specifically for:
// "Outbound federation rejects m.room.create events with an unknown room version"
hasEventRefs := true
authEvents, ok := joinEvent.AuthEvents.([]interface{})
if ok {
if len(authEvents) > 0 {
_, ok = authEvents[0].(string)
if ok {
// event refs are objects, not strings, so we know we must be dealing with a v3+ room.
hasEventRefs = false
}
}
}
if hasEventRefs {
return gomatrixserverlib.RoomVersionV1
}
return gomatrixserverlib.RoomVersionV4
}
// FederatedAuthProvider is an auth chain provider which fetches events from the server provided
func federatedAuthProvider(
ctx context.Context, federation *gomatrixserverlib.FederationClient,
keyRing gomatrixserverlib.JSONVerifier, server gomatrixserverlib.ServerName,
) gomatrixserverlib.AuthChainProvider {
// A list of events that we have retried, if they were not included in
// the auth events supplied in the send_join.
retries := map[string][]*gomatrixserverlib.Event{}
// Define a function which we can pass to Check to retrieve missing
// auth events inline. This greatly increases our chances of not having
// to repeat the entire set of checks just for a missing event or two.
return func(roomVersion gomatrixserverlib.RoomVersion, eventIDs []string) ([]*gomatrixserverlib.Event, error) {
returning := []*gomatrixserverlib.Event{}
// See if we have retry entries for each of the supplied event IDs.
for _, eventID := range eventIDs {
// If we've already satisfied a request for this event ID before then
// just append the results. We won't retry the request.
if retry, ok := retries[eventID]; ok {
if retry == nil {
return nil, fmt.Errorf("missingAuth: not retrying failed event ID %q", eventID)
}
returning = append(returning, retry...)
continue
}
// Make a note of the fact that we tried to do something with this
// event ID, even if we don't succeed.
retries[eventID] = nil
// Try to retrieve the event from the server that sent us the send
// join response.
tx, txerr := federation.GetEvent(ctx, server, eventID)
if txerr != nil {
return nil, fmt.Errorf("missingAuth r.federation.GetEvent: %w", txerr)
}
// For each event returned, add it to the set of return events. We
// also will populate the retries, in case someone asks for this
// event ID again.
for _, pdu := range tx.PDUs {
// Try to parse the event.
ev, everr := gomatrixserverlib.NewEventFromUntrustedJSON(pdu, roomVersion)
if everr != nil {
return nil, fmt.Errorf("missingAuth gomatrixserverlib.NewEventFromUntrustedJSON: %w", everr)
}
// Check the signatures of the event.
if err := ev.VerifyEventSignatures(ctx, keyRing); err != nil {
return nil, fmt.Errorf("missingAuth VerifyEventSignatures: %w", err)
}
// If the event is OK then add it to the results and the retry map.
returning = append(returning, ev)
retries[ev.EventID()] = append(retries[ev.EventID()], ev)
}
}
return returning, nil
}
}

View file

@ -0,0 +1,97 @@
package internal
import (
"context"
"fmt"
"time"
"github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
)
// QueryJoinedHostServerNamesInRoom implements api.FederationInternalAPI
func (f *FederationInternalAPI) QueryJoinedHostServerNamesInRoom(
ctx context.Context,
request *api.QueryJoinedHostServerNamesInRoomRequest,
response *api.QueryJoinedHostServerNamesInRoomResponse,
) (err error) {
joinedHosts, err := f.db.GetJoinedHostsForRooms(ctx, []string{request.RoomID})
if err != nil {
return
}
response.ServerNames = joinedHosts
return
}
func (a *FederationInternalAPI) fetchServerKeysDirectly(ctx context.Context, serverName gomatrixserverlib.ServerName) (*gomatrixserverlib.ServerKeys, error) {
ctx, cancel := context.WithTimeout(ctx, time.Second*30)
defer cancel()
ires, err := a.doRequest(serverName, func() (interface{}, error) {
return a.federation.GetServerKeys(ctx, serverName)
})
if err != nil {
return nil, err
}
sks := ires.(gomatrixserverlib.ServerKeys)
return &sks, nil
}
func (a *FederationInternalAPI) fetchServerKeysFromCache(
ctx context.Context, req *api.QueryServerKeysRequest,
) ([]gomatrixserverlib.ServerKeys, error) {
var results []gomatrixserverlib.ServerKeys
for keyID, criteria := range req.KeyIDToCriteria {
serverKeysResponses, _ := a.db.GetNotaryKeys(ctx, req.ServerName, []gomatrixserverlib.KeyID{keyID})
if len(serverKeysResponses) == 0 {
return nil, fmt.Errorf("failed to find server key response for key ID %s", keyID)
}
// we should only get 1 result as we only gave 1 key ID
sk := serverKeysResponses[0]
util.GetLogger(ctx).Infof("fetchServerKeysFromCache: minvalid:%v keys: %+v", criteria.MinimumValidUntilTS, sk)
if criteria.MinimumValidUntilTS != 0 {
// check if it's still valid. if they have the same value that's also valid
if sk.ValidUntilTS < criteria.MinimumValidUntilTS {
return nil, fmt.Errorf(
"found server response for key ID %s but it is no longer valid, min: %v valid_until: %v",
keyID, criteria.MinimumValidUntilTS, sk.ValidUntilTS,
)
}
}
results = append(results, sk)
}
return results, nil
}
func (a *FederationInternalAPI) QueryServerKeys(
ctx context.Context, req *api.QueryServerKeysRequest, res *api.QueryServerKeysResponse,
) error {
// attempt to satisfy the entire request from the cache first
results, err := a.fetchServerKeysFromCache(ctx, req)
if err == nil {
// satisfied entirely from cache, return it
res.ServerKeys = results
return nil
}
util.GetLogger(ctx).WithField("server", req.ServerName).WithError(err).Warn("notary: failed to satisfy keys request entirely from cache, hitting direct")
serverKeys, err := a.fetchServerKeysDirectly(ctx, req.ServerName)
if err != nil {
// try to load as much as we can from the cache in a best effort basis
util.GetLogger(ctx).WithField("server", req.ServerName).WithError(err).Warn("notary: failed to ask server for keys, returning best effort keys")
serverKeysResponses, dbErr := a.db.GetNotaryKeys(ctx, req.ServerName, req.KeyIDs())
if dbErr != nil {
return fmt.Errorf("notary: server returned %s, and db returned %s", err, dbErr)
}
res.ServerKeys = serverKeysResponses
return nil
}
// cache it!
if err = a.db.UpdateNotaryKeys(context.Background(), req.ServerName, *serverKeys); err != nil {
// non-fatal, still return the response
util.GetLogger(ctx).WithError(err).Warn("failed to UpdateNotaryKeys")
}
res.ServerKeys = []gomatrixserverlib.ServerKeys{*serverKeys}
return nil
}

View file

@ -0,0 +1,575 @@
package inthttp
import (
"context"
"errors"
"net/http"
"github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib"
"github.com/opentracing/opentracing-go"
)
// HTTP paths for the internal HTTP API
const (
FederationAPIQueryJoinedHostServerNamesInRoomPath = "/federationapi/queryJoinedHostServerNamesInRoom"
FederationAPIQueryServerKeysPath = "/federationapi/queryServerKeys"
FederationAPIPerformDirectoryLookupRequestPath = "/federationapi/performDirectoryLookup"
FederationAPIPerformJoinRequestPath = "/federationapi/performJoinRequest"
FederationAPIPerformLeaveRequestPath = "/federationapi/performLeaveRequest"
FederationAPIPerformInviteRequestPath = "/federationapi/performInviteRequest"
FederationAPIPerformOutboundPeekRequestPath = "/federationapi/performOutboundPeekRequest"
FederationAPIPerformServersAlivePath = "/federationapi/performServersAlive"
FederationAPIPerformBroadcastEDUPath = "/federationapi/performBroadcastEDU"
FederationAPIGetUserDevicesPath = "/federationapi/client/getUserDevices"
FederationAPIClaimKeysPath = "/federationapi/client/claimKeys"
FederationAPIQueryKeysPath = "/federationapi/client/queryKeys"
FederationAPIBackfillPath = "/federationapi/client/backfill"
FederationAPILookupStatePath = "/federationapi/client/lookupState"
FederationAPILookupStateIDsPath = "/federationapi/client/lookupStateIDs"
FederationAPIGetEventPath = "/federationapi/client/getEvent"
FederationAPILookupServerKeysPath = "/federationapi/client/lookupServerKeys"
FederationAPIEventRelationshipsPath = "/federationapi/client/msc2836eventRelationships"
FederationAPISpacesSummaryPath = "/federationapi/client/msc2946spacesSummary"
FederationAPIInputPublicKeyPath = "/federationapi/inputPublicKey"
FederationAPIQueryPublicKeyPath = "/federationapi/queryPublicKey"
)
// NewFederationAPIClient creates a FederationInternalAPI implemented by talking to a HTTP POST API.
// If httpClient is nil an error is returned
func NewFederationAPIClient(federationSenderURL string, httpClient *http.Client, cache caching.ServerKeyCache) (api.FederationInternalAPI, error) {
if httpClient == nil {
return nil, errors.New("NewFederationInternalAPIHTTP: httpClient is <nil>")
}
return &httpFederationInternalAPI{federationSenderURL, httpClient, cache}, nil
}
type httpFederationInternalAPI struct {
federationAPIURL string
httpClient *http.Client
cache caching.ServerKeyCache
}
// Handle an instruction to make_leave & send_leave with a remote server.
func (h *httpFederationInternalAPI) PerformLeave(
ctx context.Context,
request *api.PerformLeaveRequest,
response *api.PerformLeaveResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformLeaveRequest")
defer span.Finish()
apiURL := h.federationAPIURL + FederationAPIPerformLeaveRequestPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
}
// Handle sending an invite to a remote server.
func (h *httpFederationInternalAPI) PerformInvite(
ctx context.Context,
request *api.PerformInviteRequest,
response *api.PerformInviteResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformInviteRequest")
defer span.Finish()
apiURL := h.federationAPIURL + FederationAPIPerformInviteRequestPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
}
// Handle starting a peek on a remote server.
func (h *httpFederationInternalAPI) PerformOutboundPeek(
ctx context.Context,
request *api.PerformOutboundPeekRequest,
response *api.PerformOutboundPeekResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformOutboundPeekRequest")
defer span.Finish()
apiURL := h.federationAPIURL + FederationAPIPerformOutboundPeekRequestPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
}
func (h *httpFederationInternalAPI) PerformServersAlive(
ctx context.Context,
request *api.PerformServersAliveRequest,
response *api.PerformServersAliveResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformServersAlive")
defer span.Finish()
apiURL := h.federationAPIURL + FederationAPIPerformServersAlivePath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
}
// QueryJoinedHostServerNamesInRoom implements FederationInternalAPI
func (h *httpFederationInternalAPI) QueryJoinedHostServerNamesInRoom(
ctx context.Context,
request *api.QueryJoinedHostServerNamesInRoomRequest,
response *api.QueryJoinedHostServerNamesInRoomResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryJoinedHostServerNamesInRoom")
defer span.Finish()
apiURL := h.federationAPIURL + FederationAPIQueryJoinedHostServerNamesInRoomPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
}
// Handle an instruction to make_join & send_join with a remote server.
func (h *httpFederationInternalAPI) PerformJoin(
ctx context.Context,
request *api.PerformJoinRequest,
response *api.PerformJoinResponse,
) {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformJoinRequest")
defer span.Finish()
apiURL := h.federationAPIURL + FederationAPIPerformJoinRequestPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
if err != nil {
response.LastError = &gomatrix.HTTPError{
Message: err.Error(),
Code: 0,
WrappedError: err,
}
}
}
// Handle an instruction to make_join & send_join with a remote server.
func (h *httpFederationInternalAPI) PerformDirectoryLookup(
ctx context.Context,
request *api.PerformDirectoryLookupRequest,
response *api.PerformDirectoryLookupResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformDirectoryLookup")
defer span.Finish()
apiURL := h.federationAPIURL + FederationAPIPerformDirectoryLookupRequestPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
}
// Handle an instruction to broadcast an EDU to all servers in rooms we are joined to.
func (h *httpFederationInternalAPI) PerformBroadcastEDU(
ctx context.Context,
request *api.PerformBroadcastEDURequest,
response *api.PerformBroadcastEDUResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformBroadcastEDU")
defer span.Finish()
apiURL := h.federationAPIURL + FederationAPIPerformBroadcastEDUPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
}
type getUserDevices struct {
S gomatrixserverlib.ServerName
UserID string
Res *gomatrixserverlib.RespUserDevices
Err *api.FederationClientError
}
func (h *httpFederationInternalAPI) GetUserDevices(
ctx context.Context, s gomatrixserverlib.ServerName, userID string,
) (gomatrixserverlib.RespUserDevices, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "GetUserDevices")
defer span.Finish()
var result gomatrixserverlib.RespUserDevices
request := getUserDevices{
S: s,
UserID: userID,
}
var response getUserDevices
apiURL := h.federationAPIURL + FederationAPIGetUserDevicesPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
if err != nil {
return result, err
}
if response.Err != nil {
return result, response.Err
}
return *response.Res, nil
}
type claimKeys struct {
S gomatrixserverlib.ServerName
OneTimeKeys map[string]map[string]string
Res *gomatrixserverlib.RespClaimKeys
Err *api.FederationClientError
}
func (h *httpFederationInternalAPI) ClaimKeys(
ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string,
) (gomatrixserverlib.RespClaimKeys, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "ClaimKeys")
defer span.Finish()
var result gomatrixserverlib.RespClaimKeys
request := claimKeys{
S: s,
OneTimeKeys: oneTimeKeys,
}
var response claimKeys
apiURL := h.federationAPIURL + FederationAPIClaimKeysPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
if err != nil {
return result, err
}
if response.Err != nil {
return result, response.Err
}
return *response.Res, nil
}
type queryKeys struct {
S gomatrixserverlib.ServerName
Keys map[string][]string
Res *gomatrixserverlib.RespQueryKeys
Err *api.FederationClientError
}
func (h *httpFederationInternalAPI) QueryKeys(
ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string,
) (gomatrixserverlib.RespQueryKeys, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKeys")
defer span.Finish()
var result gomatrixserverlib.RespQueryKeys
request := queryKeys{
S: s,
Keys: keys,
}
var response queryKeys
apiURL := h.federationAPIURL + FederationAPIQueryKeysPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
if err != nil {
return result, err
}
if response.Err != nil {
return result, response.Err
}
return *response.Res, nil
}
type backfill struct {
S gomatrixserverlib.ServerName
RoomID string
Limit int
EventIDs []string
Res *gomatrixserverlib.Transaction
Err *api.FederationClientError
}
func (h *httpFederationInternalAPI) Backfill(
ctx context.Context, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string,
) (gomatrixserverlib.Transaction, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "Backfill")
defer span.Finish()
request := backfill{
S: s,
RoomID: roomID,
Limit: limit,
EventIDs: eventIDs,
}
var response backfill
apiURL := h.federationAPIURL + FederationAPIBackfillPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
if err != nil {
return gomatrixserverlib.Transaction{}, err
}
if response.Err != nil {
return gomatrixserverlib.Transaction{}, response.Err
}
return *response.Res, nil
}
type lookupState struct {
S gomatrixserverlib.ServerName
RoomID string
EventID string
RoomVersion gomatrixserverlib.RoomVersion
Res *gomatrixserverlib.RespState
Err *api.FederationClientError
}
func (h *httpFederationInternalAPI) LookupState(
ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion,
) (gomatrixserverlib.RespState, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "LookupState")
defer span.Finish()
request := lookupState{
S: s,
RoomID: roomID,
EventID: eventID,
RoomVersion: roomVersion,
}
var response lookupState
apiURL := h.federationAPIURL + FederationAPILookupStatePath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
if err != nil {
return gomatrixserverlib.RespState{}, err
}
if response.Err != nil {
return gomatrixserverlib.RespState{}, response.Err
}
return *response.Res, nil
}
type lookupStateIDs struct {
S gomatrixserverlib.ServerName
RoomID string
EventID string
Res *gomatrixserverlib.RespStateIDs
Err *api.FederationClientError
}
func (h *httpFederationInternalAPI) LookupStateIDs(
ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string,
) (gomatrixserverlib.RespStateIDs, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "LookupStateIDs")
defer span.Finish()
request := lookupStateIDs{
S: s,
RoomID: roomID,
EventID: eventID,
}
var response lookupStateIDs
apiURL := h.federationAPIURL + FederationAPILookupStateIDsPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
if err != nil {
return gomatrixserverlib.RespStateIDs{}, err
}
if response.Err != nil {
return gomatrixserverlib.RespStateIDs{}, response.Err
}
return *response.Res, nil
}
type getEvent struct {
S gomatrixserverlib.ServerName
EventID string
Res *gomatrixserverlib.Transaction
Err *api.FederationClientError
}
func (h *httpFederationInternalAPI) GetEvent(
ctx context.Context, s gomatrixserverlib.ServerName, eventID string,
) (gomatrixserverlib.Transaction, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "GetEvent")
defer span.Finish()
request := getEvent{
S: s,
EventID: eventID,
}
var response getEvent
apiURL := h.federationAPIURL + FederationAPIGetEventPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
if err != nil {
return gomatrixserverlib.Transaction{}, err
}
if response.Err != nil {
return gomatrixserverlib.Transaction{}, response.Err
}
return *response.Res, nil
}
func (h *httpFederationInternalAPI) QueryServerKeys(
ctx context.Context, req *api.QueryServerKeysRequest, res *api.QueryServerKeysResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryServerKeys")
defer span.Finish()
apiURL := h.federationAPIURL + FederationAPIQueryServerKeysPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
}
type lookupServerKeys struct {
S gomatrixserverlib.ServerName
KeyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp
ServerKeys []gomatrixserverlib.ServerKeys
Err *api.FederationClientError
}
func (h *httpFederationInternalAPI) LookupServerKeys(
ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
) ([]gomatrixserverlib.ServerKeys, error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "LookupServerKeys")
defer span.Finish()
request := lookupServerKeys{
S: s,
KeyRequests: keyRequests,
}
var response lookupServerKeys
apiURL := h.federationAPIURL + FederationAPILookupServerKeysPath
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
if err != nil {
return []gomatrixserverlib.ServerKeys{}, err
}
if response.Err != nil {
return []gomatrixserverlib.ServerKeys{}, response.Err
}
return response.ServerKeys, nil
}
type eventRelationships struct {
S gomatrixserverlib.ServerName
Req gomatrixserverlib.MSC2836EventRelationshipsRequest
RoomVer gomatrixserverlib.RoomVersion
Res gomatrixserverlib.MSC2836EventRelationshipsResponse
Err *api.FederationClientError
}
func (h *httpFederationInternalAPI) MSC2836EventRelationships(
ctx context.Context, s gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest,
roomVersion gomatrixserverlib.RoomVersion,
) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "MSC2836EventRelationships")
defer span.Finish()
request := eventRelationships{
S: s,
Req: r,
RoomVer: roomVersion,
}
var response eventRelationships
apiURL := h.federationAPIURL + FederationAPIEventRelationshipsPath
err = httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
if err != nil {
return res, err
}
if response.Err != nil {
return res, response.Err
}
return response.Res, nil
}
type spacesReq struct {
S gomatrixserverlib.ServerName
Req gomatrixserverlib.MSC2946SpacesRequest
RoomID string
Res gomatrixserverlib.MSC2946SpacesResponse
Err *api.FederationClientError
}
func (h *httpFederationInternalAPI) MSC2946Spaces(
ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, r gomatrixserverlib.MSC2946SpacesRequest,
) (res gomatrixserverlib.MSC2946SpacesResponse, err error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "MSC2946Spaces")
defer span.Finish()
request := spacesReq{
S: dst,
Req: r,
RoomID: roomID,
}
var response spacesReq
apiURL := h.federationAPIURL + FederationAPISpacesSummaryPath
err = httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
if err != nil {
return res, err
}
if response.Err != nil {
return res, response.Err
}
return response.Res, nil
}
func (s *httpFederationInternalAPI) KeyRing() *gomatrixserverlib.KeyRing {
// This is a bit of a cheat - we tell gomatrixserverlib that this API is
// both the key database and the key fetcher. While this does have the
// rather unfortunate effect of preventing gomatrixserverlib from handling
// key fetchers directly, we can at least reimplement this behaviour on
// the other end of the API.
return &gomatrixserverlib.KeyRing{
KeyDatabase: s,
KeyFetchers: []gomatrixserverlib.KeyFetcher{},
}
}
func (s *httpFederationInternalAPI) FetcherName() string {
return "httpServerKeyInternalAPI"
}
func (s *httpFederationInternalAPI) StoreKeys(
_ context.Context,
results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult,
) error {
// Run in a background context - we don't want to stop this work just
// because the caller gives up waiting.
ctx := context.Background()
request := api.InputPublicKeysRequest{
Keys: make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult),
}
response := api.InputPublicKeysResponse{}
for req, res := range results {
request.Keys[req] = res
s.cache.StoreServerKey(req, res)
}
return s.InputPublicKeys(ctx, &request, &response)
}
func (s *httpFederationInternalAPI) FetchKeys(
_ context.Context,
requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) {
// Run in a background context - we don't want to stop this work just
// because the caller gives up waiting.
ctx := context.Background()
result := make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult)
request := api.QueryPublicKeysRequest{
Requests: make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp),
}
response := api.QueryPublicKeysResponse{
Results: make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult),
}
for req, ts := range requests {
if res, ok := s.cache.GetServerKey(req, ts); ok {
result[req] = res
continue
}
request.Requests[req] = ts
}
err := s.QueryPublicKeys(ctx, &request, &response)
if err != nil {
return nil, err
}
for req, res := range response.Results {
result[req] = res
s.cache.StoreServerKey(req, res)
}
return result, nil
}
func (h *httpFederationInternalAPI) InputPublicKeys(
ctx context.Context,
request *api.InputPublicKeysRequest,
response *api.InputPublicKeysResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "InputPublicKey")
defer span.Finish()
apiURL := h.federationAPIURL + FederationAPIInputPublicKeyPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
}
func (h *httpFederationInternalAPI) QueryPublicKeys(
ctx context.Context,
request *api.QueryPublicKeysRequest,
response *api.QueryPublicKeysResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPublicKey")
defer span.Finish()
apiURL := h.federationAPIURL + FederationAPIQueryPublicKeyPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
}

View file

@ -0,0 +1,374 @@
package inthttp
import (
"encoding/json"
"net/http"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/internal/httputil"
"github.com/matrix-org/util"
)
// AddRoutes adds the FederationInternalAPI handlers to the http.ServeMux.
// nolint:gocyclo
func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
internalAPIMux.Handle(
FederationAPIQueryJoinedHostServerNamesInRoomPath,
httputil.MakeInternalAPI("QueryJoinedHostServerNamesInRoom", func(req *http.Request) util.JSONResponse {
var request api.QueryJoinedHostServerNamesInRoomRequest
var response api.QueryJoinedHostServerNamesInRoomResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.ErrorResponse(err)
}
if err := intAPI.QueryJoinedHostServerNamesInRoom(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(
FederationAPIPerformJoinRequestPath,
httputil.MakeInternalAPI("PerformJoinRequest", func(req *http.Request) util.JSONResponse {
var request api.PerformJoinRequest
var response api.PerformJoinResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
intAPI.PerformJoin(req.Context(), &request, &response)
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(
FederationAPIPerformLeaveRequestPath,
httputil.MakeInternalAPI("PerformLeaveRequest", func(req *http.Request) util.JSONResponse {
var request api.PerformLeaveRequest
var response api.PerformLeaveResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := intAPI.PerformLeave(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(
FederationAPIPerformInviteRequestPath,
httputil.MakeInternalAPI("PerformInviteRequest", func(req *http.Request) util.JSONResponse {
var request api.PerformInviteRequest
var response api.PerformInviteResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := intAPI.PerformInvite(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(
FederationAPIPerformDirectoryLookupRequestPath,
httputil.MakeInternalAPI("PerformDirectoryLookupRequest", func(req *http.Request) util.JSONResponse {
var request api.PerformDirectoryLookupRequest
var response api.PerformDirectoryLookupResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := intAPI.PerformDirectoryLookup(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(
FederationAPIPerformServersAlivePath,
httputil.MakeInternalAPI("PerformServersAliveRequest", func(req *http.Request) util.JSONResponse {
var request api.PerformServersAliveRequest
var response api.PerformServersAliveResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := intAPI.PerformServersAlive(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(
FederationAPIPerformBroadcastEDUPath,
httputil.MakeInternalAPI("PerformBroadcastEDU", func(req *http.Request) util.JSONResponse {
var request api.PerformBroadcastEDURequest
var response api.PerformBroadcastEDUResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := intAPI.PerformBroadcastEDU(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(
FederationAPIGetUserDevicesPath,
httputil.MakeInternalAPI("GetUserDevices", func(req *http.Request) util.JSONResponse {
var request getUserDevices
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
res, err := intAPI.GetUserDevices(req.Context(), request.S, request.UserID)
if err != nil {
ferr, ok := err.(*api.FederationClientError)
if ok {
request.Err = ferr
} else {
request.Err = &api.FederationClientError{
Err: err.Error(),
}
}
}
request.Res = &res
return util.JSONResponse{Code: http.StatusOK, JSON: request}
}),
)
internalAPIMux.Handle(
FederationAPIClaimKeysPath,
httputil.MakeInternalAPI("ClaimKeys", func(req *http.Request) util.JSONResponse {
var request claimKeys
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
res, err := intAPI.ClaimKeys(req.Context(), request.S, request.OneTimeKeys)
if err != nil {
ferr, ok := err.(*api.FederationClientError)
if ok {
request.Err = ferr
} else {
request.Err = &api.FederationClientError{
Err: err.Error(),
}
}
}
request.Res = &res
return util.JSONResponse{Code: http.StatusOK, JSON: request}
}),
)
internalAPIMux.Handle(
FederationAPIQueryKeysPath,
httputil.MakeInternalAPI("QueryKeys", func(req *http.Request) util.JSONResponse {
var request queryKeys
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
res, err := intAPI.QueryKeys(req.Context(), request.S, request.Keys)
if err != nil {
ferr, ok := err.(*api.FederationClientError)
if ok {
request.Err = ferr
} else {
request.Err = &api.FederationClientError{
Err: err.Error(),
}
}
}
request.Res = &res
return util.JSONResponse{Code: http.StatusOK, JSON: request}
}),
)
internalAPIMux.Handle(
FederationAPIBackfillPath,
httputil.MakeInternalAPI("Backfill", func(req *http.Request) util.JSONResponse {
var request backfill
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
res, err := intAPI.Backfill(req.Context(), request.S, request.RoomID, request.Limit, request.EventIDs)
if err != nil {
ferr, ok := err.(*api.FederationClientError)
if ok {
request.Err = ferr
} else {
request.Err = &api.FederationClientError{
Err: err.Error(),
}
}
}
request.Res = &res
return util.JSONResponse{Code: http.StatusOK, JSON: request}
}),
)
internalAPIMux.Handle(
FederationAPILookupStatePath,
httputil.MakeInternalAPI("LookupState", func(req *http.Request) util.JSONResponse {
var request lookupState
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
res, err := intAPI.LookupState(req.Context(), request.S, request.RoomID, request.EventID, request.RoomVersion)
if err != nil {
ferr, ok := err.(*api.FederationClientError)
if ok {
request.Err = ferr
} else {
request.Err = &api.FederationClientError{
Err: err.Error(),
}
}
}
request.Res = &res
return util.JSONResponse{Code: http.StatusOK, JSON: request}
}),
)
internalAPIMux.Handle(
FederationAPILookupStateIDsPath,
httputil.MakeInternalAPI("LookupStateIDs", func(req *http.Request) util.JSONResponse {
var request lookupStateIDs
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
res, err := intAPI.LookupStateIDs(req.Context(), request.S, request.RoomID, request.EventID)
if err != nil {
ferr, ok := err.(*api.FederationClientError)
if ok {
request.Err = ferr
} else {
request.Err = &api.FederationClientError{
Err: err.Error(),
}
}
}
request.Res = &res
return util.JSONResponse{Code: http.StatusOK, JSON: request}
}),
)
internalAPIMux.Handle(
FederationAPIGetEventPath,
httputil.MakeInternalAPI("GetEvent", func(req *http.Request) util.JSONResponse {
var request getEvent
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
res, err := intAPI.GetEvent(req.Context(), request.S, request.EventID)
if err != nil {
ferr, ok := err.(*api.FederationClientError)
if ok {
request.Err = ferr
} else {
request.Err = &api.FederationClientError{
Err: err.Error(),
}
}
}
request.Res = &res
return util.JSONResponse{Code: http.StatusOK, JSON: request}
}),
)
internalAPIMux.Handle(
FederationAPIQueryServerKeysPath,
httputil.MakeInternalAPI("QueryServerKeys", func(req *http.Request) util.JSONResponse {
var request api.QueryServerKeysRequest
var response api.QueryServerKeysResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := intAPI.QueryServerKeys(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(
FederationAPILookupServerKeysPath,
httputil.MakeInternalAPI("LookupServerKeys", func(req *http.Request) util.JSONResponse {
var request lookupServerKeys
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
res, err := intAPI.LookupServerKeys(req.Context(), request.S, request.KeyRequests)
if err != nil {
ferr, ok := err.(*api.FederationClientError)
if ok {
request.Err = ferr
} else {
request.Err = &api.FederationClientError{
Err: err.Error(),
}
}
}
request.ServerKeys = res
return util.JSONResponse{Code: http.StatusOK, JSON: request}
}),
)
internalAPIMux.Handle(
FederationAPIEventRelationshipsPath,
httputil.MakeInternalAPI("MSC2836EventRelationships", func(req *http.Request) util.JSONResponse {
var request eventRelationships
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
res, err := intAPI.MSC2836EventRelationships(req.Context(), request.S, request.Req, request.RoomVer)
if err != nil {
ferr, ok := err.(*api.FederationClientError)
if ok {
request.Err = ferr
} else {
request.Err = &api.FederationClientError{
Err: err.Error(),
}
}
}
request.Res = res
return util.JSONResponse{Code: http.StatusOK, JSON: request}
}),
)
internalAPIMux.Handle(
FederationAPISpacesSummaryPath,
httputil.MakeInternalAPI("MSC2946SpacesSummary", func(req *http.Request) util.JSONResponse {
var request spacesReq
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
res, err := intAPI.MSC2946Spaces(req.Context(), request.S, request.RoomID, request.Req)
if err != nil {
ferr, ok := err.(*api.FederationClientError)
if ok {
request.Err = ferr
} else {
request.Err = &api.FederationClientError{
Err: err.Error(),
}
}
}
request.Res = res
return util.JSONResponse{Code: http.StatusOK, JSON: request}
}),
)
internalAPIMux.Handle(FederationAPIQueryPublicKeyPath,
httputil.MakeInternalAPI("queryPublicKeys", func(req *http.Request) util.JSONResponse {
request := api.QueryPublicKeysRequest{}
response := api.QueryPublicKeysResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
keys, err := intAPI.FetchKeys(req.Context(), request.Requests)
if err != nil {
return util.ErrorResponse(err)
}
response.Results = keys
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
internalAPIMux.Handle(FederationAPIInputPublicKeyPath,
httputil.MakeInternalAPI("inputPublicKeys", func(req *http.Request) util.JSONResponse {
request := api.InputPublicKeysRequest{}
response := api.InputPublicKeysResponse{}
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := intAPI.StoreKeys(req.Context(), request.Keys); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
}

View file

@ -0,0 +1,448 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package queue
import (
"context"
"encoding/json"
"fmt"
"sync"
"time"
"github.com/matrix-org/dendrite/federationapi/statistics"
"github.com/matrix-org/dendrite/federationapi/storage"
"github.com/matrix-org/dendrite/federationapi/storage/shared"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/gomatrix"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
"go.uber.org/atomic"
)
const (
maxPDUsPerTransaction = 50
maxEDUsPerTransaction = 50
maxPDUsInMemory = 128
maxEDUsInMemory = 128
queueIdleTimeout = time.Second * 30
)
// destinationQueue is a queue of events for a single destination.
// It is responsible for sending the events to the destination and
// ensures that only one request is in flight to a given destination
// at a time.
type destinationQueue struct {
queues *OutgoingQueues
db storage.Database
process *process.ProcessContext
signing *SigningInfo
rsAPI api.RoomserverInternalAPI
client *gomatrixserverlib.FederationClient // federation client
origin gomatrixserverlib.ServerName // origin of requests
destination gomatrixserverlib.ServerName // destination of requests
running atomic.Bool // is the queue worker running?
backingOff atomic.Bool // true if we're backing off
overflowed atomic.Bool // the queues exceed maxPDUsInMemory/maxEDUsInMemory, so we should consult the database for more
statistics *statistics.ServerStatistics // statistics about this remote server
transactionIDMutex sync.Mutex // protects transactionID
transactionID gomatrixserverlib.TransactionID // last transaction ID if retrying, or "" if last txn was successful
notify chan struct{} // interrupts idle wait pending PDUs/EDUs
pendingPDUs []*queuedPDU // PDUs waiting to be sent
pendingEDUs []*queuedEDU // EDUs waiting to be sent
pendingMutex sync.RWMutex // protects pendingPDUs and pendingEDUs
interruptBackoff chan bool // interrupts backoff
}
// Send event adds the event to the pending queue for the destination.
// If the queue is empty then it starts a background goroutine to
// start sending events to that destination.
func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, receipt *shared.Receipt) {
if event == nil {
logrus.Errorf("attempt to send nil PDU with destination %q", oq.destination)
return
}
// Create a database entry that associates the given PDU NID with
// this destination queue. We'll then be able to retrieve the PDU
// later.
if err := oq.db.AssociatePDUWithDestination(
context.TODO(),
"", // TODO: remove this, as we don't need to persist the transaction ID
oq.destination, // the destination server name
receipt, // NIDs from federationapi_queue_json table
); err != nil {
logrus.WithError(err).Errorf("failed to associate PDU %q with destination %q", event.EventID(), oq.destination)
return
}
// Check if the destination is blacklisted. If it isn't then wake
// up the queue.
if !oq.statistics.Blacklisted() {
// If there's room in memory to hold the event then add it to the
// list.
oq.pendingMutex.Lock()
if len(oq.pendingPDUs) < maxPDUsInMemory {
oq.pendingPDUs = append(oq.pendingPDUs, &queuedPDU{
pdu: event,
receipt: receipt,
})
} else {
oq.overflowed.Store(true)
}
oq.pendingMutex.Unlock()
// Wake up the queue if it's asleep.
oq.wakeQueueIfNeeded()
select {
case oq.notify <- struct{}{}:
default:
}
}
}
// sendEDU adds the EDU event to the pending queue for the destination.
// If the queue is empty then it starts a background goroutine to
// start sending events to that destination.
func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *shared.Receipt) {
if event == nil {
logrus.Errorf("attempt to send nil EDU with destination %q", oq.destination)
return
}
// Create a database entry that associates the given PDU NID with
// this destination queue. We'll then be able to retrieve the PDU
// later.
if err := oq.db.AssociateEDUWithDestination(
context.TODO(),
oq.destination, // the destination server name
receipt, // NIDs from federationapi_queue_json table
); err != nil {
logrus.WithError(err).Errorf("failed to associate EDU with destination %q", oq.destination)
return
}
// Check if the destination is blacklisted. If it isn't then wake
// up the queue.
if !oq.statistics.Blacklisted() {
// If there's room in memory to hold the event then add it to the
// list.
oq.pendingMutex.Lock()
if len(oq.pendingEDUs) < maxEDUsInMemory {
oq.pendingEDUs = append(oq.pendingEDUs, &queuedEDU{
edu: event,
receipt: receipt,
})
} else {
oq.overflowed.Store(true)
}
oq.pendingMutex.Unlock()
// Wake up the queue if it's asleep.
oq.wakeQueueIfNeeded()
select {
case oq.notify <- struct{}{}:
default:
}
}
}
// wakeQueueIfNeeded will wake up the destination queue if it is
// not already running. If it is running but it is backing off
// then we will interrupt the backoff, causing any federation
// requests to retry.
func (oq *destinationQueue) wakeQueueIfNeeded() {
// If we are backing off then interrupt the backoff.
if oq.backingOff.CAS(true, false) {
oq.interruptBackoff <- true
}
// If we aren't running then wake up the queue.
if !oq.running.Load() {
// Start the queue.
go oq.backgroundSend()
}
}
// getPendingFromDatabase will look at the database and see if
// there are any persisted events that haven't been sent to this
// destination yet. If so, they will be queued up.
func (oq *destinationQueue) getPendingFromDatabase() {
// Check to see if there's anything to do for this server
// in the database.
retrieved := false
ctx := context.Background()
oq.pendingMutex.Lock()
defer oq.pendingMutex.Unlock()
// Take a note of all of the PDUs and EDUs that we already
// have cached. We will index them based on the receipt,
// which ultimately just contains the index of the PDU/EDU
// in the database.
gotPDUs := map[string]struct{}{}
gotEDUs := map[string]struct{}{}
for _, pdu := range oq.pendingPDUs {
gotPDUs[pdu.receipt.String()] = struct{}{}
}
for _, edu := range oq.pendingEDUs {
gotEDUs[edu.receipt.String()] = struct{}{}
}
if pduCapacity := maxPDUsInMemory - len(oq.pendingPDUs); pduCapacity > 0 {
// We have room in memory for some PDUs - let's request no more than that.
if pdus, err := oq.db.GetPendingPDUs(ctx, oq.destination, pduCapacity); err == nil {
for receipt, pdu := range pdus {
if _, ok := gotPDUs[receipt.String()]; ok {
continue
}
oq.pendingPDUs = append(oq.pendingPDUs, &queuedPDU{receipt, pdu})
retrieved = true
}
} else {
logrus.WithError(err).Errorf("Failed to get pending PDUs for %q", oq.destination)
}
}
if eduCapacity := maxEDUsInMemory - len(oq.pendingEDUs); eduCapacity > 0 {
// We have room in memory for some EDUs - let's request no more than that.
if edus, err := oq.db.GetPendingEDUs(ctx, oq.destination, eduCapacity); err == nil {
for receipt, edu := range edus {
if _, ok := gotEDUs[receipt.String()]; ok {
continue
}
oq.pendingEDUs = append(oq.pendingEDUs, &queuedEDU{receipt, edu})
retrieved = true
}
} else {
logrus.WithError(err).Errorf("Failed to get pending EDUs for %q", oq.destination)
}
}
// If we've retrieved all of the events from the database with room to spare
// in memory then we'll no longer consider this queue to be overflowed.
if len(oq.pendingPDUs) < maxPDUsInMemory && len(oq.pendingEDUs) < maxEDUsInMemory {
oq.overflowed.Store(false)
}
// If we've retrieved some events then notify the destination queue goroutine.
if retrieved {
select {
case oq.notify <- struct{}{}:
default:
}
}
}
// backgroundSend is the worker goroutine for sending events.
func (oq *destinationQueue) backgroundSend() {
// Check if a worker is already running, and if it isn't, then
// mark it as started.
if !oq.running.CAS(false, true) {
return
}
destinationQueueRunning.Inc()
defer destinationQueueRunning.Dec()
defer oq.queues.clearQueue(oq)
defer oq.running.Store(false)
// Mark the queue as overflowed, so we will consult the database
// to see if there's anything new to send.
oq.overflowed.Store(true)
for {
// If we are overflowing memory and have sent things out to the
// database then we can look up what those things are.
if oq.overflowed.Load() {
oq.getPendingFromDatabase()
}
// If we have nothing to do then wait either for incoming events, or
// until we hit an idle timeout.
select {
case <-oq.notify:
// There's work to do, either because getPendingFromDatabase
// told us there is, or because a new event has come in via
// sendEvent/sendEDU.
case <-time.After(queueIdleTimeout):
// The worker is idle so stop the goroutine. It'll get
// restarted automatically the next time we have an event to
// send.
return
}
// If we are backing off this server then wait for the
// backoff duration to complete first, or until explicitly
// told to retry.
until, blacklisted := oq.statistics.BackoffInfo()
if blacklisted {
// It's been suggested that we should give up because the backoff
// has exceeded a maximum allowable value. Clean up the in-memory
// buffers at this point. The PDU clean-up is already on a defer.
logrus.Warnf("Blacklisting %q due to exceeding backoff threshold", oq.destination)
oq.pendingMutex.Lock()
for i := range oq.pendingPDUs {
oq.pendingPDUs[i] = nil
}
for i := range oq.pendingEDUs {
oq.pendingEDUs[i] = nil
}
oq.pendingPDUs = nil
oq.pendingEDUs = nil
oq.pendingMutex.Unlock()
return
}
if until != nil && until.After(time.Now()) {
// We haven't backed off yet, so wait for the suggested amount of
// time.
duration := time.Until(*until)
logrus.Warnf("Backing off %q for %s", oq.destination, duration)
oq.backingOff.Store(true)
destinationQueueBackingOff.Inc()
select {
case <-time.After(duration):
case <-oq.interruptBackoff:
}
destinationQueueBackingOff.Dec()
oq.backingOff.Store(false)
}
// Work out which PDUs/EDUs to include in the next transaction.
oq.pendingMutex.RLock()
pduCount := len(oq.pendingPDUs)
eduCount := len(oq.pendingEDUs)
if pduCount > maxPDUsPerTransaction {
pduCount = maxPDUsPerTransaction
}
if eduCount > maxEDUsPerTransaction {
eduCount = maxEDUsPerTransaction
}
toSendPDUs := oq.pendingPDUs[:pduCount]
toSendEDUs := oq.pendingEDUs[:eduCount]
oq.pendingMutex.RUnlock()
// If we have pending PDUs or EDUs then construct a transaction.
// Try sending the next transaction and see what happens.
transaction, pc, ec, terr := oq.nextTransaction(toSendPDUs, toSendEDUs)
if terr != nil {
// We failed to send the transaction. Mark it as a failure.
oq.statistics.Failure()
} else if transaction {
// If we successfully sent the transaction then clear out
// the pending events and EDUs, and wipe our transaction ID.
oq.statistics.Success()
oq.pendingMutex.Lock()
for i := range oq.pendingPDUs[:pc] {
oq.pendingPDUs[i] = nil
}
for i := range oq.pendingEDUs[:ec] {
oq.pendingEDUs[i] = nil
}
oq.pendingPDUs = oq.pendingPDUs[pc:]
oq.pendingEDUs = oq.pendingEDUs[ec:]
oq.pendingMutex.Unlock()
}
}
}
// nextTransaction creates a new transaction from the pending event
// queue and sends it. Returns true if a transaction was sent or
// false otherwise.
func (oq *destinationQueue) nextTransaction(
pdus []*queuedPDU,
edus []*queuedEDU,
) (bool, int, int, error) {
// If there's no projected transaction ID then generate one. If
// the transaction succeeds then we'll set it back to "" so that
// we generate a new one next time. If it fails, we'll preserve
// it so that we retry with the same transaction ID.
oq.transactionIDMutex.Lock()
if oq.transactionID == "" {
now := gomatrixserverlib.AsTimestamp(time.Now())
oq.transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount()))
}
oq.transactionIDMutex.Unlock()
// Create the transaction.
t := gomatrixserverlib.Transaction{
PDUs: []json.RawMessage{},
EDUs: []gomatrixserverlib.EDU{},
}
t.Origin = oq.origin
t.Destination = oq.destination
t.OriginServerTS = gomatrixserverlib.AsTimestamp(time.Now())
t.TransactionID = oq.transactionID
// If we didn't get anything from the database and there are no
// pending EDUs then there's nothing to do - stop here.
if len(pdus) == 0 && len(edus) == 0 {
return false, 0, 0, nil
}
var pduReceipts []*shared.Receipt
var eduReceipts []*shared.Receipt
// Go through PDUs that we retrieved from the database, if any,
// and add them into the transaction.
for _, pdu := range pdus {
if pdu == nil || pdu.pdu == nil {
continue
}
// Append the JSON of the event, since this is a json.RawMessage type in the
// gomatrixserverlib.Transaction struct
t.PDUs = append(t.PDUs, pdu.pdu.JSON())
pduReceipts = append(pduReceipts, pdu.receipt)
}
// Do the same for pending EDUS in the queue.
for _, edu := range edus {
if edu == nil || edu.edu == nil {
continue
}
t.EDUs = append(t.EDUs, *edu.edu)
eduReceipts = append(eduReceipts, edu.receipt)
}
logrus.WithField("server_name", oq.destination).Debugf("Sending transaction %q containing %d PDUs, %d EDUs", t.TransactionID, len(t.PDUs), len(t.EDUs))
// Try to send the transaction to the destination server.
// TODO: we should check for 500-ish fails vs 400-ish here,
// since we shouldn't queue things indefinitely in response
// to a 400-ish error
ctx, cancel := context.WithTimeout(oq.process.Context(), time.Minute*5)
defer cancel()
_, err := oq.client.SendTransaction(ctx, t)
switch err.(type) {
case nil:
// Clean up the transaction in the database.
if pduReceipts != nil {
//logrus.Infof("Cleaning PDUs %q", pduReceipt.String())
if err = oq.db.CleanPDUs(context.Background(), oq.destination, pduReceipts); err != nil {
logrus.WithError(err).Errorf("Failed to clean PDUs for server %q", t.Destination)
}
}
if eduReceipts != nil {
//logrus.Infof("Cleaning EDUs %q", eduReceipt.String())
if err = oq.db.CleanEDUs(context.Background(), oq.destination, eduReceipts); err != nil {
logrus.WithError(err).Errorf("Failed to clean EDUs for server %q", t.Destination)
}
}
// Reset the transaction ID.
oq.transactionIDMutex.Lock()
oq.transactionID = ""
oq.transactionIDMutex.Unlock()
return true, len(t.PDUs), len(t.EDUs), nil
case gomatrix.HTTPError:
// Report that we failed to send the transaction and we
// will retry again, subject to backoff.
return false, 0, 0, err
default:
logrus.WithFields(logrus.Fields{
"destination": oq.destination,
logrus.ErrorKey: err,
}).Debugf("Failed to send transaction %q", t.TransactionID)
return false, 0, 0, err
}
}

View file

@ -0,0 +1,339 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package queue
import (
"context"
"crypto/ed25519"
"encoding/json"
"fmt"
"sync"
"time"
"github.com/matrix-org/dendrite/federationapi/statistics"
"github.com/matrix-org/dendrite/federationapi/storage"
"github.com/matrix-org/dendrite/federationapi/storage/shared"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/process"
"github.com/matrix-org/gomatrixserverlib"
"github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
)
// OutgoingQueues is a collection of queues for sending transactions to other
// matrix servers
type OutgoingQueues struct {
db storage.Database
process *process.ProcessContext
disabled bool
rsAPI api.RoomserverInternalAPI
origin gomatrixserverlib.ServerName
client *gomatrixserverlib.FederationClient
statistics *statistics.Statistics
signing *SigningInfo
queuesMutex sync.Mutex // protects the below
queues map[gomatrixserverlib.ServerName]*destinationQueue
}
func init() {
prometheus.MustRegister(
destinationQueueTotal, destinationQueueRunning,
destinationQueueBackingOff,
)
}
var destinationQueueTotal = prometheus.NewGauge(
prometheus.GaugeOpts{
Namespace: "dendrite",
Subsystem: "federationapi",
Name: "destination_queues_total",
},
)
var destinationQueueRunning = prometheus.NewGauge(
prometheus.GaugeOpts{
Namespace: "dendrite",
Subsystem: "federationapi",
Name: "destination_queues_running",
},
)
var destinationQueueBackingOff = prometheus.NewGauge(
prometheus.GaugeOpts{
Namespace: "dendrite",
Subsystem: "federationapi",
Name: "destination_queues_backing_off",
},
)
// NewOutgoingQueues makes a new OutgoingQueues
func NewOutgoingQueues(
db storage.Database,
process *process.ProcessContext,
disabled bool,
origin gomatrixserverlib.ServerName,
client *gomatrixserverlib.FederationClient,
rsAPI api.RoomserverInternalAPI,
statistics *statistics.Statistics,
signing *SigningInfo,
) *OutgoingQueues {
queues := &OutgoingQueues{
disabled: disabled,
process: process,
db: db,
rsAPI: rsAPI,
origin: origin,
client: client,
statistics: statistics,
signing: signing,
queues: map[gomatrixserverlib.ServerName]*destinationQueue{},
}
// Look up which servers we have pending items for and then rehydrate those queues.
if !disabled {
time.AfterFunc(time.Second*5, func() {
serverNames := map[gomatrixserverlib.ServerName]struct{}{}
if names, err := db.GetPendingPDUServerNames(context.Background()); err == nil {
for _, serverName := range names {
serverNames[serverName] = struct{}{}
}
} else {
log.WithError(err).Error("Failed to get PDU server names for destination queue hydration")
}
if names, err := db.GetPendingEDUServerNames(context.Background()); err == nil {
for _, serverName := range names {
serverNames[serverName] = struct{}{}
}
} else {
log.WithError(err).Error("Failed to get EDU server names for destination queue hydration")
}
for serverName := range serverNames {
if queue := queues.getQueue(serverName); queue != nil {
queue.wakeQueueIfNeeded()
}
}
})
}
return queues
}
// TODO: Move this somewhere useful for other components as we often need to ferry these 3 variables
// around together
type SigningInfo struct {
ServerName gomatrixserverlib.ServerName
KeyID gomatrixserverlib.KeyID
PrivateKey ed25519.PrivateKey
}
type queuedPDU struct {
receipt *shared.Receipt
pdu *gomatrixserverlib.HeaderedEvent
}
type queuedEDU struct {
receipt *shared.Receipt
edu *gomatrixserverlib.EDU
}
func (oqs *OutgoingQueues) getQueue(destination gomatrixserverlib.ServerName) *destinationQueue {
if oqs.statistics.ForServer(destination).Blacklisted() {
return nil
}
oqs.queuesMutex.Lock()
defer oqs.queuesMutex.Unlock()
oq, ok := oqs.queues[destination]
if !ok || oq != nil {
destinationQueueTotal.Inc()
oq = &destinationQueue{
queues: oqs,
db: oqs.db,
process: oqs.process,
rsAPI: oqs.rsAPI,
origin: oqs.origin,
destination: destination,
client: oqs.client,
statistics: oqs.statistics.ForServer(destination),
notify: make(chan struct{}, 1),
interruptBackoff: make(chan bool),
signing: oqs.signing,
}
oqs.queues[destination] = oq
}
return oq
}
func (oqs *OutgoingQueues) clearQueue(oq *destinationQueue) {
oqs.queuesMutex.Lock()
defer oqs.queuesMutex.Unlock()
delete(oqs.queues, oq.destination)
destinationQueueTotal.Dec()
}
type ErrorFederationDisabled struct {
Message string
}
func (e *ErrorFederationDisabled) Error() string {
return e.Message
}
// SendEvent sends an event to the destinations
func (oqs *OutgoingQueues) SendEvent(
ev *gomatrixserverlib.HeaderedEvent, origin gomatrixserverlib.ServerName,
destinations []gomatrixserverlib.ServerName,
) error {
if oqs.disabled {
return &ErrorFederationDisabled{
Message: "Federation disabled",
}
}
if origin != oqs.origin {
// TODO: Support virtual hosting; gh issue #577.
return fmt.Errorf(
"sendevent: unexpected server to send as: got %q expected %q",
origin, oqs.origin,
)
}
// Deduplicate destinations and remove the origin from the list of
// destinations just to be sure.
destmap := map[gomatrixserverlib.ServerName]struct{}{}
for _, d := range destinations {
destmap[d] = struct{}{}
}
delete(destmap, oqs.origin)
// Check if any of the destinations are prohibited by server ACLs.
for destination := range destmap {
if api.IsServerBannedFromRoom(
context.TODO(),
oqs.rsAPI,
ev.RoomID(),
destination,
) {
delete(destmap, destination)
}
}
// If there are no remaining destinations then give up.
if len(destmap) == 0 {
return nil
}
log.WithFields(log.Fields{
"destinations": len(destmap), "event": ev.EventID(),
}).Infof("Sending event")
headeredJSON, err := json.Marshal(ev)
if err != nil {
return fmt.Errorf("json.Marshal: %w", err)
}
nid, err := oqs.db.StoreJSON(context.TODO(), string(headeredJSON))
if err != nil {
return fmt.Errorf("sendevent: oqs.db.StoreJSON: %w", err)
}
for destination := range destmap {
if queue := oqs.getQueue(destination); queue != nil {
queue.sendEvent(ev, nid)
}
}
return nil
}
// SendEDU sends an EDU event to the destinations.
func (oqs *OutgoingQueues) SendEDU(
e *gomatrixserverlib.EDU, origin gomatrixserverlib.ServerName,
destinations []gomatrixserverlib.ServerName,
) error {
if oqs.disabled {
return &ErrorFederationDisabled{
Message: "Federation disabled",
}
}
if origin != oqs.origin {
// TODO: Support virtual hosting; gh issue #577.
return fmt.Errorf(
"sendevent: unexpected server to send as: got %q expected %q",
origin, oqs.origin,
)
}
// Deduplicate destinations and remove the origin from the list of
// destinations just to be sure.
destmap := map[gomatrixserverlib.ServerName]struct{}{}
for _, d := range destinations {
destmap[d] = struct{}{}
}
delete(destmap, oqs.origin)
// There is absolutely no guarantee that the EDU will have a room_id
// field, as it is not required by the spec. However, if it *does*
// (e.g. typing notifications) then we should try to make sure we don't
// bother sending them to servers that are prohibited by the server
// ACLs.
if result := gjson.GetBytes(e.Content, "room_id"); result.Exists() {
for destination := range destmap {
if api.IsServerBannedFromRoom(
context.TODO(),
oqs.rsAPI,
result.Str,
destination,
) {
delete(destmap, destination)
}
}
}
// If there are no remaining destinations then give up.
if len(destmap) == 0 {
return nil
}
log.WithFields(log.Fields{
"destinations": len(destmap), "edu_type": e.Type,
}).Info("Sending EDU event")
ephemeralJSON, err := json.Marshal(e)
if err != nil {
return fmt.Errorf("json.Marshal: %w", err)
}
nid, err := oqs.db.StoreJSON(context.TODO(), string(ephemeralJSON))
if err != nil {
return fmt.Errorf("sendevent: oqs.db.StoreJSON: %w", err)
}
for destination := range destmap {
if queue := oqs.getQueue(destination); queue != nil {
queue.sendEDU(e, nid)
}
}
return nil
}
// RetryServer attempts to resend events to the given server if we had given up.
func (oqs *OutgoingQueues) RetryServer(srv gomatrixserverlib.ServerName) {
if oqs.disabled {
return
}
if queue := oqs.getQueue(srv); queue != nil {
queue.wakeQueueIfNeeded()
}
}

View file

@ -21,7 +21,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/httputil"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api"
federationAPI "github.com/matrix-org/dendrite/federationapi/api"
"github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrixserverlib"
@ -179,7 +179,7 @@ func localKeys(cfg *config.FederationAPI, validUntil time.Time) (*gomatrixserver
func NotaryKeys(
httpReq *http.Request, cfg *config.FederationAPI,
fsAPI federationSenderAPI.FederationSenderInternalAPI,
fsAPI federationAPI.FederationInternalAPI,
req *gomatrixserverlib.PublicKeyNotaryLookupRequest,
) util.JSONResponse {
if req == nil {
@ -203,8 +203,8 @@ func NotaryKeys(
return util.ErrorResponse(err)
}
} else {
var resp federationSenderAPI.QueryServerKeysResponse
err := fsAPI.QueryServerKeys(httpReq.Context(), &federationSenderAPI.QueryServerKeysRequest{
var resp federationAPI.QueryServerKeysResponse
err := fsAPI.QueryServerKeys(httpReq.Context(), &federationAPI.QueryServerKeysRequest{
ServerName: serverName,
KeyIDToCriteria: kidToCriteria,
}, &resp)

View file

@ -33,7 +33,7 @@ func Peek(
roomID, peekID string,
remoteVersions []gomatrixserverlib.RoomVersion,
) util.JSONResponse {
// TODO: check if we're just refreshing an existing peek by querying the federationsender
// TODO: check if we're just refreshing an existing peek by querying the federationapi
verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID}
verRes := api.QueryRoomVersionForRoomResponse{}

View file

@ -19,7 +19,7 @@ import (
"net/http"
"github.com/matrix-org/dendrite/clientapi/jsonerror"
federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api"
federationAPI "github.com/matrix-org/dendrite/federationapi/api"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/gomatrix"
@ -33,7 +33,7 @@ func RoomAliasToID(
federation *gomatrixserverlib.FederationClient,
cfg *config.FederationAPI,
rsAPI roomserverAPI.RoomserverInternalAPI,
senderAPI federationSenderAPI.FederationSenderInternalAPI,
senderAPI federationAPI.FederationInternalAPI,
) util.JSONResponse {
roomAlias := httpReq.FormValue("room_alias")
if roomAlias == "" {
@ -64,8 +64,8 @@ func RoomAliasToID(
}
if queryRes.RoomID != "" {
serverQueryReq := federationSenderAPI.QueryJoinedHostServerNamesInRoomRequest{RoomID: queryRes.RoomID}
var serverQueryRes federationSenderAPI.QueryJoinedHostServerNamesInRoomResponse
serverQueryReq := federationAPI.QueryJoinedHostServerNamesInRoomRequest{RoomID: queryRes.RoomID}
var serverQueryRes federationAPI.QueryJoinedHostServerNamesInRoomResponse
if err = senderAPI.QueryJoinedHostServerNamesInRoom(httpReq.Context(), &serverQueryReq, &serverQueryRes); err != nil {
util.GetLogger(httpReq.Context()).WithError(err).Error("senderAPI.QueryJoinedHostServerNamesInRoom failed")
return jsonerror.InternalServerError()

View file

@ -21,7 +21,6 @@ import (
"github.com/matrix-org/dendrite/clientapi/jsonerror"
eduserverAPI "github.com/matrix-org/dendrite/eduserver/api"
federationAPI "github.com/matrix-org/dendrite/federationapi/api"
federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/httputil"
keyserverAPI "github.com/matrix-org/dendrite/keyserver/api"
@ -46,7 +45,7 @@ func Setup(
cfg *config.FederationAPI,
rsAPI roomserverAPI.RoomserverInternalAPI,
eduAPI eduserverAPI.EDUServerInputAPI,
fsAPI federationSenderAPI.FederationSenderInternalAPI,
fsAPI federationAPI.FederationInternalAPI,
keys gomatrixserverlib.JSONVerifier,
federation *gomatrixserverlib.FederationClient,
userAPI userapi.UserInternalAPI,

View file

@ -0,0 +1,180 @@
package statistics
import (
"math"
"sync"
"time"
"github.com/matrix-org/dendrite/federationapi/storage"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
"go.uber.org/atomic"
)
// Statistics contains information about all of the remote federated
// hosts that we have interacted with. It is basically a threadsafe
// wrapper.
type Statistics struct {
DB storage.Database
servers map[gomatrixserverlib.ServerName]*ServerStatistics
mutex sync.RWMutex
// How many times should we tolerate consecutive failures before we
// just blacklist the host altogether? The backoff is exponential,
// so the max time here to attempt is 2**failures seconds.
FailuresUntilBlacklist uint32
}
// ForServer returns server statistics for the given server name. If it
// does not exist, it will create empty statistics and return those.
func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerStatistics {
// If the map hasn't been initialised yet then do that.
if s.servers == nil {
s.mutex.Lock()
s.servers = make(map[gomatrixserverlib.ServerName]*ServerStatistics)
s.mutex.Unlock()
}
// Look up if we have statistics for this server already.
s.mutex.RLock()
server, found := s.servers[serverName]
s.mutex.RUnlock()
// If we don't, then make one.
if !found {
s.mutex.Lock()
server = &ServerStatistics{
statistics: s,
serverName: serverName,
interrupt: make(chan struct{}),
}
s.servers[serverName] = server
s.mutex.Unlock()
blacklisted, err := s.DB.IsServerBlacklisted(serverName)
if err != nil {
logrus.WithError(err).Errorf("Failed to get blacklist entry %q", serverName)
} else {
server.blacklisted.Store(blacklisted)
}
}
return server
}
// ServerStatistics contains information about our interactions with a
// remote federated host, e.g. how many times we were successful, how
// many times we failed etc. It also manages the backoff time and black-
// listing a remote host if it remains uncooperative.
type ServerStatistics struct {
statistics *Statistics //
serverName gomatrixserverlib.ServerName //
blacklisted atomic.Bool // is the node blacklisted
backoffStarted atomic.Bool // is the backoff started
backoffUntil atomic.Value // time.Time until this backoff interval ends
backoffCount atomic.Uint32 // number of times BackoffDuration has been called
interrupt chan struct{} // interrupts the backoff goroutine
successCounter atomic.Uint32 // how many times have we succeeded?
}
// duration returns how long the next backoff interval should be.
func (s *ServerStatistics) duration(count uint32) time.Duration {
return time.Second * time.Duration(math.Exp2(float64(count)))
}
// cancel will interrupt the currently active backoff.
func (s *ServerStatistics) cancel() {
s.blacklisted.Store(false)
s.backoffUntil.Store(time.Time{})
select {
case s.interrupt <- struct{}{}:
default:
}
}
// Success updates the server statistics with a new successful
// attempt, which increases the sent counter and resets the idle and
// failure counters. If a host was blacklisted at this point then
// we will unblacklist it.
func (s *ServerStatistics) Success() {
s.cancel()
s.successCounter.Inc()
s.backoffCount.Store(0)
if s.statistics.DB != nil {
if err := s.statistics.DB.RemoveServerFromBlacklist(s.serverName); err != nil {
logrus.WithError(err).Errorf("Failed to remove %q from blacklist", s.serverName)
}
}
}
// Failure marks a failure and starts backing off if needed.
// The next call to BackoffIfRequired will do the right thing
// after this. It will return the time that the current failure
// will result in backoff waiting until, and a bool signalling
// whether we have blacklisted and therefore to give up.
func (s *ServerStatistics) Failure() (time.Time, bool) {
// If we aren't already backing off, this call will start
// a new backoff period. Increase the failure counter and
// start a goroutine which will wait out the backoff and
// unset the backoffStarted flag when done.
if s.backoffStarted.CAS(false, true) {
if s.backoffCount.Inc() >= s.statistics.FailuresUntilBlacklist {
s.blacklisted.Store(true)
if s.statistics.DB != nil {
if err := s.statistics.DB.AddServerToBlacklist(s.serverName); err != nil {
logrus.WithError(err).Errorf("Failed to add %q to blacklist", s.serverName)
}
}
return time.Time{}, true
}
go func() {
until, ok := s.backoffUntil.Load().(time.Time)
if ok {
select {
case <-time.After(time.Until(until)):
case <-s.interrupt:
}
}
s.backoffStarted.Store(false)
}()
}
// Check if we have blacklisted this node.
if s.blacklisted.Load() {
return time.Now(), true
}
// If we're already backing off and we haven't yet surpassed
// the deadline then return that. Repeated calls to Failure
// within a single backoff interval will have no side effects.
if until, ok := s.backoffUntil.Load().(time.Time); ok && !time.Now().After(until) {
return until, false
}
// We're either backing off and have passed the deadline, or
// we aren't backing off, so work out what the next interval
// will be.
count := s.backoffCount.Load()
until := time.Now().Add(s.duration(count))
s.backoffUntil.Store(until)
return until, false
}
// BackoffInfo returns information about the current or previous backoff.
// Returns the last backoffUntil time and whether the server is currently blacklisted or not.
func (s *ServerStatistics) BackoffInfo() (*time.Time, bool) {
until, ok := s.backoffUntil.Load().(time.Time)
if ok {
return &until, s.blacklisted.Load()
}
return nil, s.blacklisted.Load()
}
// Blacklisted returns true if the server is blacklisted and false
// otherwise.
func (s *ServerStatistics) Blacklisted() bool {
return s.blacklisted.Load()
}
// SuccessCount returns the number of successful requests. This is
// usually useful in constructing transaction IDs.
func (s *ServerStatistics) SuccessCount() uint32 {
return s.successCounter.Load()
}

View file

@ -0,0 +1,64 @@
package statistics
import (
"math"
"testing"
"time"
)
func TestBackoff(t *testing.T) {
stats := Statistics{
FailuresUntilBlacklist: 7,
}
server := ServerStatistics{
statistics: &stats,
serverName: "test.com",
}
// Start by checking that counting successes works.
server.Success()
if successes := server.SuccessCount(); successes != 1 {
t.Fatalf("Expected success count 1, got %d", successes)
}
// Register a failure.
server.Failure()
t.Logf("Backoff counter: %d", server.backoffCount.Load())
// Now we're going to simulate backing off a few times to see
// what happens.
for i := uint32(1); i <= 10; i++ {
// Register another failure for good measure. This should have no
// side effects since a backoff is already in progress. If it does
// then we'll fail.
until, blacklisted := server.Failure()
// Get the duration.
_, blacklist := server.BackoffInfo()
duration := time.Until(until).Round(time.Second)
// Unset the backoff, or otherwise our next call will think that
// there's a backoff in progress and return the same result.
server.cancel()
server.backoffStarted.Store(false)
// Check if we should be blacklisted by now.
if i >= stats.FailuresUntilBlacklist {
if !blacklist {
t.Fatalf("Backoff %d should have resulted in blacklist but didn't", i)
} else if blacklist != blacklisted {
t.Fatalf("BackoffInfo and Failure returned different blacklist values")
} else {
t.Logf("Backoff %d is blacklisted as expected", i)
continue
}
}
// Check if the duration is what we expect.
t.Logf("Backoff %d is for %s", i, duration)
if wanted := time.Second * time.Duration(math.Exp2(float64(i))); !blacklist && duration != wanted {
t.Fatalf("Backoff %d should have been %s but was %s", i, wanted, duration)
}
}
}

68
federationapi/storage/cache/keydb.go vendored Normal file
View file

@ -0,0 +1,68 @@
package cache
import (
"context"
"errors"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/gomatrixserverlib"
)
// A Database implements gomatrixserverlib.KeyDatabase and is used to store
// the public keys for other matrix servers.
type KeyDatabase struct {
inner gomatrixserverlib.KeyDatabase
cache caching.ServerKeyCache
}
func NewKeyDatabase(inner gomatrixserverlib.KeyDatabase, cache caching.ServerKeyCache) (*KeyDatabase, error) {
if inner == nil {
return nil, errors.New("inner database can't be nil")
}
if cache == nil {
return nil, errors.New("cache can't be nil")
}
return &KeyDatabase{
inner: inner,
cache: cache,
}, nil
}
// FetcherName implements KeyFetcher
func (d KeyDatabase) FetcherName() string {
return "InMemoryKeyCache"
}
// FetchKeys implements gomatrixserverlib.KeyDatabase
func (d *KeyDatabase) FetchKeys(
ctx context.Context,
requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) {
results := make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult)
for req, ts := range requests {
if res, cached := d.cache.GetServerKey(req, ts); cached {
results[req] = res
delete(requests, req)
}
}
fromDB, err := d.inner.FetchKeys(ctx, requests)
if err != nil {
return results, err
}
for req, res := range fromDB {
results[req] = res
d.cache.StoreServerKey(req, res)
}
return results, nil
}
// StoreKeys implements gomatrixserverlib.KeyDatabase
func (d *KeyDatabase) StoreKeys(
ctx context.Context,
keyMap map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult,
) error {
for req, res := range keyMap {
d.cache.StoreServerKey(req, res)
}
return d.inner.StoreKeys(ctx, keyMap)
}

View file

@ -0,0 +1,76 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package storage
import (
"context"
"github.com/matrix-org/dendrite/federationapi/storage/shared"
"github.com/matrix-org/dendrite/federationapi/types"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/gomatrixserverlib"
)
type Database interface {
internal.PartitionStorer
gomatrixserverlib.KeyDatabase
UpdateRoom(ctx context.Context, roomID, oldEventID, newEventID string, addHosts []types.JoinedHost, removeHosts []string) (joinedHosts []types.JoinedHost, err error)
GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error)
GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
// GetJoinedHostsForRooms returns the complete set of servers in the rooms given.
GetJoinedHostsForRooms(ctx context.Context, roomIDs []string) ([]gomatrixserverlib.ServerName, error)
PurgeRoomState(ctx context.Context, roomID string) error
StoreJSON(ctx context.Context, js string) (*shared.Receipt, error)
GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent, err error)
GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*shared.Receipt]*gomatrixserverlib.EDU, err error)
AssociatePDUWithDestination(ctx context.Context, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt) error
AssociateEDUWithDestination(ctx context.Context, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt) error
CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error
CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error
GetPendingPDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error)
GetPendingEDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error)
GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
// these don't have contexts passed in as we want things to happen regardless of the request context
AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error
RemoveServerFromBlacklist(serverName gomatrixserverlib.ServerName) error
RemoveAllServersFromBlacklist() error
IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error)
AddOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error
RenewOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error
GetOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.OutboundPeek, error)
GetOutboundPeeks(ctx context.Context, roomID string) ([]types.OutboundPeek, error)
AddInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error
RenewInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error
GetInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.InboundPeek, error)
GetInboundPeeks(ctx context.Context, roomID string) ([]types.InboundPeek, error)
// Update the notary with the given server keys from the given server name.
UpdateNotaryKeys(ctx context.Context, serverName gomatrixserverlib.ServerName, serverKeys gomatrixserverlib.ServerKeys) error
// Query the notary for the server keys for the given server. If `optKeyIDs` is not empty, multiple server keys may be returned (between 1 - len(optKeyIDs))
// such that the combination of all server keys will include all the `optKeyIDs`.
GetNotaryKeys(ctx context.Context, serverName gomatrixserverlib.ServerName, optKeyIDs []gomatrixserverlib.KeyID) ([]gomatrixserverlib.ServerKeys, error)
}

View file

@ -0,0 +1,115 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
)
const blacklistSchema = `
CREATE TABLE IF NOT EXISTS federationsender_blacklist (
-- The blacklisted server name
server_name TEXT NOT NULL,
UNIQUE (server_name)
);
`
const insertBlacklistSQL = "" +
"INSERT INTO federationsender_blacklist (server_name) VALUES ($1)" +
" ON CONFLICT DO NOTHING"
const selectBlacklistSQL = "" +
"SELECT server_name FROM federationsender_blacklist WHERE server_name = $1"
const deleteBlacklistSQL = "" +
"DELETE FROM federationsender_blacklist WHERE server_name = $1"
const deleteAllBlacklistSQL = "" +
"TRUNCATE federationsender_blacklist"
type blacklistStatements struct {
db *sql.DB
insertBlacklistStmt *sql.Stmt
selectBlacklistStmt *sql.Stmt
deleteBlacklistStmt *sql.Stmt
deleteAllBlacklistStmt *sql.Stmt
}
func NewPostgresBlacklistTable(db *sql.DB) (s *blacklistStatements, err error) {
s = &blacklistStatements{
db: db,
}
_, err = db.Exec(blacklistSchema)
if err != nil {
return
}
if s.insertBlacklistStmt, err = db.Prepare(insertBlacklistSQL); err != nil {
return
}
if s.selectBlacklistStmt, err = db.Prepare(selectBlacklistSQL); err != nil {
return
}
if s.deleteBlacklistStmt, err = db.Prepare(deleteBlacklistSQL); err != nil {
return
}
if s.deleteAllBlacklistStmt, err = db.Prepare(deleteAllBlacklistSQL); err != nil {
return
}
return
}
func (s *blacklistStatements) InsertBlacklist(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) error {
stmt := sqlutil.TxStmt(txn, s.insertBlacklistStmt)
_, err := stmt.ExecContext(ctx, serverName)
return err
}
func (s *blacklistStatements) SelectBlacklist(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) (bool, error) {
stmt := sqlutil.TxStmt(txn, s.selectBlacklistStmt)
res, err := stmt.QueryContext(ctx, serverName)
if err != nil {
return false, err
}
defer res.Close() // nolint:errcheck
// The query will return the server name if the server is blacklisted, and
// will return no rows if not. By calling Next, we find out if a row was
// returned or not - we don't care about the value itself.
return res.Next(), nil
}
func (s *blacklistStatements) DeleteBlacklist(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteBlacklistStmt)
_, err := stmt.ExecContext(ctx, serverName)
return err
}
func (s *blacklistStatements) DeleteAllBlacklist(
ctx context.Context, txn *sql.Tx,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteAllBlacklistStmt)
_, err := stmt.ExecContext(ctx)
return err
}

View file

@ -0,0 +1,46 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package deltas
import (
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/pressly/goose"
)
func LoadFromGoose() {
goose.AddMigration(UpRemoveRoomsTable, DownRemoveRoomsTable)
}
func LoadRemoveRoomsTable(m *sqlutil.Migrations) {
m.AddMigration(UpRemoveRoomsTable, DownRemoveRoomsTable)
}
func UpRemoveRoomsTable(tx *sql.Tx) error {
_, err := tx.Exec(`
DROP TABLE IF EXISTS federationsender_rooms;
`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
func DownRemoveRoomsTable(tx *sql.Tx) error {
// We can't reverse this.
return nil
}

View file

@ -0,0 +1,176 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"time"
"github.com/matrix-org/dendrite/federationapi/types"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
)
const inboundPeeksSchema = `
CREATE TABLE IF NOT EXISTS federationsender_inbound_peeks (
room_id TEXT NOT NULL,
server_name TEXT NOT NULL,
peek_id TEXT NOT NULL,
creation_ts BIGINT NOT NULL,
renewed_ts BIGINT NOT NULL,
renewal_interval BIGINT NOT NULL,
UNIQUE (room_id, server_name, peek_id)
);
`
const insertInboundPeekSQL = "" +
"INSERT INTO federationsender_inbound_peeks (room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval) VALUES ($1, $2, $3, $4, $5, $6)"
const selectInboundPeekSQL = "" +
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
const selectInboundPeeksSQL = "" +
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1"
const renewInboundPeekSQL = "" +
"UPDATE federationsender_inbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5"
const deleteInboundPeekSQL = "" +
"DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2"
const deleteInboundPeeksSQL = "" +
"DELETE FROM federationsender_inbound_peeks WHERE room_id = $1"
type inboundPeeksStatements struct {
db *sql.DB
insertInboundPeekStmt *sql.Stmt
selectInboundPeekStmt *sql.Stmt
selectInboundPeeksStmt *sql.Stmt
renewInboundPeekStmt *sql.Stmt
deleteInboundPeekStmt *sql.Stmt
deleteInboundPeeksStmt *sql.Stmt
}
func NewPostgresInboundPeeksTable(db *sql.DB) (s *inboundPeeksStatements, err error) {
s = &inboundPeeksStatements{
db: db,
}
_, err = db.Exec(inboundPeeksSchema)
if err != nil {
return
}
if s.insertInboundPeekStmt, err = db.Prepare(insertInboundPeekSQL); err != nil {
return
}
if s.selectInboundPeekStmt, err = db.Prepare(selectInboundPeekSQL); err != nil {
return
}
if s.selectInboundPeeksStmt, err = db.Prepare(selectInboundPeeksSQL); err != nil {
return
}
if s.renewInboundPeekStmt, err = db.Prepare(renewInboundPeekSQL); err != nil {
return
}
if s.deleteInboundPeeksStmt, err = db.Prepare(deleteInboundPeeksSQL); err != nil {
return
}
if s.deleteInboundPeekStmt, err = db.Prepare(deleteInboundPeekSQL); err != nil {
return
}
return
}
func (s *inboundPeeksStatements) InsertInboundPeek(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64,
) (err error) {
nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
stmt := sqlutil.TxStmt(txn, s.insertInboundPeekStmt)
_, err = stmt.ExecContext(ctx, roomID, serverName, peekID, nowMilli, nowMilli, renewalInterval)
return
}
func (s *inboundPeeksStatements) RenewInboundPeek(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64,
) (err error) {
nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
_, err = sqlutil.TxStmt(txn, s.renewInboundPeekStmt).ExecContext(ctx, nowMilli, renewalInterval, roomID, serverName, peekID)
return
}
func (s *inboundPeeksStatements) SelectInboundPeek(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string,
) (*types.InboundPeek, error) {
row := sqlutil.TxStmt(txn, s.selectInboundPeeksStmt).QueryRowContext(ctx, roomID)
inboundPeek := types.InboundPeek{}
err := row.Scan(
&inboundPeek.RoomID,
&inboundPeek.ServerName,
&inboundPeek.PeekID,
&inboundPeek.CreationTimestamp,
&inboundPeek.RenewedTimestamp,
&inboundPeek.RenewalInterval,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
return &inboundPeek, nil
}
func (s *inboundPeeksStatements) SelectInboundPeeks(
ctx context.Context, txn *sql.Tx, roomID string,
) (inboundPeeks []types.InboundPeek, err error) {
rows, err := sqlutil.TxStmt(txn, s.selectInboundPeeksStmt).QueryContext(ctx, roomID)
if err != nil {
return
}
defer internal.CloseAndLogIfError(ctx, rows, "SelectInboundPeeks: rows.close() failed")
for rows.Next() {
inboundPeek := types.InboundPeek{}
if err = rows.Scan(
&inboundPeek.RoomID,
&inboundPeek.ServerName,
&inboundPeek.PeekID,
&inboundPeek.CreationTimestamp,
&inboundPeek.RenewedTimestamp,
&inboundPeek.RenewalInterval,
); err != nil {
return
}
inboundPeeks = append(inboundPeeks, inboundPeek)
}
return inboundPeeks, rows.Err()
}
func (s *inboundPeeksStatements) DeleteInboundPeek(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string,
) (err error) {
_, err = sqlutil.TxStmt(txn, s.deleteInboundPeekStmt).ExecContext(ctx, roomID, serverName, peekID)
return
}
func (s *inboundPeeksStatements) DeleteInboundPeeks(
ctx context.Context, txn *sql.Tx, roomID string,
) (err error) {
_, err = sqlutil.TxStmt(txn, s.deleteInboundPeeksStmt).ExecContext(ctx, roomID)
return
}

View file

@ -0,0 +1,212 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/federationapi/types"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
)
const joinedHostsSchema = `
-- The joined_hosts table stores a list of m.room.member event ids in the
-- current state for each room where the membership is "join".
-- There will be an entry for every user that is joined to the room.
CREATE TABLE IF NOT EXISTS federationsender_joined_hosts (
-- The string ID of the room.
room_id TEXT NOT NULL,
-- The event ID of the m.room.member join event.
event_id TEXT NOT NULL,
-- The domain part of the user ID the m.room.member event is for.
server_name TEXT NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS federatonsender_joined_hosts_event_id_idx
ON federationsender_joined_hosts (event_id);
CREATE INDEX IF NOT EXISTS federatonsender_joined_hosts_room_id_idx
ON federationsender_joined_hosts (room_id)
`
const insertJoinedHostsSQL = "" +
"INSERT INTO federationsender_joined_hosts (room_id, event_id, server_name)" +
" VALUES ($1, $2, $3) ON CONFLICT DO NOTHING"
const deleteJoinedHostsSQL = "" +
"DELETE FROM federationsender_joined_hosts WHERE event_id = ANY($1)"
const deleteJoinedHostsForRoomSQL = "" +
"DELETE FROM federationsender_joined_hosts WHERE room_id = $1"
const selectJoinedHostsSQL = "" +
"SELECT event_id, server_name FROM federationsender_joined_hosts" +
" WHERE room_id = $1"
const selectAllJoinedHostsSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_joined_hosts"
const selectJoinedHostsForRoomsSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id = ANY($1)"
type joinedHostsStatements struct {
db *sql.DB
insertJoinedHostsStmt *sql.Stmt
deleteJoinedHostsStmt *sql.Stmt
deleteJoinedHostsForRoomStmt *sql.Stmt
selectJoinedHostsStmt *sql.Stmt
selectAllJoinedHostsStmt *sql.Stmt
selectJoinedHostsForRoomsStmt *sql.Stmt
}
func NewPostgresJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) {
s = &joinedHostsStatements{
db: db,
}
_, err = s.db.Exec(joinedHostsSchema)
if err != nil {
return
}
if s.insertJoinedHostsStmt, err = s.db.Prepare(insertJoinedHostsSQL); err != nil {
return
}
if s.deleteJoinedHostsStmt, err = s.db.Prepare(deleteJoinedHostsSQL); err != nil {
return
}
if s.deleteJoinedHostsForRoomStmt, err = s.db.Prepare(deleteJoinedHostsForRoomSQL); err != nil {
return
}
if s.selectJoinedHostsStmt, err = s.db.Prepare(selectJoinedHostsSQL); err != nil {
return
}
if s.selectAllJoinedHostsStmt, err = s.db.Prepare(selectAllJoinedHostsSQL); err != nil {
return
}
if s.selectJoinedHostsForRoomsStmt, err = s.db.Prepare(selectJoinedHostsForRoomsSQL); err != nil {
return
}
return
}
func (s *joinedHostsStatements) InsertJoinedHosts(
ctx context.Context,
txn *sql.Tx,
roomID, eventID string,
serverName gomatrixserverlib.ServerName,
) error {
stmt := sqlutil.TxStmt(txn, s.insertJoinedHostsStmt)
_, err := stmt.ExecContext(ctx, roomID, eventID, serverName)
return err
}
func (s *joinedHostsStatements) DeleteJoinedHosts(
ctx context.Context, txn *sql.Tx, eventIDs []string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsStmt)
_, err := stmt.ExecContext(ctx, pq.StringArray(eventIDs))
return err
}
func (s *joinedHostsStatements) DeleteJoinedHostsForRoom(
ctx context.Context, txn *sql.Tx, roomID string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsForRoomStmt)
_, err := stmt.ExecContext(ctx, roomID)
return err
}
func (s *joinedHostsStatements) SelectJoinedHostsWithTx(
ctx context.Context, txn *sql.Tx, roomID string,
) ([]types.JoinedHost, error) {
stmt := sqlutil.TxStmt(txn, s.selectJoinedHostsStmt)
return joinedHostsFromStmt(ctx, stmt, roomID)
}
func (s *joinedHostsStatements) SelectJoinedHosts(
ctx context.Context, roomID string,
) ([]types.JoinedHost, error) {
return joinedHostsFromStmt(ctx, s.selectJoinedHostsStmt, roomID)
}
func (s *joinedHostsStatements) SelectAllJoinedHosts(
ctx context.Context,
) ([]gomatrixserverlib.ServerName, error) {
rows, err := s.selectAllJoinedHostsStmt.QueryContext(ctx)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectAllJoinedHosts: rows.close() failed")
var result []gomatrixserverlib.ServerName
for rows.Next() {
var serverName string
if err = rows.Scan(&serverName); err != nil {
return nil, err
}
result = append(result, gomatrixserverlib.ServerName(serverName))
}
return result, rows.Err()
}
func (s *joinedHostsStatements) SelectJoinedHostsForRooms(
ctx context.Context, roomIDs []string,
) ([]gomatrixserverlib.ServerName, error) {
rows, err := s.selectJoinedHostsForRoomsStmt.QueryContext(ctx, pq.StringArray(roomIDs))
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedHostsForRoomsStmt: rows.close() failed")
var result []gomatrixserverlib.ServerName
for rows.Next() {
var serverName string
if err = rows.Scan(&serverName); err != nil {
return nil, err
}
result = append(result, gomatrixserverlib.ServerName(serverName))
}
return result, rows.Err()
}
func joinedHostsFromStmt(
ctx context.Context, stmt *sql.Stmt, roomID string,
) ([]types.JoinedHost, error) {
rows, err := stmt.QueryContext(ctx, roomID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "joinedHostsFromStmt: rows.close() failed")
var result []types.JoinedHost
for rows.Next() {
var eventID, serverName string
if err = rows.Scan(&eventID, &serverName); err != nil {
return nil, err
}
result = append(result, types.JoinedHost{
MemberEventID: eventID,
ServerName: gomatrixserverlib.ServerName(serverName),
})
}
return result, rows.Err()
}

View file

@ -0,0 +1,64 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/federationapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
)
const notaryServerKeysJSONSchema = `
CREATE SEQUENCE IF NOT EXISTS federationsender_notary_server_keys_json_pkey;
CREATE TABLE IF NOT EXISTS federationsender_notary_server_keys_json (
notary_id BIGINT PRIMARY KEY NOT NULL DEFAULT nextval('federationsender_notary_server_keys_json_pkey'),
response_json TEXT NOT NULL,
server_name TEXT NOT NULL,
valid_until BIGINT NOT NULL
);
`
const insertServerKeysJSONSQL = "" +
"INSERT INTO federationsender_notary_server_keys_json (response_json, server_name, valid_until) VALUES ($1, $2, $3)" +
" RETURNING notary_id"
type notaryServerKeysStatements struct {
db *sql.DB
insertServerKeysJSONStmt *sql.Stmt
}
func NewPostgresNotaryServerKeysTable(db *sql.DB) (s *notaryServerKeysStatements, err error) {
s = &notaryServerKeysStatements{
db: db,
}
_, err = db.Exec(notaryServerKeysJSONSchema)
if err != nil {
return
}
if s.insertServerKeysJSONStmt, err = db.Prepare(insertServerKeysJSONSQL); err != nil {
return
}
return
}
func (s *notaryServerKeysStatements) InsertJSONResponse(
ctx context.Context, txn *sql.Tx, keyQueryResponseJSON gomatrixserverlib.ServerKeys, serverName gomatrixserverlib.ServerName, validUntil gomatrixserverlib.Timestamp,
) (tables.NotaryID, error) {
var notaryID tables.NotaryID
return notaryID, txn.Stmt(s.insertServerKeysJSONStmt).QueryRowContext(ctx, string(keyQueryResponseJSON.Raw), serverName, validUntil).Scan(&notaryID)
}

View file

@ -0,0 +1,167 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"encoding/json"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/federationapi/storage/tables"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/gomatrixserverlib"
)
const notaryServerKeysMetadataSchema = `
CREATE TABLE IF NOT EXISTS federationsender_notary_server_keys_metadata (
notary_id BIGINT NOT NULL,
server_name TEXT NOT NULL,
key_id TEXT NOT NULL,
UNIQUE (server_name, key_id)
);
`
const upsertServerKeysSQL = "" +
"INSERT INTO federationsender_notary_server_keys_metadata (notary_id, server_name, key_id) VALUES ($1, $2, $3)" +
" ON CONFLICT (server_name, key_id) DO UPDATE SET notary_id = $1"
// for a given (server_name, key_id), find the existing notary ID and valid until. Used to check if we will replace it
// JOINs with the json table
const selectNotaryKeyMetadataSQL = `
SELECT federationsender_notary_server_keys_metadata.notary_id, valid_until FROM federationsender_notary_server_keys_json
JOIN federationsender_notary_server_keys_metadata ON
federationsender_notary_server_keys_metadata.notary_id = federationsender_notary_server_keys_json.notary_id
WHERE federationsender_notary_server_keys_metadata.server_name = $1 AND federationsender_notary_server_keys_metadata.key_id = $2
`
// select the response which has the highest valid_until value
// JOINs with the json table
const selectNotaryKeyResponsesSQL = `
SELECT response_json FROM federationsender_notary_server_keys_json
WHERE server_name = $1 AND valid_until = (
SELECT MAX(valid_until) FROM federationsender_notary_server_keys_json WHERE server_name = $1
)
`
// select the responses which have the given key IDs
// JOINs with the json table
const selectNotaryKeyResponsesWithKeyIDsSQL = `
SELECT response_json FROM federationsender_notary_server_keys_json
JOIN federationsender_notary_server_keys_metadata ON
federationsender_notary_server_keys_metadata.notary_id = federationsender_notary_server_keys_json.notary_id
WHERE federationsender_notary_server_keys_json.server_name = $1 AND federationsender_notary_server_keys_metadata.key_id = ANY ($2)
GROUP BY federationsender_notary_server_keys_json.notary_id
`
// JOINs with the metadata table
const deleteUnusedServerKeysJSONSQL = `
DELETE FROM federationsender_notary_server_keys_json WHERE federationsender_notary_server_keys_json.notary_id NOT IN (
SELECT DISTINCT notary_id FROM federationsender_notary_server_keys_metadata
)
`
type notaryServerKeysMetadataStatements struct {
db *sql.DB
upsertServerKeysStmt *sql.Stmt
selectNotaryKeyResponsesStmt *sql.Stmt
selectNotaryKeyResponsesWithKeyIDsStmt *sql.Stmt
selectNotaryKeyMetadataStmt *sql.Stmt
deleteUnusedServerKeysJSONStmt *sql.Stmt
}
func NewPostgresNotaryServerKeysMetadataTable(db *sql.DB) (s *notaryServerKeysMetadataStatements, err error) {
s = &notaryServerKeysMetadataStatements{
db: db,
}
_, err = db.Exec(notaryServerKeysMetadataSchema)
if err != nil {
return
}
if s.upsertServerKeysStmt, err = db.Prepare(upsertServerKeysSQL); err != nil {
return
}
if s.selectNotaryKeyResponsesStmt, err = db.Prepare(selectNotaryKeyResponsesSQL); err != nil {
return
}
if s.selectNotaryKeyResponsesWithKeyIDsStmt, err = db.Prepare(selectNotaryKeyResponsesWithKeyIDsSQL); err != nil {
return
}
if s.selectNotaryKeyMetadataStmt, err = db.Prepare(selectNotaryKeyMetadataSQL); err != nil {
return
}
if s.deleteUnusedServerKeysJSONStmt, err = db.Prepare(deleteUnusedServerKeysJSONSQL); err != nil {
return
}
return
}
func (s *notaryServerKeysMetadataStatements) UpsertKey(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, keyID gomatrixserverlib.KeyID, newNotaryID tables.NotaryID, newValidUntil gomatrixserverlib.Timestamp,
) (tables.NotaryID, error) {
notaryID := newNotaryID
// see if the existing notary ID a) exists, b) has a longer valid_until
var existingNotaryID tables.NotaryID
var existingValidUntil gomatrixserverlib.Timestamp
if err := txn.Stmt(s.selectNotaryKeyMetadataStmt).QueryRowContext(ctx, serverName, keyID).Scan(&existingNotaryID, &existingValidUntil); err != nil {
if err != sql.ErrNoRows {
return 0, err
}
}
if existingValidUntil.Time().After(newValidUntil.Time()) {
// the existing valid_until is valid longer, so use that.
return existingNotaryID, nil
}
// overwrite the notary_id for this (server_name, key_id) tuple
_, err := txn.Stmt(s.upsertServerKeysStmt).ExecContext(ctx, notaryID, serverName, keyID)
return notaryID, err
}
func (s *notaryServerKeysMetadataStatements) SelectKeys(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, keyIDs []gomatrixserverlib.KeyID) ([]gomatrixserverlib.ServerKeys, error) {
var rows *sql.Rows
var err error
if len(keyIDs) == 0 {
rows, err = txn.Stmt(s.selectNotaryKeyResponsesStmt).QueryContext(ctx, string(serverName))
} else {
keyIDstr := make([]string, len(keyIDs))
for i := range keyIDs {
keyIDstr[i] = string(keyIDs[i])
}
rows, err = txn.Stmt(s.selectNotaryKeyResponsesWithKeyIDsStmt).QueryContext(ctx, string(serverName), pq.StringArray(keyIDstr))
}
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectNotaryKeyResponsesStmt close failed")
var results []gomatrixserverlib.ServerKeys
for rows.Next() {
var sk gomatrixserverlib.ServerKeys
var raw string
if err = rows.Scan(&raw); err != nil {
return nil, err
}
if err = json.Unmarshal([]byte(raw), &sk); err != nil {
return nil, err
}
results = append(results, sk)
}
return results, nil
}
func (s *notaryServerKeysMetadataStatements) DeleteOldJSONResponses(ctx context.Context, txn *sql.Tx) error {
_, err := txn.Stmt(s.deleteUnusedServerKeysJSONStmt).ExecContext(ctx)
return err
}

View file

@ -0,0 +1,176 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"time"
"github.com/matrix-org/dendrite/federationapi/types"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
)
const outboundPeeksSchema = `
CREATE TABLE IF NOT EXISTS federationsender_outbound_peeks (
room_id TEXT NOT NULL,
server_name TEXT NOT NULL,
peek_id TEXT NOT NULL,
creation_ts BIGINT NOT NULL,
renewed_ts BIGINT NOT NULL,
renewal_interval BIGINT NOT NULL,
UNIQUE (room_id, server_name, peek_id)
);
`
const insertOutboundPeekSQL = "" +
"INSERT INTO federationsender_outbound_peeks (room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval) VALUES ($1, $2, $3, $4, $5, $6)"
const selectOutboundPeekSQL = "" +
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
const selectOutboundPeeksSQL = "" +
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1"
const renewOutboundPeekSQL = "" +
"UPDATE federationsender_outbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5"
const deleteOutboundPeekSQL = "" +
"DELETE FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2"
const deleteOutboundPeeksSQL = "" +
"DELETE FROM federationsender_outbound_peeks WHERE room_id = $1"
type outboundPeeksStatements struct {
db *sql.DB
insertOutboundPeekStmt *sql.Stmt
selectOutboundPeekStmt *sql.Stmt
selectOutboundPeeksStmt *sql.Stmt
renewOutboundPeekStmt *sql.Stmt
deleteOutboundPeekStmt *sql.Stmt
deleteOutboundPeeksStmt *sql.Stmt
}
func NewPostgresOutboundPeeksTable(db *sql.DB) (s *outboundPeeksStatements, err error) {
s = &outboundPeeksStatements{
db: db,
}
_, err = db.Exec(outboundPeeksSchema)
if err != nil {
return
}
if s.insertOutboundPeekStmt, err = db.Prepare(insertOutboundPeekSQL); err != nil {
return
}
if s.selectOutboundPeekStmt, err = db.Prepare(selectOutboundPeekSQL); err != nil {
return
}
if s.selectOutboundPeeksStmt, err = db.Prepare(selectOutboundPeeksSQL); err != nil {
return
}
if s.renewOutboundPeekStmt, err = db.Prepare(renewOutboundPeekSQL); err != nil {
return
}
if s.deleteOutboundPeeksStmt, err = db.Prepare(deleteOutboundPeeksSQL); err != nil {
return
}
if s.deleteOutboundPeekStmt, err = db.Prepare(deleteOutboundPeekSQL); err != nil {
return
}
return
}
func (s *outboundPeeksStatements) InsertOutboundPeek(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64,
) (err error) {
nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
stmt := sqlutil.TxStmt(txn, s.insertOutboundPeekStmt)
_, err = stmt.ExecContext(ctx, roomID, serverName, peekID, nowMilli, nowMilli, renewalInterval)
return
}
func (s *outboundPeeksStatements) RenewOutboundPeek(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64,
) (err error) {
nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
_, err = sqlutil.TxStmt(txn, s.renewOutboundPeekStmt).ExecContext(ctx, nowMilli, renewalInterval, roomID, serverName, peekID)
return
}
func (s *outboundPeeksStatements) SelectOutboundPeek(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string,
) (*types.OutboundPeek, error) {
row := sqlutil.TxStmt(txn, s.selectOutboundPeeksStmt).QueryRowContext(ctx, roomID)
outboundPeek := types.OutboundPeek{}
err := row.Scan(
&outboundPeek.RoomID,
&outboundPeek.ServerName,
&outboundPeek.PeekID,
&outboundPeek.CreationTimestamp,
&outboundPeek.RenewedTimestamp,
&outboundPeek.RenewalInterval,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
return &outboundPeek, nil
}
func (s *outboundPeeksStatements) SelectOutboundPeeks(
ctx context.Context, txn *sql.Tx, roomID string,
) (outboundPeeks []types.OutboundPeek, err error) {
rows, err := sqlutil.TxStmt(txn, s.selectOutboundPeeksStmt).QueryContext(ctx, roomID)
if err != nil {
return
}
defer internal.CloseAndLogIfError(ctx, rows, "SelectOutboundPeeks: rows.close() failed")
for rows.Next() {
outboundPeek := types.OutboundPeek{}
if err = rows.Scan(
&outboundPeek.RoomID,
&outboundPeek.ServerName,
&outboundPeek.PeekID,
&outboundPeek.CreationTimestamp,
&outboundPeek.RenewedTimestamp,
&outboundPeek.RenewalInterval,
); err != nil {
return
}
outboundPeeks = append(outboundPeeks, outboundPeek)
}
return outboundPeeks, rows.Err()
}
func (s *outboundPeeksStatements) DeleteOutboundPeek(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string,
) (err error) {
_, err = sqlutil.TxStmt(txn, s.deleteOutboundPeekStmt).ExecContext(ctx, roomID, serverName, peekID)
return
}
func (s *outboundPeeksStatements) DeleteOutboundPeeks(
ctx context.Context, txn *sql.Tx, roomID string,
) (err error) {
_, err = sqlutil.TxStmt(txn, s.deleteOutboundPeeksStmt).ExecContext(ctx, roomID)
return
}

View file

@ -0,0 +1,198 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
)
const queueEDUsSchema = `
CREATE TABLE IF NOT EXISTS federationsender_queue_edus (
-- The type of the event (informational).
edu_type TEXT NOT NULL,
-- The domain part of the user ID the EDU event is for.
server_name TEXT NOT NULL,
-- The JSON NID from the federationsender_queue_edus_json table.
json_nid BIGINT NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_edus_json_nid_idx
ON federationsender_queue_edus (json_nid, server_name);
`
const insertQueueEDUSQL = "" +
"INSERT INTO federationsender_queue_edus (edu_type, server_name, json_nid)" +
" VALUES ($1, $2, $3)"
const deleteQueueEDUSQL = "" +
"DELETE FROM federationsender_queue_edus WHERE server_name = $1 AND json_nid = ANY($2)"
const selectQueueEDUSQL = "" +
"SELECT json_nid FROM federationsender_queue_edus" +
" WHERE server_name = $1" +
" LIMIT $2"
const selectQueueEDUReferenceJSONCountSQL = "" +
"SELECT COUNT(*) FROM federationsender_queue_edus" +
" WHERE json_nid = $1"
const selectQueueEDUCountSQL = "" +
"SELECT COUNT(*) FROM federationsender_queue_edus" +
" WHERE server_name = $1"
const selectQueueServerNamesSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_queue_edus"
type queueEDUsStatements struct {
db *sql.DB
insertQueueEDUStmt *sql.Stmt
deleteQueueEDUStmt *sql.Stmt
selectQueueEDUStmt *sql.Stmt
selectQueueEDUReferenceJSONCountStmt *sql.Stmt
selectQueueEDUCountStmt *sql.Stmt
selectQueueEDUServerNamesStmt *sql.Stmt
}
func NewPostgresQueueEDUsTable(db *sql.DB) (s *queueEDUsStatements, err error) {
s = &queueEDUsStatements{
db: db,
}
_, err = s.db.Exec(queueEDUsSchema)
if err != nil {
return
}
if s.insertQueueEDUStmt, err = s.db.Prepare(insertQueueEDUSQL); err != nil {
return
}
if s.deleteQueueEDUStmt, err = s.db.Prepare(deleteQueueEDUSQL); err != nil {
return
}
if s.selectQueueEDUStmt, err = s.db.Prepare(selectQueueEDUSQL); err != nil {
return
}
if s.selectQueueEDUReferenceJSONCountStmt, err = s.db.Prepare(selectQueueEDUReferenceJSONCountSQL); err != nil {
return
}
if s.selectQueueEDUCountStmt, err = s.db.Prepare(selectQueueEDUCountSQL); err != nil {
return
}
if s.selectQueueEDUServerNamesStmt, err = s.db.Prepare(selectQueueServerNamesSQL); err != nil {
return
}
return
}
func (s *queueEDUsStatements) InsertQueueEDU(
ctx context.Context,
txn *sql.Tx,
eduType string,
serverName gomatrixserverlib.ServerName,
nid int64,
) error {
stmt := sqlutil.TxStmt(txn, s.insertQueueEDUStmt)
_, err := stmt.ExecContext(
ctx,
eduType, // the EDU type
serverName, // destination server name
nid, // JSON blob NID
)
return err
}
func (s *queueEDUsStatements) DeleteQueueEDUs(
ctx context.Context, txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
jsonNIDs []int64,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteQueueEDUStmt)
_, err := stmt.ExecContext(ctx, serverName, pq.Int64Array(jsonNIDs))
return err
}
func (s *queueEDUsStatements) SelectQueueEDUs(
ctx context.Context, txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
limit int,
) ([]int64, error) {
stmt := sqlutil.TxStmt(txn, s.selectQueueEDUStmt)
rows, err := stmt.QueryContext(ctx, serverName, limit)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed")
var result []int64
for rows.Next() {
var nid int64
if err = rows.Scan(&nid); err != nil {
return nil, err
}
result = append(result, nid)
}
return result, nil
}
func (s *queueEDUsStatements) SelectQueueEDUReferenceJSONCount(
ctx context.Context, txn *sql.Tx, jsonNID int64,
) (int64, error) {
var count int64
stmt := sqlutil.TxStmt(txn, s.selectQueueEDUReferenceJSONCountStmt)
err := stmt.QueryRowContext(ctx, jsonNID).Scan(&count)
if err == sql.ErrNoRows {
return -1, nil
}
return count, err
}
func (s *queueEDUsStatements) SelectQueueEDUCount(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) (int64, error) {
var count int64
stmt := sqlutil.TxStmt(txn, s.selectQueueEDUCountStmt)
err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
if err == sql.ErrNoRows {
// It's acceptable for there to be no rows referencing a given
// JSON NID but it's not an error condition. Just return as if
// there's a zero count.
return 0, nil
}
return count, err
}
func (s *queueEDUsStatements) SelectQueueEDUServerNames(
ctx context.Context, txn *sql.Tx,
) ([]gomatrixserverlib.ServerName, error) {
stmt := sqlutil.TxStmt(txn, s.selectQueueEDUServerNamesStmt)
rows, err := stmt.QueryContext(ctx)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed")
var result []gomatrixserverlib.ServerName
for rows.Next() {
var serverName gomatrixserverlib.ServerName
if err = rows.Scan(&serverName); err != nil {
return nil, err
}
result = append(result, serverName)
}
return result, rows.Err()
}

View file

@ -0,0 +1,115 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
)
const queueJSONSchema = `
-- The federationsender_queue_json table contains event contents that
-- we failed to send.
CREATE TABLE IF NOT EXISTS federationsender_queue_json (
-- The JSON NID. This allows the federationsender_queue_retry table to
-- cross-reference to find the JSON blob.
json_nid BIGSERIAL,
-- The JSON body. Text so that we preserve UTF-8.
json_body TEXT NOT NULL
);
`
const insertJSONSQL = "" +
"INSERT INTO federationsender_queue_json (json_body)" +
" VALUES ($1)" +
" RETURNING json_nid"
const deleteJSONSQL = "" +
"DELETE FROM federationsender_queue_json WHERE json_nid = ANY($1)"
const selectJSONSQL = "" +
"SELECT json_nid, json_body FROM federationsender_queue_json" +
" WHERE json_nid = ANY($1)"
type queueJSONStatements struct {
db *sql.DB
insertJSONStmt *sql.Stmt
deleteJSONStmt *sql.Stmt
selectJSONStmt *sql.Stmt
}
func NewPostgresQueueJSONTable(db *sql.DB) (s *queueJSONStatements, err error) {
s = &queueJSONStatements{
db: db,
}
_, err = s.db.Exec(queueJSONSchema)
if err != nil {
return
}
if s.insertJSONStmt, err = s.db.Prepare(insertJSONSQL); err != nil {
return
}
if s.deleteJSONStmt, err = s.db.Prepare(deleteJSONSQL); err != nil {
return
}
if s.selectJSONStmt, err = s.db.Prepare(selectJSONSQL); err != nil {
return
}
return
}
func (s *queueJSONStatements) InsertQueueJSON(
ctx context.Context, txn *sql.Tx, json string,
) (int64, error) {
stmt := sqlutil.TxStmt(txn, s.insertJSONStmt)
var lastid int64
if err := stmt.QueryRowContext(ctx, json).Scan(&lastid); err != nil {
return 0, err
}
return lastid, nil
}
func (s *queueJSONStatements) DeleteQueueJSON(
ctx context.Context, txn *sql.Tx, nids []int64,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteJSONStmt)
_, err := stmt.ExecContext(ctx, pq.Int64Array(nids))
return err
}
func (s *queueJSONStatements) SelectQueueJSON(
ctx context.Context, txn *sql.Tx, jsonNIDs []int64,
) (map[int64][]byte, error) {
blobs := map[int64][]byte{}
stmt := sqlutil.TxStmt(txn, s.selectJSONStmt)
rows, err := stmt.QueryContext(ctx, pq.Int64Array(jsonNIDs))
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectJSON: rows.close() failed")
for rows.Next() {
var nid int64
var blob []byte
if err = rows.Scan(&nid, &blob); err != nil {
return nil, err
}
blobs[nid] = blob
}
return blobs, err
}

View file

@ -0,0 +1,202 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
)
const queuePDUsSchema = `
CREATE TABLE IF NOT EXISTS federationsender_queue_pdus (
-- The transaction ID that was generated before persisting the event.
transaction_id TEXT NOT NULL,
-- The destination server that we will send the event to.
server_name TEXT NOT NULL,
-- The JSON NID from the federationsender_queue_pdus_json table.
json_nid BIGINT NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_pdus_pdus_json_nid_idx
ON federationsender_queue_pdus (json_nid, server_name);
`
const insertQueuePDUSQL = "" +
"INSERT INTO federationsender_queue_pdus (transaction_id, server_name, json_nid)" +
" VALUES ($1, $2, $3)"
const deleteQueuePDUSQL = "" +
"DELETE FROM federationsender_queue_pdus WHERE server_name = $1 AND json_nid = ANY($2)"
const selectQueuePDUsSQL = "" +
"SELECT json_nid FROM federationsender_queue_pdus" +
" WHERE server_name = $1" +
" LIMIT $2"
const selectQueuePDUReferenceJSONCountSQL = "" +
"SELECT COUNT(*) FROM federationsender_queue_pdus" +
" WHERE json_nid = $1"
const selectQueuePDUsCountSQL = "" +
"SELECT COUNT(*) FROM federationsender_queue_pdus" +
" WHERE server_name = $1"
const selectQueuePDUServerNamesSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_queue_pdus"
type queuePDUsStatements struct {
db *sql.DB
insertQueuePDUStmt *sql.Stmt
deleteQueuePDUsStmt *sql.Stmt
selectQueuePDUsStmt *sql.Stmt
selectQueuePDUReferenceJSONCountStmt *sql.Stmt
selectQueuePDUsCountStmt *sql.Stmt
selectQueuePDUServerNamesStmt *sql.Stmt
}
func NewPostgresQueuePDUsTable(db *sql.DB) (s *queuePDUsStatements, err error) {
s = &queuePDUsStatements{
db: db,
}
_, err = s.db.Exec(queuePDUsSchema)
if err != nil {
return
}
if s.insertQueuePDUStmt, err = s.db.Prepare(insertQueuePDUSQL); err != nil {
return
}
if s.deleteQueuePDUsStmt, err = s.db.Prepare(deleteQueuePDUSQL); err != nil {
return
}
if s.selectQueuePDUsStmt, err = s.db.Prepare(selectQueuePDUsSQL); err != nil {
return
}
if s.selectQueuePDUReferenceJSONCountStmt, err = s.db.Prepare(selectQueuePDUReferenceJSONCountSQL); err != nil {
return
}
if s.selectQueuePDUsCountStmt, err = s.db.Prepare(selectQueuePDUsCountSQL); err != nil {
return
}
if s.selectQueuePDUServerNamesStmt, err = s.db.Prepare(selectQueuePDUServerNamesSQL); err != nil {
return
}
return
}
func (s *queuePDUsStatements) InsertQueuePDU(
ctx context.Context,
txn *sql.Tx,
transactionID gomatrixserverlib.TransactionID,
serverName gomatrixserverlib.ServerName,
nid int64,
) error {
stmt := sqlutil.TxStmt(txn, s.insertQueuePDUStmt)
_, err := stmt.ExecContext(
ctx,
transactionID, // the transaction ID that we initially attempted
serverName, // destination server name
nid, // JSON blob NID
)
return err
}
func (s *queuePDUsStatements) DeleteQueuePDUs(
ctx context.Context, txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
jsonNIDs []int64,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteQueuePDUsStmt)
_, err := stmt.ExecContext(ctx, serverName, pq.Int64Array(jsonNIDs))
return err
}
func (s *queuePDUsStatements) SelectQueuePDUReferenceJSONCount(
ctx context.Context, txn *sql.Tx, jsonNID int64,
) (int64, error) {
var count int64
stmt := sqlutil.TxStmt(txn, s.selectQueuePDUReferenceJSONCountStmt)
err := stmt.QueryRowContext(ctx, jsonNID).Scan(&count)
if err == sql.ErrNoRows {
// It's acceptable for there to be no rows referencing a given
// JSON NID but it's not an error condition. Just return as if
// there's a zero count.
return 0, nil
}
return count, err
}
func (s *queuePDUsStatements) SelectQueuePDUCount(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) (int64, error) {
var count int64
stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsCountStmt)
err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
if err == sql.ErrNoRows {
// It's acceptable for there to be no rows referencing a given
// JSON NID but it's not an error condition. Just return as if
// there's a zero count.
return 0, nil
}
return count, err
}
func (s *queuePDUsStatements) SelectQueuePDUs(
ctx context.Context, txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
limit int,
) ([]int64, error) {
stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsStmt)
rows, err := stmt.QueryContext(ctx, serverName, limit)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed")
var result []int64
for rows.Next() {
var nid int64
if err = rows.Scan(&nid); err != nil {
return nil, err
}
result = append(result, nid)
}
return result, rows.Err()
}
func (s *queuePDUsStatements) SelectQueuePDUServerNames(
ctx context.Context, txn *sql.Tx,
) ([]gomatrixserverlib.ServerName, error) {
stmt := sqlutil.TxStmt(txn, s.selectQueuePDUServerNamesStmt)
rows, err := stmt.QueryContext(ctx)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed")
var result []gomatrixserverlib.ServerName
for rows.Next() {
var serverName gomatrixserverlib.ServerName
if err = rows.Scan(&serverName); err != nil {
return nil, err
}
result = append(result, serverName)
}
return result, rows.Err()
}

View file

@ -0,0 +1,146 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
)
const serverSigningKeysSchema = `
-- A cache of signing keys downloaded from remote servers.
CREATE TABLE IF NOT EXISTS keydb_server_keys (
-- The name of the matrix server the key is for.
server_name TEXT NOT NULL,
-- The ID of the server key.
server_key_id TEXT NOT NULL,
-- Combined server name and key ID separated by the ASCII unit separator
-- to make it easier to run bulk queries.
server_name_and_key_id TEXT NOT NULL,
-- When the key is valid until as a millisecond timestamp.
-- 0 if this is an expired key (in which case expired_ts will be non-zero)
valid_until_ts BIGINT NOT NULL,
-- When the key expired as a millisecond timestamp.
-- 0 if this is an active key (in which case valid_until_ts will be non-zero)
expired_ts BIGINT NOT NULL,
-- The base64-encoded public key.
server_key TEXT NOT NULL,
CONSTRAINT keydb_server_keys_unique UNIQUE (server_name, server_key_id)
);
CREATE INDEX IF NOT EXISTS keydb_server_name_and_key_id ON keydb_server_keys (server_name_and_key_id);
`
const bulkSelectServerSigningKeysSQL = "" +
"SELECT server_name, server_key_id, valid_until_ts, expired_ts, " +
" server_key FROM keydb_server_keys" +
" WHERE server_name_and_key_id = ANY($1)"
const upsertServerSigningKeysSQL = "" +
"INSERT INTO keydb_server_keys (server_name, server_key_id," +
" server_name_and_key_id, valid_until_ts, expired_ts, server_key)" +
" VALUES ($1, $2, $3, $4, $5, $6)" +
" ON CONFLICT ON CONSTRAINT keydb_server_keys_unique" +
" DO UPDATE SET valid_until_ts = $4, expired_ts = $5, server_key = $6"
type serverSigningKeyStatements struct {
bulkSelectServerKeysStmt *sql.Stmt
upsertServerKeysStmt *sql.Stmt
}
func NewPostgresServerSigningKeysTable(db *sql.DB) (s *serverSigningKeyStatements, err error) {
s = &serverSigningKeyStatements{}
_, err = db.Exec(serverSigningKeysSchema)
if err != nil {
return
}
if s.bulkSelectServerKeysStmt, err = db.Prepare(bulkSelectServerSigningKeysSQL); err != nil {
return
}
if s.upsertServerKeysStmt, err = db.Prepare(upsertServerSigningKeysSQL); err != nil {
return
}
return s, nil
}
func (s *serverSigningKeyStatements) BulkSelectServerKeys(
ctx context.Context, txn *sql.Tx,
requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) {
var nameAndKeyIDs []string
for request := range requests {
nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request))
}
stmt := s.bulkSelectServerKeysStmt
rows, err := stmt.QueryContext(ctx, pq.StringArray(nameAndKeyIDs))
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectServerKeys: rows.close() failed")
results := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{}
for rows.Next() {
var serverName string
var keyID string
var key string
var validUntilTS int64
var expiredTS int64
if err = rows.Scan(&serverName, &keyID, &validUntilTS, &expiredTS, &key); err != nil {
return nil, err
}
r := gomatrixserverlib.PublicKeyLookupRequest{
ServerName: gomatrixserverlib.ServerName(serverName),
KeyID: gomatrixserverlib.KeyID(keyID),
}
vk := gomatrixserverlib.VerifyKey{}
err = vk.Key.Decode(key)
if err != nil {
return nil, err
}
results[r] = gomatrixserverlib.PublicKeyLookupResult{
VerifyKey: vk,
ValidUntilTS: gomatrixserverlib.Timestamp(validUntilTS),
ExpiredTS: gomatrixserverlib.Timestamp(expiredTS),
}
}
return results, rows.Err()
}
func (s *serverSigningKeyStatements) UpsertServerKeys(
ctx context.Context, txn *sql.Tx,
request gomatrixserverlib.PublicKeyLookupRequest,
key gomatrixserverlib.PublicKeyLookupResult,
) error {
stmt := sqlutil.TxStmt(txn, s.upsertServerKeysStmt)
_, err := stmt.ExecContext(
ctx,
string(request.ServerName),
string(request.KeyID),
nameAndKeyID(request),
key.ValidUntilTS,
key.ExpiredTS,
key.Key.Encode(),
)
return err
}
func nameAndKeyID(request gomatrixserverlib.PublicKeyLookupRequest) string {
return string(request.ServerName) + "\x1F" + string(request.KeyID)
}

View file

@ -0,0 +1,109 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/federationapi/storage/postgres/deltas"
"github.com/matrix-org/dendrite/federationapi/storage/shared"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
)
// Database stores information needed by the federation sender
type Database struct {
shared.Database
sqlutil.PartitionOffsetStatements
db *sql.DB
writer sqlutil.Writer
}
// NewDatabase opens a new database
func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationCache) (*Database, error) {
var d Database
var err error
if d.db, err = sqlutil.Open(dbProperties); err != nil {
return nil, err
}
d.writer = sqlutil.NewDummyWriter()
joinedHosts, err := NewPostgresJoinedHostsTable(d.db)
if err != nil {
return nil, err
}
queuePDUs, err := NewPostgresQueuePDUsTable(d.db)
if err != nil {
return nil, err
}
queueEDUs, err := NewPostgresQueueEDUsTable(d.db)
if err != nil {
return nil, err
}
queueJSON, err := NewPostgresQueueJSONTable(d.db)
if err != nil {
return nil, err
}
blacklist, err := NewPostgresBlacklistTable(d.db)
if err != nil {
return nil, err
}
inboundPeeks, err := NewPostgresInboundPeeksTable(d.db)
if err != nil {
return nil, err
}
outboundPeeks, err := NewPostgresOutboundPeeksTable(d.db)
if err != nil {
return nil, err
}
notaryJSON, err := NewPostgresNotaryServerKeysTable(d.db)
if err != nil {
return nil, fmt.Errorf("NewPostgresNotaryServerKeysTable: %s", err)
}
notaryMetadata, err := NewPostgresNotaryServerKeysMetadataTable(d.db)
if err != nil {
return nil, fmt.Errorf("NewPostgresNotaryServerKeysMetadataTable: %s", err)
}
serverSigningKeys, err := NewPostgresServerSigningKeysTable(d.db)
if err != nil {
return nil, err
}
m := sqlutil.NewMigrations()
deltas.LoadRemoveRoomsTable(m)
if err = m.RunDeltas(d.db, dbProperties); err != nil {
return nil, err
}
d.Database = shared.Database{
DB: d.db,
Cache: cache,
Writer: d.writer,
FederationJoinedHosts: joinedHosts,
FederationQueuePDUs: queuePDUs,
FederationQueueEDUs: queueEDUs,
FederationQueueJSON: queueJSON,
FederationBlacklist: blacklist,
FederationInboundPeeks: inboundPeeks,
FederationOutboundPeeks: outboundPeeks,
NotaryServerKeysJSON: notaryJSON,
NotaryServerKeysMetadata: notaryMetadata,
ServerSigningKeys: serverSigningKeys,
}
if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "federationsender"); err != nil {
return nil, err
}
return &d, nil
}

View file

@ -0,0 +1,247 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package shared
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/matrix-org/dendrite/federationapi/storage/tables"
"github.com/matrix-org/dendrite/federationapi/types"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
)
type Database struct {
DB *sql.DB
Cache caching.FederationCache
Writer sqlutil.Writer
FederationQueuePDUs tables.FederationQueuePDUs
FederationQueueEDUs tables.FederationQueueEDUs
FederationQueueJSON tables.FederationQueueJSON
FederationJoinedHosts tables.FederationJoinedHosts
FederationBlacklist tables.FederationBlacklist
FederationOutboundPeeks tables.FederationOutboundPeeks
FederationInboundPeeks tables.FederationInboundPeeks
NotaryServerKeysJSON tables.FederationNotaryServerKeysJSON
NotaryServerKeysMetadata tables.FederationNotaryServerKeysMetadata
ServerSigningKeys tables.FederationServerSigningKeys
}
// An Receipt contains the NIDs of a call to GetNextTransactionPDUs/EDUs.
// We don't actually export the NIDs but we need the caller to be able
// to pass them back so that we can clean up if the transaction sends
// successfully.
type Receipt struct {
nid int64
}
func (r *Receipt) String() string {
return fmt.Sprintf("%d", r.nid)
}
// UpdateRoom updates the joined hosts for a room and returns what the joined
// hosts were before the update, or nil if this was a duplicate message.
// This is called when we receive a message from kafka, so we pass in
// oldEventID and newEventID to check that we haven't missed any messages or
// this isn't a duplicate message.
func (d *Database) UpdateRoom(
ctx context.Context,
roomID, oldEventID, newEventID string,
addHosts []types.JoinedHost,
removeHosts []string,
) (joinedHosts []types.JoinedHost, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
joinedHosts, err = d.FederationJoinedHosts.SelectJoinedHostsWithTx(ctx, txn, roomID)
if err != nil {
return err
}
for _, add := range addHosts {
err = d.FederationJoinedHosts.InsertJoinedHosts(ctx, txn, roomID, add.MemberEventID, add.ServerName)
if err != nil {
return err
}
}
if err = d.FederationJoinedHosts.DeleteJoinedHosts(ctx, txn, removeHosts); err != nil {
return err
}
return nil
})
return
}
// GetJoinedHosts returns the currently joined hosts for room,
// as known to federationserver.
// Returns an error if something goes wrong.
func (d *Database) GetJoinedHosts(
ctx context.Context, roomID string,
) ([]types.JoinedHost, error) {
return d.FederationJoinedHosts.SelectJoinedHosts(ctx, roomID)
}
// GetAllJoinedHosts returns the currently joined hosts for
// all rooms known to the federation sender.
// Returns an error if something goes wrong.
func (d *Database) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) {
return d.FederationJoinedHosts.SelectAllJoinedHosts(ctx)
}
func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string) ([]gomatrixserverlib.ServerName, error) {
return d.FederationJoinedHosts.SelectJoinedHostsForRooms(ctx, roomIDs)
}
// StoreJSON adds a JSON blob into the queue JSON table and returns
// a NID. The NID will then be used when inserting the per-destination
// metadata entries.
func (d *Database) StoreJSON(
ctx context.Context, js string,
) (*Receipt, error) {
var nid int64
var err error
_ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
nid, err = d.FederationQueueJSON.InsertQueueJSON(ctx, txn, js)
return err
})
if err != nil {
return nil, fmt.Errorf("d.insertQueueJSON: %w", err)
}
return &Receipt{
nid: nid,
}, nil
}
func (d *Database) PurgeRoomState(
ctx context.Context, roomID string,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
// If the event is a create event then we'll delete all of the existing
// data for the room. The only reason that a create event would be replayed
// to us in this way is if we're about to receive the entire room state.
if err := d.FederationJoinedHosts.DeleteJoinedHostsForRoom(ctx, txn, roomID); err != nil {
return fmt.Errorf("d.FederationJoinedHosts.DeleteJoinedHosts: %w", err)
}
return nil
})
}
func (d *Database) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationBlacklist.InsertBlacklist(context.TODO(), txn, serverName)
})
}
func (d *Database) RemoveServerFromBlacklist(serverName gomatrixserverlib.ServerName) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationBlacklist.DeleteBlacklist(context.TODO(), txn, serverName)
})
}
func (d *Database) RemoveAllServersFromBlacklist() error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationBlacklist.DeleteAllBlacklist(context.TODO(), txn)
})
}
func (d *Database) IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error) {
return d.FederationBlacklist.SelectBlacklist(context.TODO(), nil, serverName)
}
func (d *Database) AddOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationOutboundPeeks.InsertOutboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval)
})
}
func (d *Database) RenewOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationOutboundPeeks.RenewOutboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval)
})
}
func (d *Database) GetOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.OutboundPeek, error) {
return d.FederationOutboundPeeks.SelectOutboundPeek(ctx, nil, serverName, roomID, peekID)
}
func (d *Database) GetOutboundPeeks(ctx context.Context, roomID string) ([]types.OutboundPeek, error) {
return d.FederationOutboundPeeks.SelectOutboundPeeks(ctx, nil, roomID)
}
func (d *Database) AddInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationInboundPeeks.InsertInboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval)
})
}
func (d *Database) RenewInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.FederationInboundPeeks.RenewInboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval)
})
}
func (d *Database) GetInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.InboundPeek, error) {
return d.FederationInboundPeeks.SelectInboundPeek(ctx, nil, serverName, roomID, peekID)
}
func (d *Database) GetInboundPeeks(ctx context.Context, roomID string) ([]types.InboundPeek, error) {
return d.FederationInboundPeeks.SelectInboundPeeks(ctx, nil, roomID)
}
func (d *Database) UpdateNotaryKeys(ctx context.Context, serverName gomatrixserverlib.ServerName, serverKeys gomatrixserverlib.ServerKeys) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
validUntil := serverKeys.ValidUntilTS
// Servers MUST use the lesser of this field and 7 days into the future when determining if a key is valid.
// This is to avoid a situation where an attacker publishes a key which is valid for a significant amount of
// time without a way for the homeserver owner to revoke it.
// https://spec.matrix.org/unstable/server-server-api/#querying-keys-through-another-server
weekIntoFuture := time.Now().Add(7 * 24 * time.Hour)
if weekIntoFuture.Before(validUntil.Time()) {
validUntil = gomatrixserverlib.AsTimestamp(weekIntoFuture)
}
notaryID, err := d.NotaryServerKeysJSON.InsertJSONResponse(ctx, txn, serverKeys, serverName, validUntil)
if err != nil {
return err
}
// update the metadata for the keys
for keyID := range serverKeys.OldVerifyKeys {
_, err = d.NotaryServerKeysMetadata.UpsertKey(ctx, txn, serverName, keyID, notaryID, validUntil)
if err != nil {
return err
}
}
for keyID := range serverKeys.VerifyKeys {
_, err = d.NotaryServerKeysMetadata.UpsertKey(ctx, txn, serverName, keyID, notaryID, validUntil)
if err != nil {
return err
}
}
// clean up old responses
return d.NotaryServerKeysMetadata.DeleteOldJSONResponses(ctx, txn)
})
}
func (d *Database) GetNotaryKeys(
ctx context.Context, serverName gomatrixserverlib.ServerName, optKeyIDs []gomatrixserverlib.KeyID,
) (sks []gomatrixserverlib.ServerKeys, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
sks, err = d.NotaryServerKeysMetadata.SelectKeys(ctx, txn, serverName, optKeyIDs)
return err
})
return sks, err
}

View file

@ -0,0 +1,151 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package shared
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"github.com/matrix-org/gomatrixserverlib"
)
// AssociateEDUWithDestination creates an association that the
// destination queues will use to determine which JSON blobs to send
// to which servers.
func (d *Database) AssociateEDUWithDestination(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
receipt *Receipt,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if err := d.FederationQueueEDUs.InsertQueueEDU(
ctx, // context
txn, // SQL transaction
"", // TODO: EDU type for coalescing
serverName, // destination server name
receipt.nid, // NID from the federationapi_queue_json table
); err != nil {
return fmt.Errorf("InsertQueueEDU: %w", err)
}
return nil
})
}
// GetNextTransactionEDUs retrieves events from the database for
// the next pending transaction, up to the limit specified.
func (d *Database) GetPendingEDUs(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
limit int,
) (
edus map[*Receipt]*gomatrixserverlib.EDU,
err error,
) {
edus = make(map[*Receipt]*gomatrixserverlib.EDU)
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
nids, err := d.FederationQueueEDUs.SelectQueueEDUs(ctx, txn, serverName, limit)
if err != nil {
return fmt.Errorf("SelectQueueEDUs: %w", err)
}
retrieve := make([]int64, 0, len(nids))
for _, nid := range nids {
if edu, ok := d.Cache.GetFederationQueuedEDU(nid); ok {
edus[&Receipt{nid}] = edu
} else {
retrieve = append(retrieve, nid)
}
}
blobs, err := d.FederationQueueJSON.SelectQueueJSON(ctx, txn, retrieve)
if err != nil {
return fmt.Errorf("SelectQueueJSON: %w", err)
}
for nid, blob := range blobs {
var event gomatrixserverlib.EDU
if err := json.Unmarshal(blob, &event); err != nil {
return fmt.Errorf("json.Unmarshal: %w", err)
}
edus[&Receipt{nid}] = &event
}
return nil
})
return
}
// CleanEDUs cleans up all specified EDUs. This is done when a
// transaction was sent successfully.
func (d *Database) CleanEDUs(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
receipts []*Receipt,
) error {
if len(receipts) == 0 {
return errors.New("expected receipt")
}
nids := make([]int64, len(receipts))
for i := range receipts {
nids[i] = receipts[i].nid
}
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if err := d.FederationQueueEDUs.DeleteQueueEDUs(ctx, txn, serverName, nids); err != nil {
return err
}
var deleteNIDs []int64
for _, nid := range nids {
count, err := d.FederationQueueEDUs.SelectQueueEDUReferenceJSONCount(ctx, txn, nid)
if err != nil {
return fmt.Errorf("SelectQueueEDUReferenceJSONCount: %w", err)
}
if count == 0 {
deleteNIDs = append(deleteNIDs, nid)
d.Cache.EvictFederationQueuedEDU(nid)
}
}
if len(deleteNIDs) > 0 {
if err := d.FederationQueueJSON.DeleteQueueJSON(ctx, txn, deleteNIDs); err != nil {
return fmt.Errorf("DeleteQueueJSON: %w", err)
}
}
return nil
})
}
// GetPendingEDUCount returns the number of EDUs waiting to be
// sent for a given servername.
func (d *Database) GetPendingEDUCount(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
) (int64, error) {
return d.FederationQueueEDUs.SelectQueueEDUCount(ctx, nil, serverName)
}
// GetPendingServerNames returns the server names that have EDUs
// waiting to be sent.
func (d *Database) GetPendingEDUServerNames(
ctx context.Context,
) ([]gomatrixserverlib.ServerName, error) {
return d.FederationQueueEDUs.SelectQueueEDUServerNames(ctx, nil)
}

View file

@ -0,0 +1,59 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package shared
import (
"context"
"database/sql"
"github.com/matrix-org/gomatrixserverlib"
)
// FetcherName implements KeyFetcher
func (d Database) FetcherName() string {
return "FederationAPIKeyDatabase"
}
// FetchKeys implements gomatrixserverlib.KeyDatabase
func (d *Database) FetchKeys(
ctx context.Context,
requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) {
return d.ServerSigningKeys.BulkSelectServerKeys(ctx, nil, requests)
}
// StoreKeys implements gomatrixserverlib.KeyDatabase
func (d *Database) StoreKeys(
ctx context.Context,
keyMap map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
var lastErr error
for request, keys := range keyMap {
if err := d.ServerSigningKeys.UpsertServerKeys(ctx, txn, request, keys); err != nil {
// Rather than returning immediately on error we try to insert the
// remaining keys.
// Since we are inserting the keys outside of a transaction it is
// possible for some of the inserts to succeed even though some
// of the inserts have failed.
// Ensuring that we always insert all the keys we can means that
// this behaviour won't depend on the iteration order of the map.
lastErr = err
}
}
return lastErr
})
}

View file

@ -0,0 +1,159 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package shared
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"github.com/matrix-org/gomatrixserverlib"
)
// AssociatePDUWithDestination creates an association that the
// destination queues will use to determine which JSON blobs to send
// to which servers.
func (d *Database) AssociatePDUWithDestination(
ctx context.Context,
transactionID gomatrixserverlib.TransactionID,
serverName gomatrixserverlib.ServerName,
receipt *Receipt,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if err := d.FederationQueuePDUs.InsertQueuePDU(
ctx, // context
txn, // SQL transaction
transactionID, // transaction ID
serverName, // destination server name
receipt.nid, // NID from the federationapi_queue_json table
); err != nil {
return fmt.Errorf("InsertQueuePDU: %w", err)
}
return nil
})
}
// GetNextTransactionPDUs retrieves events from the database for
// the next pending transaction, up to the limit specified.
func (d *Database) GetPendingPDUs(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
limit int,
) (
events map[*Receipt]*gomatrixserverlib.HeaderedEvent,
err error,
) {
// Strictly speaking this doesn't need to be using the writer
// since we are only performing selects, but since we don't have
// a guarantee of transactional isolation, it's actually useful
// to know in SQLite mode that nothing else is trying to modify
// the database.
events = make(map[*Receipt]*gomatrixserverlib.HeaderedEvent)
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
nids, err := d.FederationQueuePDUs.SelectQueuePDUs(ctx, txn, serverName, limit)
if err != nil {
return fmt.Errorf("SelectQueuePDUs: %w", err)
}
retrieve := make([]int64, 0, len(nids))
for _, nid := range nids {
if event, ok := d.Cache.GetFederationQueuedPDU(nid); ok {
events[&Receipt{nid}] = event
} else {
retrieve = append(retrieve, nid)
}
}
blobs, err := d.FederationQueueJSON.SelectQueueJSON(ctx, txn, retrieve)
if err != nil {
return fmt.Errorf("SelectQueueJSON: %w", err)
}
for nid, blob := range blobs {
var event gomatrixserverlib.HeaderedEvent
if err := json.Unmarshal(blob, &event); err != nil {
return fmt.Errorf("json.Unmarshal: %w", err)
}
events[&Receipt{nid}] = &event
d.Cache.StoreFederationQueuedPDU(nid, &event)
}
return nil
})
return
}
// CleanTransactionPDUs cleans up all associated events for a
// given transaction. This is done when the transaction was sent
// successfully.
func (d *Database) CleanPDUs(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
receipts []*Receipt,
) error {
if len(receipts) == 0 {
return errors.New("expected receipt")
}
nids := make([]int64, len(receipts))
for i := range receipts {
nids[i] = receipts[i].nid
}
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if err := d.FederationQueuePDUs.DeleteQueuePDUs(ctx, txn, serverName, nids); err != nil {
return err
}
var deleteNIDs []int64
for _, nid := range nids {
count, err := d.FederationQueuePDUs.SelectQueuePDUReferenceJSONCount(ctx, txn, nid)
if err != nil {
return fmt.Errorf("SelectQueuePDUReferenceJSONCount: %w", err)
}
if count == 0 {
deleteNIDs = append(deleteNIDs, nid)
d.Cache.EvictFederationQueuedPDU(nid)
}
}
if len(deleteNIDs) > 0 {
if err := d.FederationQueueJSON.DeleteQueueJSON(ctx, txn, deleteNIDs); err != nil {
return fmt.Errorf("DeleteQueueJSON: %w", err)
}
}
return nil
})
}
// GetPendingPDUCount returns the number of PDUs waiting to be
// sent for a given servername.
func (d *Database) GetPendingPDUCount(
ctx context.Context,
serverName gomatrixserverlib.ServerName,
) (int64, error) {
return d.FederationQueuePDUs.SelectQueuePDUCount(ctx, nil, serverName)
}
// GetPendingServerNames returns the server names that have PDUs
// waiting to be sent.
func (d *Database) GetPendingPDUServerNames(
ctx context.Context,
) ([]gomatrixserverlib.ServerName, error) {
return d.FederationQueuePDUs.SelectQueuePDUServerNames(ctx, nil)
}

View file

@ -0,0 +1,115 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
)
const blacklistSchema = `
CREATE TABLE IF NOT EXISTS federationsender_blacklist (
-- The blacklisted server name
server_name TEXT NOT NULL,
UNIQUE (server_name)
);
`
const insertBlacklistSQL = "" +
"INSERT INTO federationsender_blacklist (server_name) VALUES ($1)" +
" ON CONFLICT DO NOTHING"
const selectBlacklistSQL = "" +
"SELECT server_name FROM federationsender_blacklist WHERE server_name = $1"
const deleteBlacklistSQL = "" +
"DELETE FROM federationsender_blacklist WHERE server_name = $1"
const deleteAllBlacklistSQL = "" +
"DELETE FROM federationsender_blacklist"
type blacklistStatements struct {
db *sql.DB
insertBlacklistStmt *sql.Stmt
selectBlacklistStmt *sql.Stmt
deleteBlacklistStmt *sql.Stmt
deleteAllBlacklistStmt *sql.Stmt
}
func NewSQLiteBlacklistTable(db *sql.DB) (s *blacklistStatements, err error) {
s = &blacklistStatements{
db: db,
}
_, err = db.Exec(blacklistSchema)
if err != nil {
return
}
if s.insertBlacklistStmt, err = db.Prepare(insertBlacklistSQL); err != nil {
return
}
if s.selectBlacklistStmt, err = db.Prepare(selectBlacklistSQL); err != nil {
return
}
if s.deleteBlacklistStmt, err = db.Prepare(deleteBlacklistSQL); err != nil {
return
}
if s.deleteAllBlacklistStmt, err = db.Prepare(deleteAllBlacklistSQL); err != nil {
return
}
return
}
func (s *blacklistStatements) InsertBlacklist(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) error {
stmt := sqlutil.TxStmt(txn, s.insertBlacklistStmt)
_, err := stmt.ExecContext(ctx, serverName)
return err
}
func (s *blacklistStatements) SelectBlacklist(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) (bool, error) {
stmt := sqlutil.TxStmt(txn, s.selectBlacklistStmt)
res, err := stmt.QueryContext(ctx, serverName)
if err != nil {
return false, err
}
defer res.Close() // nolint:errcheck
// The query will return the server name if the server is blacklisted, and
// will return no rows if not. By calling Next, we find out if a row was
// returned or not - we don't care about the value itself.
return res.Next(), nil
}
func (s *blacklistStatements) DeleteBlacklist(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteBlacklistStmt)
_, err := stmt.ExecContext(ctx, serverName)
return err
}
func (s *blacklistStatements) DeleteAllBlacklist(
ctx context.Context, txn *sql.Tx,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteAllBlacklistStmt)
_, err := stmt.ExecContext(ctx)
return err
}

View file

@ -0,0 +1,46 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package deltas
import (
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/pressly/goose"
)
func LoadFromGoose() {
goose.AddMigration(UpRemoveRoomsTable, DownRemoveRoomsTable)
}
func LoadRemoveRoomsTable(m *sqlutil.Migrations) {
m.AddMigration(UpRemoveRoomsTable, DownRemoveRoomsTable)
}
func UpRemoveRoomsTable(tx *sql.Tx) error {
_, err := tx.Exec(`
DROP TABLE IF EXISTS federationsender_rooms;
`)
if err != nil {
return fmt.Errorf("failed to execute upgrade: %w", err)
}
return nil
}
func DownRemoveRoomsTable(tx *sql.Tx) error {
// We can't reverse this.
return nil
}

View file

@ -0,0 +1,176 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"time"
"github.com/matrix-org/dendrite/federationapi/types"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
)
const inboundPeeksSchema = `
CREATE TABLE IF NOT EXISTS federationsender_inbound_peeks (
room_id TEXT NOT NULL,
server_name TEXT NOT NULL,
peek_id TEXT NOT NULL,
creation_ts INTEGER NOT NULL,
renewed_ts INTEGER NOT NULL,
renewal_interval INTEGER NOT NULL,
UNIQUE (room_id, server_name, peek_id)
);
`
const insertInboundPeekSQL = "" +
"INSERT INTO federationsender_inbound_peeks (room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval) VALUES ($1, $2, $3, $4, $5, $6)"
const selectInboundPeekSQL = "" +
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
const selectInboundPeeksSQL = "" +
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1"
const renewInboundPeekSQL = "" +
"UPDATE federationsender_inbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5"
const deleteInboundPeekSQL = "" +
"DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2"
const deleteInboundPeeksSQL = "" +
"DELETE FROM federationsender_inbound_peeks WHERE room_id = $1"
type inboundPeeksStatements struct {
db *sql.DB
insertInboundPeekStmt *sql.Stmt
selectInboundPeekStmt *sql.Stmt
selectInboundPeeksStmt *sql.Stmt
renewInboundPeekStmt *sql.Stmt
deleteInboundPeekStmt *sql.Stmt
deleteInboundPeeksStmt *sql.Stmt
}
func NewSQLiteInboundPeeksTable(db *sql.DB) (s *inboundPeeksStatements, err error) {
s = &inboundPeeksStatements{
db: db,
}
_, err = db.Exec(inboundPeeksSchema)
if err != nil {
return
}
if s.insertInboundPeekStmt, err = db.Prepare(insertInboundPeekSQL); err != nil {
return
}
if s.selectInboundPeekStmt, err = db.Prepare(selectInboundPeekSQL); err != nil {
return
}
if s.selectInboundPeeksStmt, err = db.Prepare(selectInboundPeeksSQL); err != nil {
return
}
if s.renewInboundPeekStmt, err = db.Prepare(renewInboundPeekSQL); err != nil {
return
}
if s.deleteInboundPeeksStmt, err = db.Prepare(deleteInboundPeeksSQL); err != nil {
return
}
if s.deleteInboundPeekStmt, err = db.Prepare(deleteInboundPeekSQL); err != nil {
return
}
return
}
func (s *inboundPeeksStatements) InsertInboundPeek(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64,
) (err error) {
nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
stmt := sqlutil.TxStmt(txn, s.insertInboundPeekStmt)
_, err = stmt.ExecContext(ctx, roomID, serverName, peekID, nowMilli, nowMilli, renewalInterval)
return
}
func (s *inboundPeeksStatements) RenewInboundPeek(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64,
) (err error) {
nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
_, err = sqlutil.TxStmt(txn, s.renewInboundPeekStmt).ExecContext(ctx, nowMilli, renewalInterval, roomID, serverName, peekID)
return
}
func (s *inboundPeeksStatements) SelectInboundPeek(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string,
) (*types.InboundPeek, error) {
row := sqlutil.TxStmt(txn, s.selectInboundPeeksStmt).QueryRowContext(ctx, roomID)
inboundPeek := types.InboundPeek{}
err := row.Scan(
&inboundPeek.RoomID,
&inboundPeek.ServerName,
&inboundPeek.PeekID,
&inboundPeek.CreationTimestamp,
&inboundPeek.RenewedTimestamp,
&inboundPeek.RenewalInterval,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
return &inboundPeek, nil
}
func (s *inboundPeeksStatements) SelectInboundPeeks(
ctx context.Context, txn *sql.Tx, roomID string,
) (inboundPeeks []types.InboundPeek, err error) {
rows, err := sqlutil.TxStmt(txn, s.selectInboundPeeksStmt).QueryContext(ctx, roomID)
if err != nil {
return
}
defer internal.CloseAndLogIfError(ctx, rows, "SelectInboundPeeks: rows.close() failed")
for rows.Next() {
inboundPeek := types.InboundPeek{}
if err = rows.Scan(
&inboundPeek.RoomID,
&inboundPeek.ServerName,
&inboundPeek.PeekID,
&inboundPeek.CreationTimestamp,
&inboundPeek.RenewedTimestamp,
&inboundPeek.RenewalInterval,
); err != nil {
return
}
inboundPeeks = append(inboundPeeks, inboundPeek)
}
return inboundPeeks, rows.Err()
}
func (s *inboundPeeksStatements) DeleteInboundPeek(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string,
) (err error) {
_, err = sqlutil.TxStmt(txn, s.deleteInboundPeekStmt).ExecContext(ctx, roomID, serverName, peekID)
return
}
func (s *inboundPeeksStatements) DeleteInboundPeeks(
ctx context.Context, txn *sql.Tx, roomID string,
) (err error) {
_, err = sqlutil.TxStmt(txn, s.deleteInboundPeeksStmt).ExecContext(ctx, roomID)
return
}

View file

@ -0,0 +1,219 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"strings"
"github.com/matrix-org/dendrite/federationapi/types"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
)
const joinedHostsSchema = `
-- The joined_hosts table stores a list of m.room.member event ids in the
-- current state for each room where the membership is "join".
-- There will be an entry for every user that is joined to the room.
CREATE TABLE IF NOT EXISTS federationsender_joined_hosts (
-- The string ID of the room.
room_id TEXT NOT NULL,
-- The event ID of the m.room.member join event.
event_id TEXT NOT NULL,
-- The domain part of the user ID the m.room.member event is for.
server_name TEXT NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS federatonsender_joined_hosts_event_id_idx
ON federationsender_joined_hosts (event_id);
CREATE INDEX IF NOT EXISTS federatonsender_joined_hosts_room_id_idx
ON federationsender_joined_hosts (room_id)
`
const insertJoinedHostsSQL = "" +
"INSERT OR IGNORE INTO federationsender_joined_hosts (room_id, event_id, server_name)" +
" VALUES ($1, $2, $3)"
const deleteJoinedHostsSQL = "" +
"DELETE FROM federationsender_joined_hosts WHERE event_id = $1"
const deleteJoinedHostsForRoomSQL = "" +
"DELETE FROM federationsender_joined_hosts WHERE room_id = $1"
const selectJoinedHostsSQL = "" +
"SELECT event_id, server_name FROM federationsender_joined_hosts" +
" WHERE room_id = $1"
const selectAllJoinedHostsSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_joined_hosts"
const selectJoinedHostsForRoomsSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id IN ($1)"
type joinedHostsStatements struct {
db *sql.DB
insertJoinedHostsStmt *sql.Stmt
deleteJoinedHostsStmt *sql.Stmt
deleteJoinedHostsForRoomStmt *sql.Stmt
selectJoinedHostsStmt *sql.Stmt
selectAllJoinedHostsStmt *sql.Stmt
// selectJoinedHostsForRoomsStmt *sql.Stmt - prepared at runtime due to variadic
}
func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) {
s = &joinedHostsStatements{
db: db,
}
_, err = db.Exec(joinedHostsSchema)
if err != nil {
return
}
if s.insertJoinedHostsStmt, err = db.Prepare(insertJoinedHostsSQL); err != nil {
return
}
if s.deleteJoinedHostsStmt, err = db.Prepare(deleteJoinedHostsSQL); err != nil {
return
}
if s.deleteJoinedHostsForRoomStmt, err = s.db.Prepare(deleteJoinedHostsForRoomSQL); err != nil {
return
}
if s.selectJoinedHostsStmt, err = db.Prepare(selectJoinedHostsSQL); err != nil {
return
}
if s.selectAllJoinedHostsStmt, err = db.Prepare(selectAllJoinedHostsSQL); err != nil {
return
}
return
}
func (s *joinedHostsStatements) InsertJoinedHosts(
ctx context.Context,
txn *sql.Tx,
roomID, eventID string,
serverName gomatrixserverlib.ServerName,
) error {
stmt := sqlutil.TxStmt(txn, s.insertJoinedHostsStmt)
_, err := stmt.ExecContext(ctx, roomID, eventID, serverName)
return err
}
func (s *joinedHostsStatements) DeleteJoinedHosts(
ctx context.Context, txn *sql.Tx, eventIDs []string,
) error {
for _, eventID := range eventIDs {
stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsStmt)
if _, err := stmt.ExecContext(ctx, eventID); err != nil {
return err
}
}
return nil
}
func (s *joinedHostsStatements) DeleteJoinedHostsForRoom(
ctx context.Context, txn *sql.Tx, roomID string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsForRoomStmt)
_, err := stmt.ExecContext(ctx, roomID)
return err
}
func (s *joinedHostsStatements) SelectJoinedHostsWithTx(
ctx context.Context, txn *sql.Tx, roomID string,
) ([]types.JoinedHost, error) {
stmt := sqlutil.TxStmt(txn, s.selectJoinedHostsStmt)
return joinedHostsFromStmt(ctx, stmt, roomID)
}
func (s *joinedHostsStatements) SelectJoinedHosts(
ctx context.Context, roomID string,
) ([]types.JoinedHost, error) {
return joinedHostsFromStmt(ctx, s.selectJoinedHostsStmt, roomID)
}
func (s *joinedHostsStatements) SelectAllJoinedHosts(
ctx context.Context,
) ([]gomatrixserverlib.ServerName, error) {
rows, err := s.selectAllJoinedHostsStmt.QueryContext(ctx)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectAllJoinedHosts: rows.close() failed")
var result []gomatrixserverlib.ServerName
for rows.Next() {
var serverName string
if err = rows.Scan(&serverName); err != nil {
return nil, err
}
result = append(result, gomatrixserverlib.ServerName(serverName))
}
return result, rows.Err()
}
func (s *joinedHostsStatements) SelectJoinedHostsForRooms(
ctx context.Context, roomIDs []string,
) ([]gomatrixserverlib.ServerName, error) {
iRoomIDs := make([]interface{}, len(roomIDs))
for i := range roomIDs {
iRoomIDs[i] = roomIDs[i]
}
sql := strings.Replace(selectJoinedHostsForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomIDs)), 1)
rows, err := s.db.QueryContext(ctx, sql, iRoomIDs...)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedHostsForRoomsStmt: rows.close() failed")
var result []gomatrixserverlib.ServerName
for rows.Next() {
var serverName string
if err = rows.Scan(&serverName); err != nil {
return nil, err
}
result = append(result, gomatrixserverlib.ServerName(serverName))
}
return result, rows.Err()
}
func joinedHostsFromStmt(
ctx context.Context, stmt *sql.Stmt, roomID string,
) ([]types.JoinedHost, error) {
rows, err := stmt.QueryContext(ctx, roomID)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "joinedHostsFromStmt: rows.close() failed")
var result []types.JoinedHost
for rows.Next() {
var eventID, serverName string
if err = rows.Scan(&eventID, &serverName); err != nil {
return nil, err
}
result = append(result, types.JoinedHost{
MemberEventID: eventID,
ServerName: gomatrixserverlib.ServerName(serverName),
})
}
return result, nil
}

View file

@ -0,0 +1,63 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/federationapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
)
const notaryServerKeysJSONSchema = `
CREATE TABLE IF NOT EXISTS federationsender_notary_server_keys_json (
notary_id INTEGER PRIMARY KEY AUTOINCREMENT,
response_json TEXT NOT NULL,
server_name TEXT NOT NULL,
valid_until BIGINT NOT NULL
);
`
const insertServerKeysJSONSQL = "" +
"INSERT INTO federationsender_notary_server_keys_json (response_json, server_name, valid_until) VALUES ($1, $2, $3)" +
" RETURNING notary_id"
type notaryServerKeysStatements struct {
db *sql.DB
insertServerKeysJSONStmt *sql.Stmt
}
func NewSQLiteNotaryServerKeysTable(db *sql.DB) (s *notaryServerKeysStatements, err error) {
s = &notaryServerKeysStatements{
db: db,
}
_, err = db.Exec(notaryServerKeysJSONSchema)
if err != nil {
return
}
if s.insertServerKeysJSONStmt, err = db.Prepare(insertServerKeysJSONSQL); err != nil {
return
}
return
}
func (s *notaryServerKeysStatements) InsertJSONResponse(
ctx context.Context, txn *sql.Tx, keyQueryResponseJSON gomatrixserverlib.ServerKeys, serverName gomatrixserverlib.ServerName, validUntil gomatrixserverlib.Timestamp,
) (tables.NotaryID, error) {
var notaryID tables.NotaryID
return notaryID, txn.Stmt(s.insertServerKeysJSONStmt).QueryRowContext(ctx, string(keyQueryResponseJSON.Raw), serverName, validUntil).Scan(&notaryID)
}

View file

@ -0,0 +1,169 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"github.com/matrix-org/dendrite/federationapi/storage/tables"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
)
const notaryServerKeysMetadataSchema = `
CREATE TABLE IF NOT EXISTS federationsender_notary_server_keys_metadata (
notary_id BIGINT NOT NULL,
server_name TEXT NOT NULL,
key_id TEXT NOT NULL,
UNIQUE (server_name, key_id)
);
`
const upsertServerKeysSQL = "" +
"INSERT INTO federationsender_notary_server_keys_metadata (notary_id, server_name, key_id) VALUES ($1, $2, $3)" +
" ON CONFLICT (server_name, key_id) DO UPDATE SET notary_id = $1"
// for a given (server_name, key_id), find the existing notary ID and valid until. Used to check if we will replace it
// JOINs with the json table
const selectNotaryKeyMetadataSQL = `
SELECT federationsender_notary_server_keys_metadata.notary_id, valid_until FROM federationsender_notary_server_keys_json
JOIN federationsender_notary_server_keys_metadata ON
federationsender_notary_server_keys_metadata.notary_id = federationsender_notary_server_keys_json.notary_id
WHERE federationsender_notary_server_keys_metadata.server_name = $1 AND federationsender_notary_server_keys_metadata.key_id = $2
`
// select the response which has the highest valid_until value
// JOINs with the json table
const selectNotaryKeyResponsesSQL = `
SELECT response_json FROM federationsender_notary_server_keys_json
WHERE server_name = $1 AND valid_until = (
SELECT MAX(valid_until) FROM federationsender_notary_server_keys_json WHERE server_name = $1
)
`
// select the responses which have the given key IDs
// JOINs with the json table
const selectNotaryKeyResponsesWithKeyIDsSQL = `
SELECT response_json FROM federationsender_notary_server_keys_json
JOIN federationsender_notary_server_keys_metadata ON
federationsender_notary_server_keys_metadata.notary_id = federationsender_notary_server_keys_json.notary_id
WHERE federationsender_notary_server_keys_json.server_name = $1 AND federationsender_notary_server_keys_metadata.key_id IN ($2)
GROUP BY federationsender_notary_server_keys_json.notary_id
`
// JOINs with the metadata table
const deleteUnusedServerKeysJSONSQL = `
DELETE FROM federationsender_notary_server_keys_json WHERE federationsender_notary_server_keys_json.notary_id NOT IN (
SELECT DISTINCT notary_id FROM federationsender_notary_server_keys_metadata
)
`
type notaryServerKeysMetadataStatements struct {
db *sql.DB
upsertServerKeysStmt *sql.Stmt
selectNotaryKeyResponsesStmt *sql.Stmt
selectNotaryKeyMetadataStmt *sql.Stmt
deleteUnusedServerKeysJSONStmt *sql.Stmt
}
func NewSQLiteNotaryServerKeysMetadataTable(db *sql.DB) (s *notaryServerKeysMetadataStatements, err error) {
s = &notaryServerKeysMetadataStatements{
db: db,
}
_, err = db.Exec(notaryServerKeysMetadataSchema)
if err != nil {
return
}
if s.upsertServerKeysStmt, err = db.Prepare(upsertServerKeysSQL); err != nil {
return
}
if s.selectNotaryKeyResponsesStmt, err = db.Prepare(selectNotaryKeyResponsesSQL); err != nil {
return
}
if s.selectNotaryKeyMetadataStmt, err = db.Prepare(selectNotaryKeyMetadataSQL); err != nil {
return
}
if s.deleteUnusedServerKeysJSONStmt, err = db.Prepare(deleteUnusedServerKeysJSONSQL); err != nil {
return
}
return
}
func (s *notaryServerKeysMetadataStatements) UpsertKey(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, keyID gomatrixserverlib.KeyID, newNotaryID tables.NotaryID, newValidUntil gomatrixserverlib.Timestamp,
) (tables.NotaryID, error) {
notaryID := newNotaryID
// see if the existing notary ID a) exists, b) has a longer valid_until
var existingNotaryID tables.NotaryID
var existingValidUntil gomatrixserverlib.Timestamp
if err := txn.Stmt(s.selectNotaryKeyMetadataStmt).QueryRowContext(ctx, serverName, keyID).Scan(&existingNotaryID, &existingValidUntil); err != nil {
if err != sql.ErrNoRows {
return 0, err
}
}
if existingValidUntil.Time().After(newValidUntil.Time()) {
// the existing valid_until is valid longer, so use that.
return existingNotaryID, nil
}
// overwrite the notary_id for this (server_name, key_id) tuple
_, err := txn.Stmt(s.upsertServerKeysStmt).ExecContext(ctx, notaryID, serverName, keyID)
return notaryID, err
}
func (s *notaryServerKeysMetadataStatements) SelectKeys(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, keyIDs []gomatrixserverlib.KeyID) ([]gomatrixserverlib.ServerKeys, error) {
var rows *sql.Rows
var err error
if len(keyIDs) == 0 {
rows, err = txn.Stmt(s.selectNotaryKeyResponsesStmt).QueryContext(ctx, string(serverName))
} else {
iKeyIDs := make([]interface{}, len(keyIDs)+1)
iKeyIDs[0] = serverName
for i := range keyIDs {
iKeyIDs[i+1] = string(keyIDs[i])
}
sql := strings.Replace(selectNotaryKeyResponsesWithKeyIDsSQL, "($2)", sqlutil.QueryVariadicOffset(len(keyIDs), 1), 1)
fmt.Println(sql)
fmt.Println(iKeyIDs...)
rows, err = s.db.QueryContext(ctx, sql, iKeyIDs...)
}
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectNotaryKeyResponsesStmt close failed")
var results []gomatrixserverlib.ServerKeys
for rows.Next() {
var sk gomatrixserverlib.ServerKeys
var raw string
if err = rows.Scan(&raw); err != nil {
return nil, err
}
if err = json.Unmarshal([]byte(raw), &sk); err != nil {
return nil, err
}
results = append(results, sk)
}
return results, nil
}
func (s *notaryServerKeysMetadataStatements) DeleteOldJSONResponses(ctx context.Context, txn *sql.Tx) error {
_, err := txn.Stmt(s.deleteUnusedServerKeysJSONStmt).ExecContext(ctx)
return err
}

View file

@ -0,0 +1,176 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"time"
"github.com/matrix-org/dendrite/federationapi/types"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
)
const outboundPeeksSchema = `
CREATE TABLE IF NOT EXISTS federationsender_outbound_peeks (
room_id TEXT NOT NULL,
server_name TEXT NOT NULL,
peek_id TEXT NOT NULL,
creation_ts INTEGER NOT NULL,
renewed_ts INTEGER NOT NULL,
renewal_interval INTEGER NOT NULL,
UNIQUE (room_id, server_name, peek_id)
);
`
const insertOutboundPeekSQL = "" +
"INSERT INTO federationsender_outbound_peeks (room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval) VALUES ($1, $2, $3, $4, $5, $6)"
const selectOutboundPeekSQL = "" +
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
const selectOutboundPeeksSQL = "" +
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1"
const renewOutboundPeekSQL = "" +
"UPDATE federationsender_outbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5"
const deleteOutboundPeekSQL = "" +
"DELETE FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2"
const deleteOutboundPeeksSQL = "" +
"DELETE FROM federationsender_outbound_peeks WHERE room_id = $1"
type outboundPeeksStatements struct {
db *sql.DB
insertOutboundPeekStmt *sql.Stmt
selectOutboundPeekStmt *sql.Stmt
selectOutboundPeeksStmt *sql.Stmt
renewOutboundPeekStmt *sql.Stmt
deleteOutboundPeekStmt *sql.Stmt
deleteOutboundPeeksStmt *sql.Stmt
}
func NewSQLiteOutboundPeeksTable(db *sql.DB) (s *outboundPeeksStatements, err error) {
s = &outboundPeeksStatements{
db: db,
}
_, err = db.Exec(outboundPeeksSchema)
if err != nil {
return
}
if s.insertOutboundPeekStmt, err = db.Prepare(insertOutboundPeekSQL); err != nil {
return
}
if s.selectOutboundPeekStmt, err = db.Prepare(selectOutboundPeekSQL); err != nil {
return
}
if s.selectOutboundPeeksStmt, err = db.Prepare(selectOutboundPeeksSQL); err != nil {
return
}
if s.renewOutboundPeekStmt, err = db.Prepare(renewOutboundPeekSQL); err != nil {
return
}
if s.deleteOutboundPeeksStmt, err = db.Prepare(deleteOutboundPeeksSQL); err != nil {
return
}
if s.deleteOutboundPeekStmt, err = db.Prepare(deleteOutboundPeekSQL); err != nil {
return
}
return
}
func (s *outboundPeeksStatements) InsertOutboundPeek(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64,
) (err error) {
nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
stmt := sqlutil.TxStmt(txn, s.insertOutboundPeekStmt)
_, err = stmt.ExecContext(ctx, roomID, serverName, peekID, nowMilli, nowMilli, renewalInterval)
return
}
func (s *outboundPeeksStatements) RenewOutboundPeek(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64,
) (err error) {
nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
_, err = sqlutil.TxStmt(txn, s.renewOutboundPeekStmt).ExecContext(ctx, nowMilli, renewalInterval, roomID, serverName, peekID)
return
}
func (s *outboundPeeksStatements) SelectOutboundPeek(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string,
) (*types.OutboundPeek, error) {
row := sqlutil.TxStmt(txn, s.selectOutboundPeeksStmt).QueryRowContext(ctx, roomID)
outboundPeek := types.OutboundPeek{}
err := row.Scan(
&outboundPeek.RoomID,
&outboundPeek.ServerName,
&outboundPeek.PeekID,
&outboundPeek.CreationTimestamp,
&outboundPeek.RenewedTimestamp,
&outboundPeek.RenewalInterval,
)
if err == sql.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, err
}
return &outboundPeek, nil
}
func (s *outboundPeeksStatements) SelectOutboundPeeks(
ctx context.Context, txn *sql.Tx, roomID string,
) (outboundPeeks []types.OutboundPeek, err error) {
rows, err := sqlutil.TxStmt(txn, s.selectOutboundPeeksStmt).QueryContext(ctx, roomID)
if err != nil {
return
}
defer internal.CloseAndLogIfError(ctx, rows, "SelectOutboundPeeks: rows.close() failed")
for rows.Next() {
outboundPeek := types.OutboundPeek{}
if err = rows.Scan(
&outboundPeek.RoomID,
&outboundPeek.ServerName,
&outboundPeek.PeekID,
&outboundPeek.CreationTimestamp,
&outboundPeek.RenewedTimestamp,
&outboundPeek.RenewalInterval,
); err != nil {
return
}
outboundPeeks = append(outboundPeeks, outboundPeek)
}
return outboundPeeks, rows.Err()
}
func (s *outboundPeeksStatements) DeleteOutboundPeek(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string,
) (err error) {
_, err = sqlutil.TxStmt(txn, s.deleteOutboundPeekStmt).ExecContext(ctx, roomID, serverName, peekID)
return
}
func (s *outboundPeeksStatements) DeleteOutboundPeeks(
ctx context.Context, txn *sql.Tx, roomID string,
) (err error) {
_, err = sqlutil.TxStmt(txn, s.deleteOutboundPeeksStmt).ExecContext(ctx, roomID)
return
}

View file

@ -0,0 +1,207 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"fmt"
"strings"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
)
const queueEDUsSchema = `
CREATE TABLE IF NOT EXISTS federationsender_queue_edus (
-- The type of the event (informational).
edu_type TEXT NOT NULL,
-- The domain part of the user ID the EDU event is for.
server_name TEXT NOT NULL,
-- The JSON NID from the federationsender_queue_edus_json table.
json_nid BIGINT NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_edus_json_nid_idx
ON federationsender_queue_edus (json_nid, server_name);
`
const insertQueueEDUSQL = "" +
"INSERT INTO federationsender_queue_edus (edu_type, server_name, json_nid)" +
" VALUES ($1, $2, $3)"
const deleteQueueEDUsSQL = "" +
"DELETE FROM federationsender_queue_edus WHERE server_name = $1 AND json_nid IN ($2)"
const selectQueueEDUSQL = "" +
"SELECT json_nid FROM federationsender_queue_edus" +
" WHERE server_name = $1" +
" LIMIT $2"
const selectQueueEDUReferenceJSONCountSQL = "" +
"SELECT COUNT(*) FROM federationsender_queue_edus" +
" WHERE json_nid = $1"
const selectQueueEDUCountSQL = "" +
"SELECT COUNT(*) FROM federationsender_queue_edus" +
" WHERE server_name = $1"
const selectQueueServerNamesSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_queue_edus"
type queueEDUsStatements struct {
db *sql.DB
insertQueueEDUStmt *sql.Stmt
selectQueueEDUStmt *sql.Stmt
selectQueueEDUReferenceJSONCountStmt *sql.Stmt
selectQueueEDUCountStmt *sql.Stmt
selectQueueEDUServerNamesStmt *sql.Stmt
}
func NewSQLiteQueueEDUsTable(db *sql.DB) (s *queueEDUsStatements, err error) {
s = &queueEDUsStatements{
db: db,
}
_, err = db.Exec(queueEDUsSchema)
if err != nil {
return
}
if s.insertQueueEDUStmt, err = db.Prepare(insertQueueEDUSQL); err != nil {
return
}
if s.selectQueueEDUStmt, err = db.Prepare(selectQueueEDUSQL); err != nil {
return
}
if s.selectQueueEDUReferenceJSONCountStmt, err = db.Prepare(selectQueueEDUReferenceJSONCountSQL); err != nil {
return
}
if s.selectQueueEDUCountStmt, err = db.Prepare(selectQueueEDUCountSQL); err != nil {
return
}
if s.selectQueueEDUServerNamesStmt, err = db.Prepare(selectQueueServerNamesSQL); err != nil {
return
}
return
}
func (s *queueEDUsStatements) InsertQueueEDU(
ctx context.Context,
txn *sql.Tx,
eduType string,
serverName gomatrixserverlib.ServerName,
nid int64,
) error {
stmt := sqlutil.TxStmt(txn, s.insertQueueEDUStmt)
_, err := stmt.ExecContext(
ctx,
eduType, // the EDU type
serverName, // destination server name
nid, // JSON blob NID
)
return err
}
func (s *queueEDUsStatements) DeleteQueueEDUs(
ctx context.Context, txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
jsonNIDs []int64,
) error {
deleteSQL := strings.Replace(deleteQueueEDUsSQL, "($2)", sqlutil.QueryVariadicOffset(len(jsonNIDs), 1), 1)
deleteStmt, err := txn.Prepare(deleteSQL)
if err != nil {
return fmt.Errorf("s.deleteQueueJSON s.db.Prepare: %w", err)
}
params := make([]interface{}, len(jsonNIDs)+1)
params[0] = serverName
for k, v := range jsonNIDs {
params[k+1] = v
}
stmt := sqlutil.TxStmt(txn, deleteStmt)
_, err = stmt.ExecContext(ctx, params...)
return err
}
func (s *queueEDUsStatements) SelectQueueEDUs(
ctx context.Context, txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
limit int,
) ([]int64, error) {
stmt := sqlutil.TxStmt(txn, s.selectQueueEDUStmt)
rows, err := stmt.QueryContext(ctx, serverName, limit)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed")
var result []int64
for rows.Next() {
var nid int64
if err = rows.Scan(&nid); err != nil {
return nil, err
}
result = append(result, nid)
}
return result, nil
}
func (s *queueEDUsStatements) SelectQueueEDUReferenceJSONCount(
ctx context.Context, txn *sql.Tx, jsonNID int64,
) (int64, error) {
var count int64
stmt := sqlutil.TxStmt(txn, s.selectQueueEDUReferenceJSONCountStmt)
err := stmt.QueryRowContext(ctx, jsonNID).Scan(&count)
if err == sql.ErrNoRows {
return -1, nil
}
return count, err
}
func (s *queueEDUsStatements) SelectQueueEDUCount(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) (int64, error) {
var count int64
stmt := sqlutil.TxStmt(txn, s.selectQueueEDUCountStmt)
err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
if err == sql.ErrNoRows {
// It's acceptable for there to be no rows referencing a given
// JSON NID but it's not an error condition. Just return as if
// there's a zero count.
return 0, nil
}
return count, err
}
func (s *queueEDUsStatements) SelectQueueEDUServerNames(
ctx context.Context, txn *sql.Tx,
) ([]gomatrixserverlib.ServerName, error) {
stmt := sqlutil.TxStmt(txn, s.selectQueueEDUServerNamesStmt)
rows, err := stmt.QueryContext(ctx)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed")
var result []gomatrixserverlib.ServerName
for rows.Next() {
var serverName gomatrixserverlib.ServerName
if err = rows.Scan(&serverName); err != nil {
return nil, err
}
result = append(result, serverName)
}
return result, rows.Err()
}

View file

@ -0,0 +1,136 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"fmt"
"strings"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
)
const queueJSONSchema = `
-- The queue_retry_json table contains event contents that
-- we failed to send.
CREATE TABLE IF NOT EXISTS federationsender_queue_json (
-- The JSON NID. This allows the federationsender_queue_retry table to
-- cross-reference to find the JSON blob.
json_nid INTEGER PRIMARY KEY AUTOINCREMENT,
-- The JSON body. Text so that we preserve UTF-8.
json_body TEXT NOT NULL
);
`
const insertJSONSQL = "" +
"INSERT INTO federationsender_queue_json (json_body)" +
" VALUES ($1)"
const deleteJSONSQL = "" +
"DELETE FROM federationsender_queue_json WHERE json_nid IN ($1)"
const selectJSONSQL = "" +
"SELECT json_nid, json_body FROM federationsender_queue_json" +
" WHERE json_nid IN ($1)"
type queueJSONStatements struct {
db *sql.DB
insertJSONStmt *sql.Stmt
//deleteJSONStmt *sql.Stmt - prepared at runtime due to variadic
//selectJSONStmt *sql.Stmt - prepared at runtime due to variadic
}
func NewSQLiteQueueJSONTable(db *sql.DB) (s *queueJSONStatements, err error) {
s = &queueJSONStatements{
db: db,
}
_, err = db.Exec(queueJSONSchema)
if err != nil {
return
}
if s.insertJSONStmt, err = db.Prepare(insertJSONSQL); err != nil {
return
}
return
}
func (s *queueJSONStatements) InsertQueueJSON(
ctx context.Context, txn *sql.Tx, json string,
) (lastid int64, err error) {
stmt := sqlutil.TxStmt(txn, s.insertJSONStmt)
res, err := stmt.ExecContext(ctx, json)
if err != nil {
return 0, fmt.Errorf("stmt.QueryContext: %w", err)
}
lastid, err = res.LastInsertId()
if err != nil {
return 0, fmt.Errorf("res.LastInsertId: %w", err)
}
return
}
func (s *queueJSONStatements) DeleteQueueJSON(
ctx context.Context, txn *sql.Tx, nids []int64,
) error {
deleteSQL := strings.Replace(deleteJSONSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1)
deleteStmt, err := txn.Prepare(deleteSQL)
if err != nil {
return fmt.Errorf("s.deleteQueueJSON s.db.Prepare: %w", err)
}
iNIDs := make([]interface{}, len(nids))
for k, v := range nids {
iNIDs[k] = v
}
stmt := sqlutil.TxStmt(txn, deleteStmt)
_, err = stmt.ExecContext(ctx, iNIDs...)
return err
}
func (s *queueJSONStatements) SelectQueueJSON(
ctx context.Context, txn *sql.Tx, jsonNIDs []int64,
) (map[int64][]byte, error) {
selectSQL := strings.Replace(selectJSONSQL, "($1)", sqlutil.QueryVariadic(len(jsonNIDs)), 1)
selectStmt, err := txn.Prepare(selectSQL)
if err != nil {
return nil, fmt.Errorf("s.selectQueueJSON s.db.Prepare: %w", err)
}
iNIDs := make([]interface{}, len(jsonNIDs))
for k, v := range jsonNIDs {
iNIDs[k] = v
}
blobs := map[int64][]byte{}
stmt := sqlutil.TxStmt(txn, selectStmt)
rows, err := stmt.QueryContext(ctx, iNIDs...)
if err != nil {
return nil, fmt.Errorf("s.selectQueueJSON stmt.QueryContext: %w", err)
}
defer internal.CloseAndLogIfError(ctx, rows, "selectJSON: rows.close() failed")
for rows.Next() {
var nid int64
var blob []byte
if err = rows.Scan(&nid, &blob); err != nil {
return nil, fmt.Errorf("s.selectQueueJSON rows.Scan: %w", err)
}
blobs[nid] = blob
}
return blobs, err
}

View file

@ -0,0 +1,235 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"fmt"
"strings"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
)
const queuePDUsSchema = `
CREATE TABLE IF NOT EXISTS federationsender_queue_pdus (
-- The transaction ID that was generated before persisting the event.
transaction_id TEXT NOT NULL,
-- The domain part of the user ID the m.room.member event is for.
server_name TEXT NOT NULL,
-- The JSON NID from the federationsender_queue_pdus_json table.
json_nid BIGINT NOT NULL
);
CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_pdus_pdus_json_nid_idx
ON federationsender_queue_pdus (json_nid, server_name);
`
const insertQueuePDUSQL = "" +
"INSERT INTO federationsender_queue_pdus (transaction_id, server_name, json_nid)" +
" VALUES ($1, $2, $3)"
const deleteQueuePDUsSQL = "" +
"DELETE FROM federationsender_queue_pdus WHERE server_name = $1 AND json_nid IN ($2)"
const selectQueueNextTransactionIDSQL = "" +
"SELECT transaction_id FROM federationsender_queue_pdus" +
" WHERE server_name = $1" +
" ORDER BY transaction_id ASC" +
" LIMIT 1"
const selectQueuePDUsSQL = "" +
"SELECT json_nid FROM federationsender_queue_pdus" +
" WHERE server_name = $1" +
" LIMIT $2"
const selectQueuePDUsReferenceJSONCountSQL = "" +
"SELECT COUNT(*) FROM federationsender_queue_pdus" +
" WHERE json_nid = $1"
const selectQueuePDUsCountSQL = "" +
"SELECT COUNT(*) FROM federationsender_queue_pdus" +
" WHERE server_name = $1"
const selectQueuePDUsServerNamesSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_queue_pdus"
type queuePDUsStatements struct {
db *sql.DB
insertQueuePDUStmt *sql.Stmt
selectQueueNextTransactionIDStmt *sql.Stmt
selectQueuePDUsStmt *sql.Stmt
selectQueueReferenceJSONCountStmt *sql.Stmt
selectQueuePDUsCountStmt *sql.Stmt
selectQueueServerNamesStmt *sql.Stmt
// deleteQueuePDUsStmt *sql.Stmt - prepared at runtime due to variadic
}
func NewSQLiteQueuePDUsTable(db *sql.DB) (s *queuePDUsStatements, err error) {
s = &queuePDUsStatements{
db: db,
}
_, err = db.Exec(queuePDUsSchema)
if err != nil {
return
}
if s.insertQueuePDUStmt, err = db.Prepare(insertQueuePDUSQL); err != nil {
return
}
//if s.deleteQueuePDUsStmt, err = db.Prepare(deleteQueuePDUsSQL); err != nil {
// return
//}
if s.selectQueueNextTransactionIDStmt, err = db.Prepare(selectQueueNextTransactionIDSQL); err != nil {
return
}
if s.selectQueuePDUsStmt, err = db.Prepare(selectQueuePDUsSQL); err != nil {
return
}
if s.selectQueueReferenceJSONCountStmt, err = db.Prepare(selectQueuePDUsReferenceJSONCountSQL); err != nil {
return
}
if s.selectQueuePDUsCountStmt, err = db.Prepare(selectQueuePDUsCountSQL); err != nil {
return
}
if s.selectQueueServerNamesStmt, err = db.Prepare(selectQueuePDUsServerNamesSQL); err != nil {
return
}
return
}
func (s *queuePDUsStatements) InsertQueuePDU(
ctx context.Context,
txn *sql.Tx,
transactionID gomatrixserverlib.TransactionID,
serverName gomatrixserverlib.ServerName,
nid int64,
) error {
stmt := sqlutil.TxStmt(txn, s.insertQueuePDUStmt)
_, err := stmt.ExecContext(
ctx,
transactionID, // the transaction ID that we initially attempted
serverName, // destination server name
nid, // JSON blob NID
)
return err
}
func (s *queuePDUsStatements) DeleteQueuePDUs(
ctx context.Context, txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
jsonNIDs []int64,
) error {
deleteSQL := strings.Replace(deleteQueuePDUsSQL, "($2)", sqlutil.QueryVariadicOffset(len(jsonNIDs), 1), 1)
deleteStmt, err := txn.Prepare(deleteSQL)
if err != nil {
return fmt.Errorf("s.deleteQueueJSON s.db.Prepare: %w", err)
}
params := make([]interface{}, len(jsonNIDs)+1)
params[0] = serverName
for k, v := range jsonNIDs {
params[k+1] = v
}
stmt := sqlutil.TxStmt(txn, deleteStmt)
_, err = stmt.ExecContext(ctx, params...)
return err
}
func (s *queuePDUsStatements) SelectQueuePDUNextTransactionID(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) (gomatrixserverlib.TransactionID, error) {
var transactionID gomatrixserverlib.TransactionID
stmt := sqlutil.TxStmt(txn, s.selectQueueNextTransactionIDStmt)
err := stmt.QueryRowContext(ctx, serverName).Scan(&transactionID)
if err == sql.ErrNoRows {
return "", nil
}
return transactionID, err
}
func (s *queuePDUsStatements) SelectQueuePDUReferenceJSONCount(
ctx context.Context, txn *sql.Tx, jsonNID int64,
) (int64, error) {
var count int64
stmt := sqlutil.TxStmt(txn, s.selectQueueReferenceJSONCountStmt)
err := stmt.QueryRowContext(ctx, jsonNID).Scan(&count)
if err == sql.ErrNoRows {
return -1, nil
}
return count, err
}
func (s *queuePDUsStatements) SelectQueuePDUCount(
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
) (int64, error) {
var count int64
stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsCountStmt)
err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
if err == sql.ErrNoRows {
// It's acceptable for there to be no rows referencing a given
// JSON NID but it's not an error condition. Just return as if
// there's a zero count.
return 0, nil
}
return count, err
}
func (s *queuePDUsStatements) SelectQueuePDUs(
ctx context.Context, txn *sql.Tx,
serverName gomatrixserverlib.ServerName,
limit int,
) ([]int64, error) {
stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsStmt)
rows, err := stmt.QueryContext(ctx, serverName, limit)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed")
var result []int64
for rows.Next() {
var nid int64
if err = rows.Scan(&nid); err != nil {
return nil, err
}
result = append(result, nid)
}
return result, rows.Err()
}
func (s *queuePDUsStatements) SelectQueuePDUServerNames(
ctx context.Context, txn *sql.Tx,
) ([]gomatrixserverlib.ServerName, error) {
stmt := sqlutil.TxStmt(txn, s.selectQueueServerNamesStmt)
rows, err := stmt.QueryContext(ctx)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed")
var result []gomatrixserverlib.ServerName
for rows.Next() {
var serverName gomatrixserverlib.ServerName
if err = rows.Scan(&serverName); err != nil {
return nil, err
}
result = append(result, serverName)
}
return result, rows.Err()
}

View file

@ -0,0 +1,157 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"fmt"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib"
)
const serverSigningKeysSchema = `
-- A cache of signing keys downloaded from remote servers.
CREATE TABLE IF NOT EXISTS keydb_server_keys (
-- The name of the matrix server the key is for.
server_name TEXT NOT NULL,
-- The ID of the server key.
server_key_id TEXT NOT NULL,
-- Combined server name and key ID separated by the ASCII unit separator
-- to make it easier to run bulk queries.
server_name_and_key_id TEXT NOT NULL,
-- When the key is valid until as a millisecond timestamp.
-- 0 if this is an expired key (in which case expired_ts will be non-zero)
valid_until_ts BIGINT NOT NULL,
-- When the key expired as a millisecond timestamp.
-- 0 if this is an active key (in which case valid_until_ts will be non-zero)
expired_ts BIGINT NOT NULL,
-- The base64-encoded public key.
server_key TEXT NOT NULL,
UNIQUE (server_name, server_key_id)
);
CREATE INDEX IF NOT EXISTS keydb_server_name_and_key_id ON keydb_server_keys (server_name_and_key_id);
`
const bulkSelectServerSigningKeysSQL = "" +
"SELECT server_name, server_key_id, valid_until_ts, expired_ts, " +
" server_key FROM keydb_server_keys" +
" WHERE server_name_and_key_id IN ($1)"
const upsertServerSigningKeysSQL = "" +
"INSERT INTO keydb_server_keys (server_name, server_key_id," +
" server_name_and_key_id, valid_until_ts, expired_ts, server_key)" +
" VALUES ($1, $2, $3, $4, $5, $6)" +
" ON CONFLICT (server_name, server_key_id)" +
" DO UPDATE SET valid_until_ts = $4, expired_ts = $5, server_key = $6"
type serverSigningKeyStatements struct {
db *sql.DB
bulkSelectServerKeysStmt *sql.Stmt
upsertServerKeysStmt *sql.Stmt
}
func NewSQLiteServerSigningKeysTable(db *sql.DB) (s *serverSigningKeyStatements, err error) {
s = &serverSigningKeyStatements{
db: db,
}
_, err = db.Exec(serverSigningKeysSchema)
if err != nil {
return
}
if s.bulkSelectServerKeysStmt, err = db.Prepare(bulkSelectServerSigningKeysSQL); err != nil {
return
}
if s.upsertServerKeysStmt, err = db.Prepare(upsertServerSigningKeysSQL); err != nil {
return
}
return s, nil
}
func (s *serverSigningKeyStatements) BulkSelectServerKeys(
ctx context.Context, txn *sql.Tx,
requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) {
nameAndKeyIDs := make([]string, 0, len(requests))
for request := range requests {
nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request))
}
results := make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, len(requests))
iKeyIDs := make([]interface{}, len(nameAndKeyIDs))
for i, v := range nameAndKeyIDs {
iKeyIDs[i] = v
}
err := sqlutil.RunLimitedVariablesQuery(
ctx, bulkSelectServerSigningKeysSQL, s.db, iKeyIDs, sqlutil.SQLite3MaxVariables,
func(rows *sql.Rows) error {
for rows.Next() {
var serverName string
var keyID string
var key string
var validUntilTS int64
var expiredTS int64
if err := rows.Scan(&serverName, &keyID, &validUntilTS, &expiredTS, &key); err != nil {
return fmt.Errorf("bulkSelectServerKeys: %v", err)
}
r := gomatrixserverlib.PublicKeyLookupRequest{
ServerName: gomatrixserverlib.ServerName(serverName),
KeyID: gomatrixserverlib.KeyID(keyID),
}
vk := gomatrixserverlib.VerifyKey{}
err := vk.Key.Decode(key)
if err != nil {
return fmt.Errorf("bulkSelectServerKeys: %v", err)
}
results[r] = gomatrixserverlib.PublicKeyLookupResult{
VerifyKey: vk,
ValidUntilTS: gomatrixserverlib.Timestamp(validUntilTS),
ExpiredTS: gomatrixserverlib.Timestamp(expiredTS),
}
}
return nil
},
)
if err != nil {
return nil, err
}
return results, nil
}
func (s *serverSigningKeyStatements) UpsertServerKeys(
ctx context.Context, txn *sql.Tx,
request gomatrixserverlib.PublicKeyLookupRequest,
key gomatrixserverlib.PublicKeyLookupResult,
) error {
stmt := sqlutil.TxStmt(txn, s.upsertServerKeysStmt)
_, err := stmt.ExecContext(
ctx,
string(request.ServerName),
string(request.KeyID),
nameAndKeyID(request),
key.ValidUntilTS,
key.ExpiredTS,
key.Key.Encode(),
)
return err
}
func nameAndKeyID(request gomatrixserverlib.PublicKeyLookupRequest) string {
return string(request.ServerName) + "\x1F" + string(request.KeyID)
}

View file

@ -0,0 +1,108 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"database/sql"
"github.com/matrix-org/dendrite/federationapi/storage/shared"
"github.com/matrix-org/dendrite/federationapi/storage/sqlite3/deltas"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config"
)
// Database stores information needed by the federation sender
type Database struct {
shared.Database
sqlutil.PartitionOffsetStatements
db *sql.DB
writer sqlutil.Writer
}
// NewDatabase opens a new database
func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationCache) (*Database, error) {
var d Database
var err error
if d.db, err = sqlutil.Open(dbProperties); err != nil {
return nil, err
}
d.writer = sqlutil.NewExclusiveWriter()
joinedHosts, err := NewSQLiteJoinedHostsTable(d.db)
if err != nil {
return nil, err
}
queuePDUs, err := NewSQLiteQueuePDUsTable(d.db)
if err != nil {
return nil, err
}
queueEDUs, err := NewSQLiteQueueEDUsTable(d.db)
if err != nil {
return nil, err
}
queueJSON, err := NewSQLiteQueueJSONTable(d.db)
if err != nil {
return nil, err
}
blacklist, err := NewSQLiteBlacklistTable(d.db)
if err != nil {
return nil, err
}
outboundPeeks, err := NewSQLiteOutboundPeeksTable(d.db)
if err != nil {
return nil, err
}
inboundPeeks, err := NewSQLiteInboundPeeksTable(d.db)
if err != nil {
return nil, err
}
notaryKeys, err := NewSQLiteNotaryServerKeysTable(d.db)
if err != nil {
return nil, err
}
notaryKeysMetadata, err := NewSQLiteNotaryServerKeysMetadataTable(d.db)
if err != nil {
return nil, err
}
serverSigningKeys, err := NewSQLiteServerSigningKeysTable(d.db)
if err != nil {
return nil, err
}
m := sqlutil.NewMigrations()
deltas.LoadRemoveRoomsTable(m)
if err = m.RunDeltas(d.db, dbProperties); err != nil {
return nil, err
}
d.Database = shared.Database{
DB: d.db,
Cache: cache,
Writer: d.writer,
FederationJoinedHosts: joinedHosts,
FederationQueuePDUs: queuePDUs,
FederationQueueEDUs: queueEDUs,
FederationQueueJSON: queueJSON,
FederationBlacklist: blacklist,
FederationOutboundPeeks: outboundPeeks,
FederationInboundPeeks: inboundPeeks,
NotaryServerKeysJSON: notaryKeys,
NotaryServerKeysMetadata: notaryKeysMetadata,
ServerSigningKeys: serverSigningKeys,
}
if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "federationsender"); err != nil {
return nil, err
}
return &d, nil
}

View file

@ -0,0 +1,39 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//go:build !wasm
// +build !wasm
package storage
import (
"fmt"
"github.com/matrix-org/dendrite/federationapi/storage/postgres"
"github.com/matrix-org/dendrite/federationapi/storage/sqlite3"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/setup/config"
)
// NewDatabase opens a new database
func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationCache) (Database, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():
return sqlite3.NewDatabase(dbProperties, cache)
case dbProperties.ConnectionString.IsPostgres():
return postgres.NewDatabase(dbProperties, cache)
default:
return nil, fmt.Errorf("unexpected database type")
}
}

View file

@ -0,0 +1,35 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package storage
import (
"fmt"
"github.com/matrix-org/dendrite/federationapi/storage/sqlite3"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/setup/config"
)
// NewDatabase opens a new database
func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationCache) (Database, error) {
switch {
case dbProperties.ConnectionString.IsSQLite():
return sqlite3.NewDatabase(dbProperties, cache)
case dbProperties.ConnectionString.IsPostgres():
return nil, fmt.Errorf("can't use Postgres implementation")
default:
return nil, fmt.Errorf("unexpected database type")
}
}

View file

@ -0,0 +1,111 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tables
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/federationapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
type NotaryID int64
type FederationQueuePDUs interface {
InsertQueuePDU(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error
DeleteQueuePDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error
SelectQueuePDUReferenceJSONCount(ctx context.Context, txn *sql.Tx, jsonNID int64) (int64, error)
SelectQueuePDUCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error)
SelectQueuePDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error)
SelectQueuePDUServerNames(ctx context.Context, txn *sql.Tx) ([]gomatrixserverlib.ServerName, error)
}
type FederationQueueEDUs interface {
InsertQueueEDU(ctx context.Context, txn *sql.Tx, eduType string, serverName gomatrixserverlib.ServerName, nid int64) error
DeleteQueueEDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error
SelectQueueEDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error)
SelectQueueEDUReferenceJSONCount(ctx context.Context, txn *sql.Tx, jsonNID int64) (int64, error)
SelectQueueEDUCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error)
SelectQueueEDUServerNames(ctx context.Context, txn *sql.Tx) ([]gomatrixserverlib.ServerName, error)
}
type FederationQueueJSON interface {
InsertQueueJSON(ctx context.Context, txn *sql.Tx, json string) (int64, error)
DeleteQueueJSON(ctx context.Context, txn *sql.Tx, nids []int64) error
SelectQueueJSON(ctx context.Context, txn *sql.Tx, jsonNIDs []int64) (map[int64][]byte, error)
}
type FederationJoinedHosts interface {
InsertJoinedHosts(ctx context.Context, txn *sql.Tx, roomID, eventID string, serverName gomatrixserverlib.ServerName) error
DeleteJoinedHosts(ctx context.Context, txn *sql.Tx, eventIDs []string) error
DeleteJoinedHostsForRoom(ctx context.Context, txn *sql.Tx, roomID string) error
SelectJoinedHostsWithTx(ctx context.Context, txn *sql.Tx, roomID string) ([]types.JoinedHost, error)
SelectJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error)
SelectAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
SelectJoinedHostsForRooms(ctx context.Context, roomIDs []string) ([]gomatrixserverlib.ServerName, error)
}
type FederationBlacklist interface {
InsertBlacklist(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) error
SelectBlacklist(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (bool, error)
DeleteBlacklist(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) error
DeleteAllBlacklist(ctx context.Context, txn *sql.Tx) error
}
type FederationOutboundPeeks interface {
InsertOutboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) (err error)
RenewOutboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) (err error)
SelectOutboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string) (outboundPeek *types.OutboundPeek, err error)
SelectOutboundPeeks(ctx context.Context, txn *sql.Tx, roomID string) (outboundPeeks []types.OutboundPeek, err error)
DeleteOutboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string) (err error)
DeleteOutboundPeeks(ctx context.Context, txn *sql.Tx, roomID string) (err error)
}
type FederationInboundPeeks interface {
InsertInboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) (err error)
RenewInboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) (err error)
SelectInboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string) (inboundPeek *types.InboundPeek, err error)
SelectInboundPeeks(ctx context.Context, txn *sql.Tx, roomID string) (inboundPeeks []types.InboundPeek, err error)
DeleteInboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string) (err error)
DeleteInboundPeeks(ctx context.Context, txn *sql.Tx, roomID string) (err error)
}
// FederationNotaryServerKeysJSON contains the byte-for-byte responses from servers which contain their keys and is signed by them.
type FederationNotaryServerKeysJSON interface {
// InsertJSONResponse inserts a new response JSON. Useless on its own, needs querying via FederationNotaryServerKeysMetadata
// `validUntil` should be the value of `valid_until_ts` with the 7-day check applied from:
// "Servers MUST use the lesser of this field and 7 days into the future when determining if a key is valid.
// This is to avoid a situation where an attacker publishes a key which is valid for a significant amount of time
// without a way for the homeserver owner to revoke it.""
InsertJSONResponse(ctx context.Context, txn *sql.Tx, keyQueryResponseJSON gomatrixserverlib.ServerKeys, serverName gomatrixserverlib.ServerName, validUntil gomatrixserverlib.Timestamp) (NotaryID, error)
}
// FederationNotaryServerKeysMetadata persists the metadata for FederationNotaryServerKeysJSON
type FederationNotaryServerKeysMetadata interface {
// UpsertKey updates or inserts a (server_name, key_id) tuple, pointing it via NotaryID at the the response which has the longest valid_until_ts
// `newNotaryID` and `newValidUntil` should be the notary ID / valid_until which has this (server_name, key_id) tuple already, e.g one you just inserted.
UpsertKey(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, keyID gomatrixserverlib.KeyID, newNotaryID NotaryID, newValidUntil gomatrixserverlib.Timestamp) (NotaryID, error)
// SelectKeys returns the signed JSON objects which contain the given key IDs. This will be at most the length of `keyIDs` and at least 1 (assuming
// the keys exist in the first place). If `keyIDs` is empty, the signed JSON object with the longest valid_until_ts will be returned.
SelectKeys(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, keyIDs []gomatrixserverlib.KeyID) ([]gomatrixserverlib.ServerKeys, error)
// DeleteOldJSONResponses removes all responses which are not referenced in FederationNotaryServerKeysMetadata
DeleteOldJSONResponses(ctx context.Context, txn *sql.Tx) error
}
type FederationServerSigningKeys interface {
BulkSelectServerKeys(ctx context.Context, txn *sql.Tx, requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error)
UpsertServerKeys(ctx context.Context, txn *sql.Tx, request gomatrixserverlib.PublicKeyLookupRequest, key gomatrixserverlib.PublicKeyLookupResult) error
}

View file

@ -0,0 +1,53 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package types
import (
"github.com/matrix-org/gomatrixserverlib"
)
// A JoinedHost is a server that is joined to a matrix room.
type JoinedHost struct {
// The MemberEventID of a m.room.member join event.
MemberEventID string
// The domain part of the state key of the m.room.member join event
ServerName gomatrixserverlib.ServerName
}
type ServerNames []gomatrixserverlib.ServerName
func (s ServerNames) Len() int { return len(s) }
func (s ServerNames) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
func (s ServerNames) Less(i, j int) bool { return s[i] < s[j] }
// tracks peeks we're performing on another server over federation
type OutboundPeek struct {
PeekID string
RoomID string
ServerName gomatrixserverlib.ServerName
CreationTimestamp int64
RenewedTimestamp int64
RenewalInterval int64
}
// tracks peeks other servers are performing on us over federation
type InboundPeek struct {
PeekID string
RoomID string
ServerName gomatrixserverlib.ServerName
CreationTimestamp int64
RenewedTimestamp int64
RenewalInterval int64
}