mirror of
https://github.com/hoernschen/dendrite.git
synced 2025-07-29 12:42:46 +00:00
Merge federationapi
, federationsender
, signingkeyserver
components (#2055)
* Initial federation sender -> federation API refactoring * Move base into own package, avoids import cycle * Fix build errors * Fix tests * Add signing key server tables * Try to fold signing key server into federation API * Fix dendritejs builds * Update embedded interfaces * Fix panic, fix lint error * Update configs, docker * Rename some things * Reuse same keyring on the implementing side * Fix federation tests, `NewBaseDendrite` can accept freeform options * Fix build * Update create_db, configs * Name tables back * Don't rename federationsender consumer for now
This commit is contained in:
parent
6e93531e94
commit
ec716793eb
136 changed files with 1211 additions and 1786 deletions
208
federationapi/api/api.go
Normal file
208
federationapi/api/api.go
Normal file
|
@ -0,0 +1,208 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/federationapi/types"
|
||||
"github.com/matrix-org/gomatrix"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// FederationClient is a subset of gomatrixserverlib.FederationClient functions which the fedsender
|
||||
// implements as proxy calls, with built-in backoff/retries/etc. Errors returned from functions in
|
||||
// this interface are of type FederationClientError
|
||||
type FederationClient interface {
|
||||
gomatrixserverlib.BackfillClient
|
||||
gomatrixserverlib.FederatedStateClient
|
||||
GetUserDevices(ctx context.Context, s gomatrixserverlib.ServerName, userID string) (res gomatrixserverlib.RespUserDevices, err error)
|
||||
ClaimKeys(ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string) (res gomatrixserverlib.RespClaimKeys, err error)
|
||||
QueryKeys(ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string) (res gomatrixserverlib.RespQueryKeys, err error)
|
||||
GetEvent(ctx context.Context, s gomatrixserverlib.ServerName, eventID string) (res gomatrixserverlib.Transaction, err error)
|
||||
MSC2836EventRelationships(ctx context.Context, dst gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest, roomVersion gomatrixserverlib.RoomVersion) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error)
|
||||
MSC2946Spaces(ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, r gomatrixserverlib.MSC2946SpacesRequest) (res gomatrixserverlib.MSC2946SpacesResponse, err error)
|
||||
LookupServerKeys(ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) ([]gomatrixserverlib.ServerKeys, error)
|
||||
}
|
||||
|
||||
// FederationClientError is returned from FederationClient methods in the event of a problem.
|
||||
type FederationClientError struct {
|
||||
Err string
|
||||
RetryAfter time.Duration
|
||||
Blacklisted bool
|
||||
}
|
||||
|
||||
func (e *FederationClientError) Error() string {
|
||||
return fmt.Sprintf("%s - (retry_after=%s, blacklisted=%v)", e.Err, e.RetryAfter.String(), e.Blacklisted)
|
||||
}
|
||||
|
||||
// FederationInternalAPI is used to query information from the federation sender.
|
||||
type FederationInternalAPI interface {
|
||||
FederationClient
|
||||
gomatrixserverlib.KeyDatabase
|
||||
|
||||
KeyRing() *gomatrixserverlib.KeyRing
|
||||
|
||||
QueryServerKeys(ctx context.Context, request *QueryServerKeysRequest, response *QueryServerKeysResponse) error
|
||||
|
||||
// PerformDirectoryLookup looks up a remote room ID from a room alias.
|
||||
PerformDirectoryLookup(
|
||||
ctx context.Context,
|
||||
request *PerformDirectoryLookupRequest,
|
||||
response *PerformDirectoryLookupResponse,
|
||||
) error
|
||||
// Query the server names of the joined hosts in a room.
|
||||
// Unlike QueryJoinedHostsInRoom, this function returns a de-duplicated slice
|
||||
// containing only the server names (without information for membership events).
|
||||
// The response will include this server if they are joined to the room.
|
||||
QueryJoinedHostServerNamesInRoom(
|
||||
ctx context.Context,
|
||||
request *QueryJoinedHostServerNamesInRoomRequest,
|
||||
response *QueryJoinedHostServerNamesInRoomResponse,
|
||||
) error
|
||||
// Handle an instruction to make_join & send_join with a remote server.
|
||||
PerformJoin(
|
||||
ctx context.Context,
|
||||
request *PerformJoinRequest,
|
||||
response *PerformJoinResponse,
|
||||
)
|
||||
// Handle an instruction to peek a room on a remote server.
|
||||
PerformOutboundPeek(
|
||||
ctx context.Context,
|
||||
request *PerformOutboundPeekRequest,
|
||||
response *PerformOutboundPeekResponse,
|
||||
) error
|
||||
// Handle an instruction to make_leave & send_leave with a remote server.
|
||||
PerformLeave(
|
||||
ctx context.Context,
|
||||
request *PerformLeaveRequest,
|
||||
response *PerformLeaveResponse,
|
||||
) error
|
||||
// Handle sending an invite to a remote server.
|
||||
PerformInvite(
|
||||
ctx context.Context,
|
||||
request *PerformInviteRequest,
|
||||
response *PerformInviteResponse,
|
||||
) error
|
||||
// Notifies the federation sender that these servers may be online and to retry sending messages.
|
||||
PerformServersAlive(
|
||||
ctx context.Context,
|
||||
request *PerformServersAliveRequest,
|
||||
response *PerformServersAliveResponse,
|
||||
) error
|
||||
// Broadcasts an EDU to all servers in rooms we are joined to.
|
||||
PerformBroadcastEDU(
|
||||
ctx context.Context,
|
||||
request *PerformBroadcastEDURequest,
|
||||
response *PerformBroadcastEDUResponse,
|
||||
) error
|
||||
}
|
||||
|
||||
type QueryServerKeysRequest struct {
|
||||
ServerName gomatrixserverlib.ServerName
|
||||
KeyIDToCriteria map[gomatrixserverlib.KeyID]gomatrixserverlib.PublicKeyNotaryQueryCriteria
|
||||
}
|
||||
|
||||
func (q *QueryServerKeysRequest) KeyIDs() []gomatrixserverlib.KeyID {
|
||||
kids := make([]gomatrixserverlib.KeyID, len(q.KeyIDToCriteria))
|
||||
i := 0
|
||||
for keyID := range q.KeyIDToCriteria {
|
||||
kids[i] = keyID
|
||||
i++
|
||||
}
|
||||
return kids
|
||||
}
|
||||
|
||||
type QueryServerKeysResponse struct {
|
||||
ServerKeys []gomatrixserverlib.ServerKeys
|
||||
}
|
||||
|
||||
type QueryPublicKeysRequest struct {
|
||||
Requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp `json:"requests"`
|
||||
}
|
||||
|
||||
type QueryPublicKeysResponse struct {
|
||||
Results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult `json:"results"`
|
||||
}
|
||||
|
||||
type PerformDirectoryLookupRequest struct {
|
||||
RoomAlias string `json:"room_alias"`
|
||||
ServerName gomatrixserverlib.ServerName `json:"server_name"`
|
||||
}
|
||||
|
||||
type PerformDirectoryLookupResponse struct {
|
||||
RoomID string `json:"room_id"`
|
||||
ServerNames []gomatrixserverlib.ServerName `json:"server_names"`
|
||||
}
|
||||
|
||||
type PerformJoinRequest struct {
|
||||
RoomID string `json:"room_id"`
|
||||
UserID string `json:"user_id"`
|
||||
// The sorted list of servers to try. Servers will be tried sequentially, after de-duplication.
|
||||
ServerNames types.ServerNames `json:"server_names"`
|
||||
Content map[string]interface{} `json:"content"`
|
||||
}
|
||||
|
||||
type PerformJoinResponse struct {
|
||||
JoinedVia gomatrixserverlib.ServerName
|
||||
LastError *gomatrix.HTTPError
|
||||
}
|
||||
|
||||
type PerformOutboundPeekRequest struct {
|
||||
RoomID string `json:"room_id"`
|
||||
// The sorted list of servers to try. Servers will be tried sequentially, after de-duplication.
|
||||
ServerNames types.ServerNames `json:"server_names"`
|
||||
}
|
||||
|
||||
type PerformOutboundPeekResponse struct {
|
||||
LastError *gomatrix.HTTPError
|
||||
}
|
||||
|
||||
type PerformLeaveRequest struct {
|
||||
RoomID string `json:"room_id"`
|
||||
UserID string `json:"user_id"`
|
||||
ServerNames types.ServerNames `json:"server_names"`
|
||||
}
|
||||
|
||||
type PerformLeaveResponse struct {
|
||||
}
|
||||
|
||||
type PerformInviteRequest struct {
|
||||
RoomVersion gomatrixserverlib.RoomVersion `json:"room_version"`
|
||||
Event *gomatrixserverlib.HeaderedEvent `json:"event"`
|
||||
InviteRoomState []gomatrixserverlib.InviteV2StrippedState `json:"invite_room_state"`
|
||||
}
|
||||
|
||||
type PerformInviteResponse struct {
|
||||
Event *gomatrixserverlib.HeaderedEvent `json:"event"`
|
||||
}
|
||||
|
||||
type PerformServersAliveRequest struct {
|
||||
Servers []gomatrixserverlib.ServerName
|
||||
}
|
||||
|
||||
type PerformServersAliveResponse struct {
|
||||
}
|
||||
|
||||
// QueryJoinedHostServerNamesInRoomRequest is a request to QueryJoinedHostServerNames
|
||||
type QueryJoinedHostServerNamesInRoomRequest struct {
|
||||
RoomID string `json:"room_id"`
|
||||
}
|
||||
|
||||
// QueryJoinedHostServerNamesInRoomResponse is a response to QueryJoinedHostServerNames
|
||||
type QueryJoinedHostServerNamesInRoomResponse struct {
|
||||
ServerNames []gomatrixserverlib.ServerName `json:"server_names"`
|
||||
}
|
||||
|
||||
type PerformBroadcastEDURequest struct {
|
||||
}
|
||||
|
||||
type PerformBroadcastEDUResponse struct {
|
||||
}
|
||||
|
||||
type InputPublicKeysRequest struct {
|
||||
Keys map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult `json:"keys"`
|
||||
}
|
||||
|
||||
type InputPublicKeysResponse struct {
|
||||
}
|
249
federationapi/consumers/eduserver.go
Normal file
249
federationapi/consumers/eduserver.go
Normal file
|
@ -0,0 +1,249 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package consumers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/Shopify/sarama"
|
||||
"github.com/matrix-org/dendrite/eduserver/api"
|
||||
"github.com/matrix-org/dendrite/federationapi/queue"
|
||||
"github.com/matrix-org/dendrite/federationapi/storage"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/setup/process"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// OutputEDUConsumer consumes events that originate in EDU server.
|
||||
type OutputEDUConsumer struct {
|
||||
typingConsumer *internal.ContinualConsumer
|
||||
sendToDeviceConsumer *internal.ContinualConsumer
|
||||
receiptConsumer *internal.ContinualConsumer
|
||||
db storage.Database
|
||||
queues *queue.OutgoingQueues
|
||||
ServerName gomatrixserverlib.ServerName
|
||||
TypingTopic string
|
||||
SendToDeviceTopic string
|
||||
}
|
||||
|
||||
// NewOutputEDUConsumer creates a new OutputEDUConsumer. Call Start() to begin consuming from EDU servers.
|
||||
func NewOutputEDUConsumer(
|
||||
process *process.ProcessContext,
|
||||
cfg *config.FederationAPI,
|
||||
kafkaConsumer sarama.Consumer,
|
||||
queues *queue.OutgoingQueues,
|
||||
store storage.Database,
|
||||
) *OutputEDUConsumer {
|
||||
c := &OutputEDUConsumer{
|
||||
typingConsumer: &internal.ContinualConsumer{
|
||||
Process: process,
|
||||
ComponentName: "eduserver/typing",
|
||||
Topic: cfg.Matrix.Kafka.TopicFor(config.TopicOutputTypingEvent),
|
||||
Consumer: kafkaConsumer,
|
||||
PartitionStore: store,
|
||||
},
|
||||
sendToDeviceConsumer: &internal.ContinualConsumer{
|
||||
Process: process,
|
||||
ComponentName: "eduserver/sendtodevice",
|
||||
Topic: cfg.Matrix.Kafka.TopicFor(config.TopicOutputSendToDeviceEvent),
|
||||
Consumer: kafkaConsumer,
|
||||
PartitionStore: store,
|
||||
},
|
||||
receiptConsumer: &internal.ContinualConsumer{
|
||||
Process: process,
|
||||
ComponentName: "eduserver/receipt",
|
||||
Topic: cfg.Matrix.Kafka.TopicFor(config.TopicOutputReceiptEvent),
|
||||
Consumer: kafkaConsumer,
|
||||
PartitionStore: store,
|
||||
},
|
||||
queues: queues,
|
||||
db: store,
|
||||
ServerName: cfg.Matrix.ServerName,
|
||||
TypingTopic: cfg.Matrix.Kafka.TopicFor(config.TopicOutputTypingEvent),
|
||||
SendToDeviceTopic: cfg.Matrix.Kafka.TopicFor(config.TopicOutputSendToDeviceEvent),
|
||||
}
|
||||
c.typingConsumer.ProcessMessage = c.onTypingEvent
|
||||
c.sendToDeviceConsumer.ProcessMessage = c.onSendToDeviceEvent
|
||||
c.receiptConsumer.ProcessMessage = c.onReceiptEvent
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// Start consuming from EDU servers
|
||||
func (t *OutputEDUConsumer) Start() error {
|
||||
if err := t.typingConsumer.Start(); err != nil {
|
||||
return fmt.Errorf("t.typingConsumer.Start: %w", err)
|
||||
}
|
||||
if err := t.sendToDeviceConsumer.Start(); err != nil {
|
||||
return fmt.Errorf("t.sendToDeviceConsumer.Start: %w", err)
|
||||
}
|
||||
if err := t.receiptConsumer.Start(); err != nil {
|
||||
return fmt.Errorf("t.receiptConsumer.Start: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// onSendToDeviceEvent is called in response to a message received on the
|
||||
// send-to-device events topic from the EDU server.
|
||||
func (t *OutputEDUConsumer) onSendToDeviceEvent(msg *sarama.ConsumerMessage) error {
|
||||
// Extract the send-to-device event from msg.
|
||||
var ote api.OutputSendToDeviceEvent
|
||||
if err := json.Unmarshal(msg.Value, &ote); err != nil {
|
||||
log.WithError(err).Errorf("eduserver output log: message parse failed (expected send-to-device)")
|
||||
return nil
|
||||
}
|
||||
|
||||
// only send send-to-device events which originated from us
|
||||
_, originServerName, err := gomatrixserverlib.SplitID('@', ote.Sender)
|
||||
if err != nil {
|
||||
log.WithError(err).WithField("user_id", ote.Sender).Error("Failed to extract domain from send-to-device sender")
|
||||
return nil
|
||||
}
|
||||
if originServerName != t.ServerName {
|
||||
log.WithField("other_server", originServerName).Info("Suppressing send-to-device: originated elsewhere")
|
||||
return nil
|
||||
}
|
||||
|
||||
_, destServerName, err := gomatrixserverlib.SplitID('@', ote.UserID)
|
||||
if err != nil {
|
||||
log.WithError(err).WithField("user_id", ote.UserID).Error("Failed to extract domain from send-to-device destination")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Pack the EDU and marshal it
|
||||
edu := &gomatrixserverlib.EDU{
|
||||
Type: gomatrixserverlib.MDirectToDevice,
|
||||
Origin: string(t.ServerName),
|
||||
}
|
||||
tdm := gomatrixserverlib.ToDeviceMessage{
|
||||
Sender: ote.Sender,
|
||||
Type: ote.Type,
|
||||
MessageID: util.RandomString(32),
|
||||
Messages: map[string]map[string]json.RawMessage{
|
||||
ote.UserID: {
|
||||
ote.DeviceID: ote.Content,
|
||||
},
|
||||
},
|
||||
}
|
||||
if edu.Content, err = json.Marshal(tdm); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Infof("Sending send-to-device message into %q destination queue", destServerName)
|
||||
return t.queues.SendEDU(edu, t.ServerName, []gomatrixserverlib.ServerName{destServerName})
|
||||
}
|
||||
|
||||
// onTypingEvent is called in response to a message received on the typing
|
||||
// events topic from the EDU server.
|
||||
func (t *OutputEDUConsumer) onTypingEvent(msg *sarama.ConsumerMessage) error {
|
||||
// Extract the typing event from msg.
|
||||
var ote api.OutputTypingEvent
|
||||
if err := json.Unmarshal(msg.Value, &ote); err != nil {
|
||||
// Skip this msg but continue processing messages.
|
||||
log.WithError(err).Errorf("eduserver output log: message parse failed (expected typing)")
|
||||
return nil
|
||||
}
|
||||
|
||||
// only send typing events which originated from us
|
||||
_, typingServerName, err := gomatrixserverlib.SplitID('@', ote.Event.UserID)
|
||||
if err != nil {
|
||||
log.WithError(err).WithField("user_id", ote.Event.UserID).Error("Failed to extract domain from typing sender")
|
||||
return nil
|
||||
}
|
||||
if typingServerName != t.ServerName {
|
||||
log.WithField("other_server", typingServerName).Info("Suppressing typing notif: originated elsewhere")
|
||||
return nil
|
||||
}
|
||||
|
||||
joined, err := t.db.GetJoinedHosts(context.TODO(), ote.Event.RoomID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
names := make([]gomatrixserverlib.ServerName, len(joined))
|
||||
for i := range joined {
|
||||
names[i] = joined[i].ServerName
|
||||
}
|
||||
|
||||
edu := &gomatrixserverlib.EDU{Type: ote.Event.Type}
|
||||
if edu.Content, err = json.Marshal(map[string]interface{}{
|
||||
"room_id": ote.Event.RoomID,
|
||||
"user_id": ote.Event.UserID,
|
||||
"typing": ote.Event.Typing,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return t.queues.SendEDU(edu, t.ServerName, names)
|
||||
}
|
||||
|
||||
// onReceiptEvent is called in response to a message received on the receipt
|
||||
// events topic from the EDU server.
|
||||
func (t *OutputEDUConsumer) onReceiptEvent(msg *sarama.ConsumerMessage) error {
|
||||
// Extract the typing event from msg.
|
||||
var receipt api.OutputReceiptEvent
|
||||
if err := json.Unmarshal(msg.Value, &receipt); err != nil {
|
||||
// Skip this msg but continue processing messages.
|
||||
log.WithError(err).Errorf("eduserver output log: message parse failed (expected receipt)")
|
||||
return nil
|
||||
}
|
||||
|
||||
// only send receipt events which originated from us
|
||||
_, receiptServerName, err := gomatrixserverlib.SplitID('@', receipt.UserID)
|
||||
if err != nil {
|
||||
log.WithError(err).WithField("user_id", receipt.UserID).Error("Failed to extract domain from receipt sender")
|
||||
return nil
|
||||
}
|
||||
if receiptServerName != t.ServerName {
|
||||
return nil // don't log, very spammy as it logs for each remote receipt
|
||||
}
|
||||
|
||||
joined, err := t.db.GetJoinedHosts(context.TODO(), receipt.RoomID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
names := make([]gomatrixserverlib.ServerName, len(joined))
|
||||
for i := range joined {
|
||||
names[i] = joined[i].ServerName
|
||||
}
|
||||
|
||||
content := map[string]api.FederationReceiptMRead{}
|
||||
content[receipt.RoomID] = api.FederationReceiptMRead{
|
||||
User: map[string]api.FederationReceiptData{
|
||||
receipt.UserID: {
|
||||
Data: api.ReceiptTS{
|
||||
TS: receipt.Timestamp,
|
||||
},
|
||||
EventIDs: []string{receipt.EventID},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
edu := &gomatrixserverlib.EDU{
|
||||
Type: gomatrixserverlib.MReceipt,
|
||||
Origin: string(t.ServerName),
|
||||
}
|
||||
if edu.Content, err = json.Marshal(content); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return t.queues.SendEDU(edu, t.ServerName, names)
|
||||
}
|
202
federationapi/consumers/keychange.go
Normal file
202
federationapi/consumers/keychange.go
Normal file
|
@ -0,0 +1,202 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package consumers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/Shopify/sarama"
|
||||
eduserverAPI "github.com/matrix-org/dendrite/eduserver/api"
|
||||
"github.com/matrix-org/dendrite/federationapi/queue"
|
||||
"github.com/matrix-org/dendrite/federationapi/storage"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/keyserver/api"
|
||||
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/setup/process"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// KeyChangeConsumer consumes events that originate in key server.
|
||||
type KeyChangeConsumer struct {
|
||||
consumer *internal.ContinualConsumer
|
||||
db storage.Database
|
||||
queues *queue.OutgoingQueues
|
||||
serverName gomatrixserverlib.ServerName
|
||||
rsAPI roomserverAPI.RoomserverInternalAPI
|
||||
}
|
||||
|
||||
// NewKeyChangeConsumer creates a new KeyChangeConsumer. Call Start() to begin consuming from key servers.
|
||||
func NewKeyChangeConsumer(
|
||||
process *process.ProcessContext,
|
||||
cfg *config.KeyServer,
|
||||
kafkaConsumer sarama.Consumer,
|
||||
queues *queue.OutgoingQueues,
|
||||
store storage.Database,
|
||||
rsAPI roomserverAPI.RoomserverInternalAPI,
|
||||
) *KeyChangeConsumer {
|
||||
c := &KeyChangeConsumer{
|
||||
consumer: &internal.ContinualConsumer{
|
||||
Process: process,
|
||||
ComponentName: "federationapi/keychange",
|
||||
Topic: string(cfg.Matrix.Kafka.TopicFor(config.TopicOutputKeyChangeEvent)),
|
||||
Consumer: kafkaConsumer,
|
||||
PartitionStore: store,
|
||||
},
|
||||
queues: queues,
|
||||
db: store,
|
||||
serverName: cfg.Matrix.ServerName,
|
||||
rsAPI: rsAPI,
|
||||
}
|
||||
c.consumer.ProcessMessage = c.onMessage
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
// Start consuming from key servers
|
||||
func (t *KeyChangeConsumer) Start() error {
|
||||
if err := t.consumer.Start(); err != nil {
|
||||
return fmt.Errorf("t.consumer.Start: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// onMessage is called in response to a message received on the
|
||||
// key change events topic from the key server.
|
||||
func (t *KeyChangeConsumer) onMessage(msg *sarama.ConsumerMessage) error {
|
||||
var m api.DeviceMessage
|
||||
if err := json.Unmarshal(msg.Value, &m); err != nil {
|
||||
logrus.WithError(err).Errorf("failed to read device message from key change topic")
|
||||
return nil
|
||||
}
|
||||
if m.DeviceKeys == nil && m.OutputCrossSigningKeyUpdate == nil {
|
||||
// This probably shouldn't happen but stops us from panicking if we come
|
||||
// across an update that doesn't satisfy either types.
|
||||
return nil
|
||||
}
|
||||
switch m.Type {
|
||||
case api.TypeCrossSigningUpdate:
|
||||
return t.onCrossSigningMessage(m)
|
||||
case api.TypeDeviceKeyUpdate:
|
||||
fallthrough
|
||||
default:
|
||||
return t.onDeviceKeyMessage(m)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) error {
|
||||
logger := logrus.WithField("user_id", m.UserID)
|
||||
|
||||
// only send key change events which originated from us
|
||||
_, originServerName, err := gomatrixserverlib.SplitID('@', m.UserID)
|
||||
if err != nil {
|
||||
logger.WithError(err).Error("Failed to extract domain from key change event")
|
||||
return nil
|
||||
}
|
||||
if originServerName != t.serverName {
|
||||
return nil
|
||||
}
|
||||
|
||||
var queryRes roomserverAPI.QueryRoomsForUserResponse
|
||||
err = t.rsAPI.QueryRoomsForUser(context.Background(), &roomserverAPI.QueryRoomsForUserRequest{
|
||||
UserID: m.UserID,
|
||||
WantMembership: "join",
|
||||
}, &queryRes)
|
||||
if err != nil {
|
||||
logger.WithError(err).Error("failed to calculate joined rooms for user")
|
||||
return nil
|
||||
}
|
||||
// send this key change to all servers who share rooms with this user.
|
||||
destinations, err := t.db.GetJoinedHostsForRooms(context.Background(), queryRes.RoomIDs)
|
||||
if err != nil {
|
||||
logger.WithError(err).Error("failed to calculate joined hosts for rooms user is in")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Pack the EDU and marshal it
|
||||
edu := &gomatrixserverlib.EDU{
|
||||
Type: gomatrixserverlib.MDeviceListUpdate,
|
||||
Origin: string(t.serverName),
|
||||
}
|
||||
event := gomatrixserverlib.DeviceListUpdateEvent{
|
||||
UserID: m.UserID,
|
||||
DeviceID: m.DeviceID,
|
||||
DeviceDisplayName: m.DisplayName,
|
||||
StreamID: m.StreamID,
|
||||
PrevID: prevID(m.StreamID),
|
||||
Deleted: len(m.KeyJSON) == 0,
|
||||
Keys: m.KeyJSON,
|
||||
}
|
||||
if edu.Content, err = json.Marshal(event); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logrus.Infof("Sending device list update message to %q", destinations)
|
||||
return t.queues.SendEDU(edu, t.serverName, destinations)
|
||||
}
|
||||
|
||||
func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) error {
|
||||
output := m.CrossSigningKeyUpdate
|
||||
_, host, err := gomatrixserverlib.SplitID('@', output.UserID)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Errorf("fedsender key change consumer: user ID parse failure")
|
||||
return nil
|
||||
}
|
||||
if host != gomatrixserverlib.ServerName(t.serverName) {
|
||||
// Ignore any messages that didn't originate locally, otherwise we'll
|
||||
// end up parroting information we received from other servers.
|
||||
return nil
|
||||
}
|
||||
logger := logrus.WithField("user_id", output.UserID)
|
||||
|
||||
var queryRes roomserverAPI.QueryRoomsForUserResponse
|
||||
err = t.rsAPI.QueryRoomsForUser(context.Background(), &roomserverAPI.QueryRoomsForUserRequest{
|
||||
UserID: output.UserID,
|
||||
WantMembership: "join",
|
||||
}, &queryRes)
|
||||
if err != nil {
|
||||
logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined rooms for user")
|
||||
return nil
|
||||
}
|
||||
// send this key change to all servers who share rooms with this user.
|
||||
destinations, err := t.db.GetJoinedHostsForRooms(context.Background(), queryRes.RoomIDs)
|
||||
if err != nil {
|
||||
logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined hosts for rooms user is in")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Pack the EDU and marshal it
|
||||
edu := &gomatrixserverlib.EDU{
|
||||
Type: eduserverAPI.MSigningKeyUpdate,
|
||||
Origin: string(t.serverName),
|
||||
}
|
||||
if edu.Content, err = json.Marshal(output); err != nil {
|
||||
logger.WithError(err).Error("fedsender key change consumer: failed to marshal output, dropping")
|
||||
return nil
|
||||
}
|
||||
|
||||
logger.Infof("Sending cross-signing update message to %q", destinations)
|
||||
return t.queues.SendEDU(edu, t.serverName, destinations)
|
||||
}
|
||||
|
||||
func prevID(streamID int) []int {
|
||||
if streamID <= 1 {
|
||||
return nil
|
||||
}
|
||||
return []int{streamID - 1}
|
||||
}
|
407
federationapi/consumers/roomserver.go
Normal file
407
federationapi/consumers/roomserver.go
Normal file
|
@ -0,0 +1,407 @@
|
|||
// Copyright 2017 Vector Creations Ltd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package consumers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/Shopify/sarama"
|
||||
"github.com/matrix-org/dendrite/federationapi/queue"
|
||||
"github.com/matrix-org/dendrite/federationapi/storage"
|
||||
"github.com/matrix-org/dendrite/federationapi/types"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/setup/process"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// OutputRoomEventConsumer consumes events that originated in the room server.
|
||||
type OutputRoomEventConsumer struct {
|
||||
cfg *config.FederationAPI
|
||||
rsAPI api.RoomserverInternalAPI
|
||||
rsConsumer *internal.ContinualConsumer
|
||||
db storage.Database
|
||||
queues *queue.OutgoingQueues
|
||||
}
|
||||
|
||||
// NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers.
|
||||
func NewOutputRoomEventConsumer(
|
||||
process *process.ProcessContext,
|
||||
cfg *config.FederationAPI,
|
||||
kafkaConsumer sarama.Consumer,
|
||||
queues *queue.OutgoingQueues,
|
||||
store storage.Database,
|
||||
rsAPI api.RoomserverInternalAPI,
|
||||
) *OutputRoomEventConsumer {
|
||||
consumer := internal.ContinualConsumer{
|
||||
Process: process,
|
||||
ComponentName: "federationapi/roomserver",
|
||||
Topic: string(cfg.Matrix.Kafka.TopicFor(config.TopicOutputRoomEvent)),
|
||||
Consumer: kafkaConsumer,
|
||||
PartitionStore: store,
|
||||
}
|
||||
s := &OutputRoomEventConsumer{
|
||||
cfg: cfg,
|
||||
rsConsumer: &consumer,
|
||||
db: store,
|
||||
queues: queues,
|
||||
rsAPI: rsAPI,
|
||||
}
|
||||
consumer.ProcessMessage = s.onMessage
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// Start consuming from room servers
|
||||
func (s *OutputRoomEventConsumer) Start() error {
|
||||
return s.rsConsumer.Start()
|
||||
}
|
||||
|
||||
// onMessage is called when the federation server receives a new event from the room server output log.
|
||||
// It is unsafe to call this with messages for the same room in multiple gorountines
|
||||
// because updates it will likely fail with a types.EventIDMismatchError when it
|
||||
// realises that it cannot update the room state using the deltas.
|
||||
func (s *OutputRoomEventConsumer) onMessage(msg *sarama.ConsumerMessage) error {
|
||||
// Parse out the event JSON
|
||||
var output api.OutputEvent
|
||||
if err := json.Unmarshal(msg.Value, &output); err != nil {
|
||||
// If the message was invalid, log it and move on to the next message in the stream
|
||||
log.WithError(err).Errorf("roomserver output log: message parse failure")
|
||||
return nil
|
||||
}
|
||||
|
||||
switch output.Type {
|
||||
case api.OutputTypeNewRoomEvent:
|
||||
ev := output.NewRoomEvent.Event
|
||||
|
||||
if output.NewRoomEvent.RewritesState {
|
||||
if err := s.db.PurgeRoomState(context.TODO(), ev.RoomID()); err != nil {
|
||||
return fmt.Errorf("s.db.PurgeRoom: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.processMessage(*output.NewRoomEvent); err != nil {
|
||||
switch err.(type) {
|
||||
case *queue.ErrorFederationDisabled:
|
||||
log.WithField("error", output.Type).Info(
|
||||
err.Error(),
|
||||
)
|
||||
default:
|
||||
// panic rather than continue with an inconsistent database
|
||||
log.WithFields(log.Fields{
|
||||
"event_id": ev.EventID(),
|
||||
"event": string(ev.JSON()),
|
||||
"add": output.NewRoomEvent.AddsStateEventIDs,
|
||||
"del": output.NewRoomEvent.RemovesStateEventIDs,
|
||||
log.ErrorKey: err,
|
||||
}).Panicf("roomserver output log: write room event failure")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
case api.OutputTypeNewInboundPeek:
|
||||
if err := s.processInboundPeek(*output.NewInboundPeek); err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
"event": output.NewInboundPeek,
|
||||
log.ErrorKey: err,
|
||||
}).Panicf("roomserver output log: remote peek event failure")
|
||||
return nil
|
||||
}
|
||||
default:
|
||||
log.WithField("type", output.Type).Debug(
|
||||
"roomserver output log: ignoring unknown output type",
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// processInboundPeek starts tracking a new federated inbound peek (replacing the existing one if any)
|
||||
// causing the federationapi to start sending messages to the peeking server
|
||||
func (s *OutputRoomEventConsumer) processInboundPeek(orp api.OutputNewInboundPeek) error {
|
||||
|
||||
// FIXME: there's a race here - we should start /sending new peeked events
|
||||
// atomically after the orp.LatestEventID to ensure there are no gaps between
|
||||
// the peek beginning and the send stream beginning.
|
||||
//
|
||||
// We probably need to track orp.LatestEventID on the inbound peek, but it's
|
||||
// unclear how we then use that to prevent the race when we start the send
|
||||
// stream.
|
||||
//
|
||||
// This is making the tests flakey.
|
||||
|
||||
return s.db.AddInboundPeek(context.TODO(), orp.ServerName, orp.RoomID, orp.PeekID, orp.RenewalInterval)
|
||||
}
|
||||
|
||||
// processMessage updates the list of currently joined hosts in the room
|
||||
// and then sends the event to the hosts that were joined before the event.
|
||||
func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent) error {
|
||||
addsJoinedHosts, err := joinedHostsFromEvents(gomatrixserverlib.UnwrapEventHeaders(ore.AddsState()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Update our copy of the current state.
|
||||
// We keep a copy of the current state because the state at each event is
|
||||
// expressed as a delta against the current state.
|
||||
// TODO(#290): handle EventIDMismatchError and recover the current state by
|
||||
// talking to the roomserver
|
||||
oldJoinedHosts, err := s.db.UpdateRoom(
|
||||
context.TODO(),
|
||||
ore.Event.RoomID(),
|
||||
ore.LastSentEventID,
|
||||
ore.Event.EventID(),
|
||||
addsJoinedHosts,
|
||||
ore.RemovesStateEventIDs,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if oldJoinedHosts == nil {
|
||||
// This means that there is nothing to update as this is a duplicate
|
||||
// message.
|
||||
// This can happen if dendrite crashed between reading the message and
|
||||
// persisting the stream position.
|
||||
return nil
|
||||
}
|
||||
|
||||
if ore.SendAsServer == api.DoNotSendToOtherServers {
|
||||
// Ignore event that we don't need to send anywhere.
|
||||
return nil
|
||||
}
|
||||
|
||||
// Work out which hosts were joined at the event itself.
|
||||
joinedHostsAtEvent, err := s.joinedHostsAtEvent(ore, oldJoinedHosts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO: do housekeeping to evict unrenewed peeking hosts
|
||||
|
||||
// TODO: implement query to let the fedapi check whether a given peek is live or not
|
||||
|
||||
// Send the event.
|
||||
return s.queues.SendEvent(
|
||||
ore.Event, gomatrixserverlib.ServerName(ore.SendAsServer), joinedHostsAtEvent,
|
||||
)
|
||||
}
|
||||
|
||||
// joinedHostsAtEvent works out a list of matrix servers that were joined to
|
||||
// the room at the event (including peeking ones)
|
||||
// It is important to use the state at the event for sending messages because:
|
||||
// 1) We shouldn't send messages to servers that weren't in the room.
|
||||
// 2) If a server is kicked from the rooms it should still be told about the
|
||||
// kick event,
|
||||
// Usually the list can be calculated locally, but sometimes it will need fetch
|
||||
// events from the room server.
|
||||
// Returns an error if there was a problem talking to the room server.
|
||||
func (s *OutputRoomEventConsumer) joinedHostsAtEvent(
|
||||
ore api.OutputNewRoomEvent, oldJoinedHosts []types.JoinedHost,
|
||||
) ([]gomatrixserverlib.ServerName, error) {
|
||||
// Combine the delta into a single delta so that the adds and removes can
|
||||
// cancel each other out. This should reduce the number of times we need
|
||||
// to fetch a state event from the room server.
|
||||
combinedAdds, combinedRemoves := combineDeltas(
|
||||
ore.AddsStateEventIDs, ore.RemovesStateEventIDs,
|
||||
ore.StateBeforeAddsEventIDs, ore.StateBeforeRemovesEventIDs,
|
||||
)
|
||||
combinedAddsEvents, err := s.lookupStateEvents(combinedAdds, ore.Event.Event)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
combinedAddsJoinedHosts, err := joinedHostsFromEvents(combinedAddsEvents)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
removed := map[string]bool{}
|
||||
for _, eventID := range combinedRemoves {
|
||||
removed[eventID] = true
|
||||
}
|
||||
|
||||
joined := map[gomatrixserverlib.ServerName]bool{}
|
||||
for _, joinedHost := range oldJoinedHosts {
|
||||
if removed[joinedHost.MemberEventID] {
|
||||
// This m.room.member event is part of the current state of the
|
||||
// room, but not part of the state at the event we are processing
|
||||
// Therefore we can't use it to tell whether the server was in
|
||||
// the room at the event.
|
||||
continue
|
||||
}
|
||||
joined[joinedHost.ServerName] = true
|
||||
}
|
||||
|
||||
for _, joinedHost := range combinedAddsJoinedHosts {
|
||||
// This m.room.member event was part of the state of the room at the
|
||||
// event, but isn't part of the current state of the room now.
|
||||
joined[joinedHost.ServerName] = true
|
||||
}
|
||||
|
||||
// handle peeking hosts
|
||||
inboundPeeks, err := s.db.GetInboundPeeks(context.TODO(), ore.Event.Event.RoomID())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, inboundPeek := range inboundPeeks {
|
||||
joined[inboundPeek.ServerName] = true
|
||||
}
|
||||
|
||||
var result []gomatrixserverlib.ServerName
|
||||
for serverName, include := range joined {
|
||||
if include {
|
||||
result = append(result, serverName)
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// joinedHostsFromEvents turns a list of state events into a list of joined hosts.
|
||||
// This errors if one of the events was invalid.
|
||||
// It should be impossible for an invalid event to get this far in the pipeline.
|
||||
func joinedHostsFromEvents(evs []*gomatrixserverlib.Event) ([]types.JoinedHost, error) {
|
||||
var joinedHosts []types.JoinedHost
|
||||
for _, ev := range evs {
|
||||
if ev.Type() != "m.room.member" || ev.StateKey() == nil {
|
||||
continue
|
||||
}
|
||||
membership, err := ev.Membership()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if membership != gomatrixserverlib.Join {
|
||||
continue
|
||||
}
|
||||
_, serverName, err := gomatrixserverlib.SplitID('@', *ev.StateKey())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
joinedHosts = append(joinedHosts, types.JoinedHost{
|
||||
MemberEventID: ev.EventID(), ServerName: serverName,
|
||||
})
|
||||
}
|
||||
return joinedHosts, nil
|
||||
}
|
||||
|
||||
// combineDeltas combines two deltas into a single delta.
|
||||
// Assumes that the order of operations is add(1), remove(1), add(2), remove(2).
|
||||
// Removes duplicate entries and redundant operations from each delta.
|
||||
func combineDeltas(adds1, removes1, adds2, removes2 []string) (adds, removes []string) {
|
||||
addSet := map[string]bool{}
|
||||
removeSet := map[string]bool{}
|
||||
|
||||
// combine processes each unique value in a list.
|
||||
// If the value is in the removeFrom set then it is removed from that set.
|
||||
// Otherwise it is added to the addTo set.
|
||||
combine := func(values []string, removeFrom, addTo map[string]bool) {
|
||||
processed := map[string]bool{}
|
||||
for _, value := range values {
|
||||
if processed[value] {
|
||||
continue
|
||||
}
|
||||
processed[value] = true
|
||||
if removeFrom[value] {
|
||||
delete(removeFrom, value)
|
||||
} else {
|
||||
addTo[value] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
combine(adds1, nil, addSet)
|
||||
combine(removes1, addSet, removeSet)
|
||||
combine(adds2, removeSet, addSet)
|
||||
combine(removes2, addSet, removeSet)
|
||||
|
||||
for value := range addSet {
|
||||
adds = append(adds, value)
|
||||
}
|
||||
for value := range removeSet {
|
||||
removes = append(removes, value)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// lookupStateEvents looks up the state events that are added by a new event.
|
||||
func (s *OutputRoomEventConsumer) lookupStateEvents(
|
||||
addsStateEventIDs []string, event *gomatrixserverlib.Event,
|
||||
) ([]*gomatrixserverlib.Event, error) {
|
||||
// Fast path if there aren't any new state events.
|
||||
if len(addsStateEventIDs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Fast path if the only state event added is the event itself.
|
||||
if len(addsStateEventIDs) == 1 && addsStateEventIDs[0] == event.EventID() {
|
||||
return []*gomatrixserverlib.Event{event}, nil
|
||||
}
|
||||
|
||||
missing := addsStateEventIDs
|
||||
var result []*gomatrixserverlib.Event
|
||||
|
||||
// Check if event itself is being added.
|
||||
for _, eventID := range missing {
|
||||
if eventID == event.EventID() {
|
||||
result = append(result, event)
|
||||
break
|
||||
}
|
||||
}
|
||||
missing = missingEventsFrom(result, addsStateEventIDs)
|
||||
|
||||
if len(missing) == 0 {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// At this point the missing events are neither the event itself nor are
|
||||
// they present in our local database. Our only option is to fetch them
|
||||
// from the roomserver using the query API.
|
||||
eventReq := api.QueryEventsByIDRequest{EventIDs: missing}
|
||||
var eventResp api.QueryEventsByIDResponse
|
||||
if err := s.rsAPI.QueryEventsByID(context.TODO(), &eventReq, &eventResp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, headeredEvent := range eventResp.Events {
|
||||
result = append(result, headeredEvent.Event)
|
||||
}
|
||||
|
||||
missing = missingEventsFrom(result, addsStateEventIDs)
|
||||
|
||||
if len(missing) != 0 {
|
||||
return nil, fmt.Errorf(
|
||||
"missing %d state events IDs at event %q", len(missing), event.EventID(),
|
||||
)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func missingEventsFrom(events []*gomatrixserverlib.Event, required []string) []string {
|
||||
have := map[string]bool{}
|
||||
for _, event := range events {
|
||||
have[event.EventID()] = true
|
||||
}
|
||||
var missing []string
|
||||
for _, eventID := range required {
|
||||
if !have[eventID] {
|
||||
missing = append(missing, eventID)
|
||||
}
|
||||
}
|
||||
return missing
|
||||
}
|
53
federationapi/consumers/roomserver_test.go
Normal file
53
federationapi/consumers/roomserver_test.go
Normal file
|
@ -0,0 +1,53 @@
|
|||
// Copyright 2017 Vector Creations Ltd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package consumers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCombineNoOp(t *testing.T) {
|
||||
inputAdd1 := []string{"a", "b", "c"}
|
||||
inputDel1 := []string{"a", "b", "d"}
|
||||
inputAdd2 := []string{"a", "d", "e"}
|
||||
inputDel2 := []string{"a", "c", "e", "e"}
|
||||
|
||||
gotAdd, gotDel := combineDeltas(inputAdd1, inputDel1, inputAdd2, inputDel2)
|
||||
|
||||
if len(gotAdd) != 0 {
|
||||
t.Errorf("wanted combined adds to be an empty list, got %#v", gotAdd)
|
||||
}
|
||||
|
||||
if len(gotDel) != 0 {
|
||||
t.Errorf("wanted combined removes to be an empty list, got %#v", gotDel)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCombineDedup(t *testing.T) {
|
||||
inputAdd1 := []string{"a", "a"}
|
||||
inputDel1 := []string{"b", "b"}
|
||||
inputAdd2 := []string{"a", "a"}
|
||||
inputDel2 := []string{"b", "b"}
|
||||
|
||||
gotAdd, gotDel := combineDeltas(inputAdd1, inputDel1, inputAdd2, inputDel2)
|
||||
|
||||
if len(gotAdd) != 1 || gotAdd[0] != "a" {
|
||||
t.Errorf("wanted combined adds to be %#v, got %#v", []string{"a"}, gotAdd)
|
||||
}
|
||||
|
||||
if len(gotDel) != 1 || gotDel[0] != "b" {
|
||||
t.Errorf("wanted combined removes to be %#v, got %#v", []string{"b"}, gotDel)
|
||||
}
|
||||
}
|
|
@ -17,17 +17,33 @@ package federationapi
|
|||
import (
|
||||
"github.com/gorilla/mux"
|
||||
eduserverAPI "github.com/matrix-org/dendrite/eduserver/api"
|
||||
"github.com/matrix-org/dendrite/federationapi/api"
|
||||
federationAPI "github.com/matrix-org/dendrite/federationapi/api"
|
||||
federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api"
|
||||
"github.com/matrix-org/dendrite/federationapi/consumers"
|
||||
"github.com/matrix-org/dendrite/federationapi/internal"
|
||||
"github.com/matrix-org/dendrite/federationapi/inthttp"
|
||||
"github.com/matrix-org/dendrite/federationapi/queue"
|
||||
"github.com/matrix-org/dendrite/federationapi/statistics"
|
||||
"github.com/matrix-org/dendrite/federationapi/storage"
|
||||
"github.com/matrix-org/dendrite/internal/caching"
|
||||
keyserverAPI "github.com/matrix-org/dendrite/keyserver/api"
|
||||
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/setup/base"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/setup/kafka"
|
||||
userapi "github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/matrix-org/dendrite/federationapi/routing"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// AddInternalRoutes registers HTTP handlers for the internal API. Invokes functions
|
||||
// on the given input API.
|
||||
func AddInternalRoutes(router *mux.Router, intAPI api.FederationInternalAPI) {
|
||||
inthttp.AddRoutes(intAPI, router)
|
||||
}
|
||||
|
||||
// AddPublicRoutes sets up and registers HTTP handlers on the base API muxes for the FederationAPI component.
|
||||
func AddPublicRoutes(
|
||||
fedRouter, keyRouter, wellKnownRouter *mux.Router,
|
||||
|
@ -36,7 +52,7 @@ func AddPublicRoutes(
|
|||
federation *gomatrixserverlib.FederationClient,
|
||||
keyRing gomatrixserverlib.JSONVerifier,
|
||||
rsAPI roomserverAPI.RoomserverInternalAPI,
|
||||
federationSenderAPI federationSenderAPI.FederationSenderInternalAPI,
|
||||
federationAPI federationAPI.FederationInternalAPI,
|
||||
eduAPI eduserverAPI.EDUServerInputAPI,
|
||||
keyAPI keyserverAPI.KeyInternalAPI,
|
||||
mscCfg *config.MSCs,
|
||||
|
@ -44,8 +60,70 @@ func AddPublicRoutes(
|
|||
) {
|
||||
routing.Setup(
|
||||
fedRouter, keyRouter, wellKnownRouter, cfg, rsAPI,
|
||||
eduAPI, federationSenderAPI, keyRing,
|
||||
eduAPI, federationAPI, keyRing,
|
||||
federation, userAPI, keyAPI, mscCfg,
|
||||
servers,
|
||||
)
|
||||
}
|
||||
|
||||
// NewInternalAPI returns a concerete implementation of the internal API. Callers
|
||||
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
|
||||
func NewInternalAPI(
|
||||
base *base.BaseDendrite,
|
||||
federation *gomatrixserverlib.FederationClient,
|
||||
rsAPI roomserverAPI.RoomserverInternalAPI,
|
||||
caches *caching.Caches,
|
||||
resetBlacklist bool,
|
||||
) api.FederationInternalAPI {
|
||||
cfg := &base.Cfg.FederationAPI
|
||||
|
||||
federationDB, err := storage.NewDatabase(&cfg.Database, base.Caches)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Panic("failed to connect to federation sender db")
|
||||
}
|
||||
|
||||
if resetBlacklist {
|
||||
_ = federationDB.RemoveAllServersFromBlacklist()
|
||||
}
|
||||
|
||||
stats := &statistics.Statistics{
|
||||
DB: federationDB,
|
||||
FailuresUntilBlacklist: cfg.FederationMaxRetries,
|
||||
}
|
||||
|
||||
consumer, _ := kafka.SetupConsumerProducer(&cfg.Matrix.Kafka)
|
||||
|
||||
queues := queue.NewOutgoingQueues(
|
||||
federationDB, base.ProcessContext,
|
||||
cfg.Matrix.DisableFederation,
|
||||
cfg.Matrix.ServerName, federation, rsAPI, stats,
|
||||
&queue.SigningInfo{
|
||||
KeyID: cfg.Matrix.KeyID,
|
||||
PrivateKey: cfg.Matrix.PrivateKey,
|
||||
ServerName: cfg.Matrix.ServerName,
|
||||
},
|
||||
)
|
||||
|
||||
rsConsumer := consumers.NewOutputRoomEventConsumer(
|
||||
base.ProcessContext, cfg, consumer, queues,
|
||||
federationDB, rsAPI,
|
||||
)
|
||||
if err = rsConsumer.Start(); err != nil {
|
||||
logrus.WithError(err).Panic("failed to start room server consumer")
|
||||
}
|
||||
|
||||
tsConsumer := consumers.NewOutputEDUConsumer(
|
||||
base.ProcessContext, cfg, consumer, queues, federationDB,
|
||||
)
|
||||
if err := tsConsumer.Start(); err != nil {
|
||||
logrus.WithError(err).Panic("failed to start typing server consumer")
|
||||
}
|
||||
keyConsumer := consumers.NewKeyChangeConsumer(
|
||||
base.ProcessContext, &base.Cfg.KeyServer, consumer, queues, federationDB, rsAPI,
|
||||
)
|
||||
if err := keyConsumer.Start(); err != nil {
|
||||
logrus.WithError(err).Panic("failed to start key server consumer")
|
||||
}
|
||||
|
||||
return internal.NewFederationInternalAPI(federationDB, cfg, rsAPI, federation, stats, caches, queues)
|
||||
}
|
||||
|
|
320
federationapi/federationapi_keys_test.go
Normal file
320
federationapi/federationapi_keys_test.go
Normal file
|
@ -0,0 +1,320 @@
|
|||
package federationapi
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/federationapi/api"
|
||||
"github.com/matrix-org/dendrite/federationapi/routing"
|
||||
"github.com/matrix-org/dendrite/internal/caching"
|
||||
"github.com/matrix-org/dendrite/setup/base"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
type server struct {
|
||||
name gomatrixserverlib.ServerName // server name
|
||||
validity time.Duration // key validity duration from now
|
||||
config *config.FederationAPI // skeleton config, from TestMain
|
||||
fedclient *gomatrixserverlib.FederationClient // uses MockRoundTripper
|
||||
cache *caching.Caches // server-specific cache
|
||||
api api.FederationInternalAPI // server-specific server key API
|
||||
}
|
||||
|
||||
func (s *server) renew() {
|
||||
// This updates the validity period to be an hour in the
|
||||
// future, which is particularly useful in server A and
|
||||
// server C's cases which have validity either as now or
|
||||
// in the past.
|
||||
s.validity = time.Hour
|
||||
s.config.Matrix.KeyValidityPeriod = s.validity
|
||||
}
|
||||
|
||||
var (
|
||||
serverKeyID = gomatrixserverlib.KeyID("ed25519:auto")
|
||||
serverA = &server{name: "a.com", validity: time.Duration(0)} // expires now
|
||||
serverB = &server{name: "b.com", validity: time.Hour} // expires in an hour
|
||||
serverC = &server{name: "c.com", validity: -time.Hour} // expired an hour ago
|
||||
)
|
||||
|
||||
var servers = map[string]*server{
|
||||
"a.com": serverA,
|
||||
"b.com": serverB,
|
||||
"c.com": serverC,
|
||||
}
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
// Set up the server key API for each "server" that we
|
||||
// will use in our tests.
|
||||
for _, s := range servers {
|
||||
// Generate a new key.
|
||||
_, testPriv, err := ed25519.GenerateKey(nil)
|
||||
if err != nil {
|
||||
panic("can't generate identity key: " + err.Error())
|
||||
}
|
||||
|
||||
// Create a new cache but don't enable prometheus!
|
||||
s.cache, err = caching.NewInMemoryLRUCache(false)
|
||||
if err != nil {
|
||||
panic("can't create cache: " + err.Error())
|
||||
}
|
||||
|
||||
// Draw up just enough Dendrite config for the server key
|
||||
// API to work.
|
||||
cfg := &config.Dendrite{}
|
||||
cfg.Defaults()
|
||||
cfg.Global.ServerName = gomatrixserverlib.ServerName(s.name)
|
||||
cfg.Global.PrivateKey = testPriv
|
||||
cfg.Global.Kafka.UseNaffka = true
|
||||
cfg.Global.KeyID = serverKeyID
|
||||
cfg.Global.KeyValidityPeriod = s.validity
|
||||
cfg.FederationAPI.Database.ConnectionString = config.DataSource("file::memory:")
|
||||
s.config = &cfg.FederationAPI
|
||||
|
||||
// Create a transport which redirects federation requests to
|
||||
// the mock round tripper. Since we're not *really* listening for
|
||||
// federation requests then this will return the key instead.
|
||||
transport := &http.Transport{}
|
||||
transport.RegisterProtocol("matrix", &MockRoundTripper{})
|
||||
|
||||
// Create the federation client.
|
||||
s.fedclient = gomatrixserverlib.NewFederationClient(
|
||||
s.config.Matrix.ServerName, serverKeyID, testPriv,
|
||||
gomatrixserverlib.WithTransport(transport),
|
||||
)
|
||||
|
||||
// Finally, build the server key APIs.
|
||||
sbase := base.NewBaseDendrite(cfg, "Monolith", base.NoCacheMetrics)
|
||||
s.api = NewInternalAPI(sbase, s.fedclient, nil, s.cache, true)
|
||||
}
|
||||
|
||||
// Now that we have built our server key APIs, start the
|
||||
// rest of the tests.
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
type MockRoundTripper struct{}
|
||||
|
||||
func (m *MockRoundTripper) RoundTrip(req *http.Request) (res *http.Response, err error) {
|
||||
// Check if the request is looking for keys from a server that
|
||||
// we know about in the test. The only reason this should go wrong
|
||||
// is if the test is broken.
|
||||
s, ok := servers[req.Host]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("server not known: %s", req.Host)
|
||||
}
|
||||
|
||||
// We're intercepting /matrix/key/v2/server requests here, so check
|
||||
// that the URL supplied in the request is for that.
|
||||
if req.URL.Path != "/_matrix/key/v2/server" {
|
||||
return nil, fmt.Errorf("unexpected request path: %s", req.URL.Path)
|
||||
}
|
||||
|
||||
// Get the keys and JSON-ify them.
|
||||
keys := routing.LocalKeys(s.config)
|
||||
body, err := json.MarshalIndent(keys.JSON, "", " ")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// And respond.
|
||||
res = &http.Response{
|
||||
StatusCode: 200,
|
||||
Body: ioutil.NopCloser(bytes.NewReader(body)),
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func TestServersRequestOwnKeys(t *testing.T) {
|
||||
// Each server will request its own keys. There's no reason
|
||||
// for this to fail as each server should know its own keys.
|
||||
|
||||
for name, s := range servers {
|
||||
req := gomatrixserverlib.PublicKeyLookupRequest{
|
||||
ServerName: s.name,
|
||||
KeyID: serverKeyID,
|
||||
}
|
||||
res, err := s.api.FetchKeys(
|
||||
context.Background(),
|
||||
map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp{
|
||||
req: gomatrixserverlib.AsTimestamp(time.Now()),
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("server could not fetch own key: %s", err)
|
||||
}
|
||||
if _, ok := res[req]; !ok {
|
||||
t.Fatalf("server didn't return its own key in the results")
|
||||
}
|
||||
t.Logf("%s's key expires at %s\n", name, res[req].ValidUntilTS.Time())
|
||||
}
|
||||
}
|
||||
|
||||
func TestCachingBehaviour(t *testing.T) {
|
||||
// Server A will request Server B's key, which has a validity
|
||||
// period of an hour from now. We should retrieve the key and
|
||||
// it should make it into the cache automatically.
|
||||
|
||||
req := gomatrixserverlib.PublicKeyLookupRequest{
|
||||
ServerName: serverB.name,
|
||||
KeyID: serverKeyID,
|
||||
}
|
||||
ts := gomatrixserverlib.AsTimestamp(time.Now())
|
||||
|
||||
res, err := serverA.api.FetchKeys(
|
||||
context.Background(),
|
||||
map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp{
|
||||
req: ts,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("server A failed to retrieve server B key: %s", err)
|
||||
}
|
||||
if len(res) != 1 {
|
||||
t.Fatalf("server B should have returned one key but instead returned %d keys", len(res))
|
||||
}
|
||||
if _, ok := res[req]; !ok {
|
||||
t.Fatalf("server B isn't included in the key fetch response")
|
||||
}
|
||||
|
||||
// At this point, if the previous key request was a success,
|
||||
// then the cache should now contain the key. Check if that's
|
||||
// the case - if it isn't then there's something wrong with
|
||||
// the cache implementation or we failed to get the key.
|
||||
|
||||
cres, ok := serverA.cache.GetServerKey(req, ts)
|
||||
if !ok {
|
||||
t.Fatalf("server B key should be in cache but isn't")
|
||||
}
|
||||
if !reflect.DeepEqual(cres, res[req]) {
|
||||
t.Fatalf("the cached result from server B wasn't what server B gave us")
|
||||
}
|
||||
|
||||
// If we ask the cache for the same key but this time for an event
|
||||
// that happened in +30 minutes. Since the validity period is for
|
||||
// another hour, then we should get a response back from the cache.
|
||||
|
||||
_, ok = serverA.cache.GetServerKey(
|
||||
req,
|
||||
gomatrixserverlib.AsTimestamp(time.Now().Add(time.Minute*30)),
|
||||
)
|
||||
if !ok {
|
||||
t.Fatalf("server B key isn't in cache when it should be (+30 minutes)")
|
||||
}
|
||||
|
||||
// If we ask the cache for the same key but this time for an event
|
||||
// that happened in +90 minutes then we should expect to get no
|
||||
// cache result. This is because the cache shouldn't return a result
|
||||
// that is obviously past the validity of the event.
|
||||
|
||||
_, ok = serverA.cache.GetServerKey(
|
||||
req,
|
||||
gomatrixserverlib.AsTimestamp(time.Now().Add(time.Minute*90)),
|
||||
)
|
||||
if ok {
|
||||
t.Fatalf("server B key is in cache when it shouldn't be (+90 minutes)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRenewalBehaviour(t *testing.T) {
|
||||
// Server A will request Server C's key but their validity period
|
||||
// is an hour in the past. We'll retrieve the key as, even though it's
|
||||
// past its validity, it will be able to verify past events.
|
||||
|
||||
req := gomatrixserverlib.PublicKeyLookupRequest{
|
||||
ServerName: serverC.name,
|
||||
KeyID: serverKeyID,
|
||||
}
|
||||
|
||||
res, err := serverA.api.FetchKeys(
|
||||
context.Background(),
|
||||
map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp{
|
||||
req: gomatrixserverlib.AsTimestamp(time.Now()),
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("server A failed to retrieve server C key: %s", err)
|
||||
}
|
||||
if len(res) != 1 {
|
||||
t.Fatalf("server C should have returned one key but instead returned %d keys", len(res))
|
||||
}
|
||||
if _, ok := res[req]; !ok {
|
||||
t.Fatalf("server C isn't included in the key fetch response")
|
||||
}
|
||||
|
||||
// If we ask the cache for the server key for an event that happened
|
||||
// 90 minutes ago then we should get a cache result, as the key hadn't
|
||||
// passed its validity by that point. The fact that the key is now in
|
||||
// the cache is, in itself, proof that we successfully retrieved the
|
||||
// key before.
|
||||
|
||||
oldcached, ok := serverA.cache.GetServerKey(
|
||||
req,
|
||||
gomatrixserverlib.AsTimestamp(time.Now().Add(-time.Minute*90)),
|
||||
)
|
||||
if !ok {
|
||||
t.Fatalf("server C key isn't in cache when it should be (-90 minutes)")
|
||||
}
|
||||
|
||||
// If we now ask the cache for the same key but this time for an event
|
||||
// that only happened 30 minutes ago then we shouldn't get a cached
|
||||
// result, as the event happened after the key validity expired. This
|
||||
// is really just for sanity checking.
|
||||
|
||||
_, ok = serverA.cache.GetServerKey(
|
||||
req,
|
||||
gomatrixserverlib.AsTimestamp(time.Now().Add(-time.Minute*30)),
|
||||
)
|
||||
if ok {
|
||||
t.Fatalf("server B key is in cache when it shouldn't be (-30 minutes)")
|
||||
}
|
||||
|
||||
// We're now going to kick server C into renewing its key. Since we're
|
||||
// happy at this point that the key that we already have is from the past
|
||||
// then repeating a key fetch should cause us to try and renew the key.
|
||||
// If so, then the new key will end up in our cache.
|
||||
|
||||
serverC.renew()
|
||||
|
||||
res, err = serverA.api.FetchKeys(
|
||||
context.Background(),
|
||||
map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp{
|
||||
req: gomatrixserverlib.AsTimestamp(time.Now()),
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("server A failed to retrieve server C key: %s", err)
|
||||
}
|
||||
if len(res) != 1 {
|
||||
t.Fatalf("server C should have returned one key but instead returned %d keys", len(res))
|
||||
}
|
||||
if _, ok = res[req]; !ok {
|
||||
t.Fatalf("server C isn't included in the key fetch response")
|
||||
}
|
||||
|
||||
// We're now going to ask the cache what the new key validity is. If
|
||||
// it is still the same as the previous validity then we've failed to
|
||||
// retrieve the renewed key. If it's newer then we've successfully got
|
||||
// the renewed key.
|
||||
|
||||
newcached, ok := serverA.cache.GetServerKey(
|
||||
req,
|
||||
gomatrixserverlib.AsTimestamp(time.Now().Add(-time.Minute*30)),
|
||||
)
|
||||
if !ok {
|
||||
t.Fatalf("server B key isn't in cache when it shouldn't be (post-renewal)")
|
||||
}
|
||||
if oldcached.ValidUntilTS >= newcached.ValidUntilTS {
|
||||
t.Fatalf("the server B key should have been renewed but wasn't")
|
||||
}
|
||||
t.Log(res)
|
||||
}
|
|
@ -8,7 +8,7 @@ import (
|
|||
|
||||
"github.com/matrix-org/dendrite/federationapi"
|
||||
"github.com/matrix-org/dendrite/internal/test"
|
||||
"github.com/matrix-org/dendrite/setup"
|
||||
"github.com/matrix-org/dendrite/setup/base"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/gomatrix"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
@ -25,10 +25,10 @@ func TestRoomsV3URLEscapeDoNot404(t *testing.T) {
|
|||
cfg.Global.PrivateKey = privKey
|
||||
cfg.Global.Kafka.UseNaffka = true
|
||||
cfg.Global.Kafka.Database.ConnectionString = config.DataSource("file::memory:")
|
||||
cfg.FederationSender.Database.ConnectionString = config.DataSource("file::memory:")
|
||||
base := setup.NewBaseDendrite(cfg, "Monolith", false)
|
||||
cfg.FederationAPI.Database.ConnectionString = config.DataSource("file::memory:")
|
||||
base := base.NewBaseDendrite(cfg, "Monolith", base.NoCacheMetrics)
|
||||
keyRing := &test.NopJSONVerifier{}
|
||||
fsAPI := base.FederationSenderHTTPClient()
|
||||
fsAPI := base.FederationAPIHTTPClient()
|
||||
// TODO: This is pretty fragile, as if anything calls anything on these nils this test will break.
|
||||
// Unfortunately, it makes little sense to instantiate these dependencies when we just want to test routing.
|
||||
federationapi.AddPublicRoutes(base.PublicFederationAPIMux, base.PublicKeyAPIMux, base.PublicWellKnownAPIMux, &cfg.FederationAPI, nil, nil, keyRing, nil, fsAPI, nil, nil, &cfg.MSCs, nil)
|
||||
|
|
304
federationapi/internal/api.go
Normal file
304
federationapi/internal/api.go
Normal file
|
@ -0,0 +1,304 @@
|
|||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"encoding/base64"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/federationapi/api"
|
||||
"github.com/matrix-org/dendrite/federationapi/queue"
|
||||
"github.com/matrix-org/dendrite/federationapi/statistics"
|
||||
"github.com/matrix-org/dendrite/federationapi/storage"
|
||||
"github.com/matrix-org/dendrite/federationapi/storage/cache"
|
||||
"github.com/matrix-org/dendrite/internal/caching"
|
||||
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/gomatrix"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// FederationInternalAPI is an implementation of api.FederationInternalAPI
|
||||
type FederationInternalAPI struct {
|
||||
db storage.Database
|
||||
cfg *config.FederationAPI
|
||||
statistics *statistics.Statistics
|
||||
rsAPI roomserverAPI.RoomserverInternalAPI
|
||||
federation *gomatrixserverlib.FederationClient
|
||||
keyRing *gomatrixserverlib.KeyRing
|
||||
queues *queue.OutgoingQueues
|
||||
joins sync.Map // joins currently in progress
|
||||
}
|
||||
|
||||
func NewFederationInternalAPI(
|
||||
db storage.Database, cfg *config.FederationAPI,
|
||||
rsAPI roomserverAPI.RoomserverInternalAPI,
|
||||
federation *gomatrixserverlib.FederationClient,
|
||||
statistics *statistics.Statistics,
|
||||
caches *caching.Caches,
|
||||
queues *queue.OutgoingQueues,
|
||||
) *FederationInternalAPI {
|
||||
serverKeyDB, err := cache.NewKeyDatabase(db, caches)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Panicf("failed to set up caching wrapper for server key database")
|
||||
}
|
||||
|
||||
keyRing := &gomatrixserverlib.KeyRing{
|
||||
KeyFetchers: []gomatrixserverlib.KeyFetcher{},
|
||||
KeyDatabase: serverKeyDB,
|
||||
}
|
||||
|
||||
addDirectFetcher := func() {
|
||||
keyRing.KeyFetchers = append(
|
||||
keyRing.KeyFetchers,
|
||||
&gomatrixserverlib.DirectKeyFetcher{
|
||||
Client: federation,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
if cfg.PreferDirectFetch {
|
||||
addDirectFetcher()
|
||||
} else {
|
||||
defer addDirectFetcher()
|
||||
}
|
||||
|
||||
var b64e = base64.StdEncoding.WithPadding(base64.NoPadding)
|
||||
for _, ps := range cfg.KeyPerspectives {
|
||||
perspective := &gomatrixserverlib.PerspectiveKeyFetcher{
|
||||
PerspectiveServerName: ps.ServerName,
|
||||
PerspectiveServerKeys: map[gomatrixserverlib.KeyID]ed25519.PublicKey{},
|
||||
Client: federation,
|
||||
}
|
||||
|
||||
for _, key := range ps.Keys {
|
||||
rawkey, err := b64e.DecodeString(key.PublicKey)
|
||||
if err != nil {
|
||||
logrus.WithError(err).WithFields(logrus.Fields{
|
||||
"server_name": ps.ServerName,
|
||||
"public_key": key.PublicKey,
|
||||
}).Warn("Couldn't parse perspective key")
|
||||
continue
|
||||
}
|
||||
perspective.PerspectiveServerKeys[key.KeyID] = rawkey
|
||||
}
|
||||
|
||||
keyRing.KeyFetchers = append(keyRing.KeyFetchers, perspective)
|
||||
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"server_name": ps.ServerName,
|
||||
"num_public_keys": len(ps.Keys),
|
||||
}).Info("Enabled perspective key fetcher")
|
||||
}
|
||||
|
||||
return &FederationInternalAPI{
|
||||
db: db,
|
||||
cfg: cfg,
|
||||
rsAPI: rsAPI,
|
||||
keyRing: keyRing,
|
||||
federation: federation,
|
||||
statistics: statistics,
|
||||
queues: queues,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *FederationInternalAPI) isBlacklistedOrBackingOff(s gomatrixserverlib.ServerName) (*statistics.ServerStatistics, error) {
|
||||
stats := a.statistics.ForServer(s)
|
||||
until, blacklisted := stats.BackoffInfo()
|
||||
if blacklisted {
|
||||
return stats, &api.FederationClientError{
|
||||
Blacklisted: true,
|
||||
}
|
||||
}
|
||||
now := time.Now()
|
||||
if until != nil && now.Before(*until) {
|
||||
return stats, &api.FederationClientError{
|
||||
RetryAfter: time.Until(*until),
|
||||
}
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func failBlacklistableError(err error, stats *statistics.ServerStatistics) (until time.Time, blacklisted bool) {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
mxerr, ok := err.(gomatrix.HTTPError)
|
||||
if !ok {
|
||||
return stats.Failure()
|
||||
}
|
||||
if mxerr.Code == 401 { // invalid signature in X-Matrix header
|
||||
return stats.Failure()
|
||||
}
|
||||
if mxerr.Code >= 500 && mxerr.Code < 600 { // internal server errors
|
||||
return stats.Failure()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (a *FederationInternalAPI) doRequest(
|
||||
s gomatrixserverlib.ServerName, request func() (interface{}, error),
|
||||
) (interface{}, error) {
|
||||
stats, err := a.isBlacklistedOrBackingOff(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res, err := request()
|
||||
if err != nil {
|
||||
until, blacklisted := failBlacklistableError(err, stats)
|
||||
now := time.Now()
|
||||
var retryAfter time.Duration
|
||||
if until.After(now) {
|
||||
retryAfter = time.Until(until)
|
||||
}
|
||||
return res, &api.FederationClientError{
|
||||
Err: err.Error(),
|
||||
Blacklisted: blacklisted,
|
||||
RetryAfter: retryAfter,
|
||||
}
|
||||
}
|
||||
stats.Success()
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (a *FederationInternalAPI) GetUserDevices(
|
||||
ctx context.Context, s gomatrixserverlib.ServerName, userID string,
|
||||
) (gomatrixserverlib.RespUserDevices, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Second*30)
|
||||
defer cancel()
|
||||
ires, err := a.doRequest(s, func() (interface{}, error) {
|
||||
return a.federation.GetUserDevices(ctx, s, userID)
|
||||
})
|
||||
if err != nil {
|
||||
return gomatrixserverlib.RespUserDevices{}, err
|
||||
}
|
||||
return ires.(gomatrixserverlib.RespUserDevices), nil
|
||||
}
|
||||
|
||||
func (a *FederationInternalAPI) ClaimKeys(
|
||||
ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string,
|
||||
) (gomatrixserverlib.RespClaimKeys, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Second*30)
|
||||
defer cancel()
|
||||
ires, err := a.doRequest(s, func() (interface{}, error) {
|
||||
return a.federation.ClaimKeys(ctx, s, oneTimeKeys)
|
||||
})
|
||||
if err != nil {
|
||||
return gomatrixserverlib.RespClaimKeys{}, err
|
||||
}
|
||||
return ires.(gomatrixserverlib.RespClaimKeys), nil
|
||||
}
|
||||
|
||||
func (a *FederationInternalAPI) QueryKeys(
|
||||
ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string,
|
||||
) (gomatrixserverlib.RespQueryKeys, error) {
|
||||
ires, err := a.doRequest(s, func() (interface{}, error) {
|
||||
return a.federation.QueryKeys(ctx, s, keys)
|
||||
})
|
||||
if err != nil {
|
||||
return gomatrixserverlib.RespQueryKeys{}, err
|
||||
}
|
||||
return ires.(gomatrixserverlib.RespQueryKeys), nil
|
||||
}
|
||||
|
||||
func (a *FederationInternalAPI) Backfill(
|
||||
ctx context.Context, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string,
|
||||
) (res gomatrixserverlib.Transaction, err error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Second*30)
|
||||
defer cancel()
|
||||
ires, err := a.doRequest(s, func() (interface{}, error) {
|
||||
return a.federation.Backfill(ctx, s, roomID, limit, eventIDs)
|
||||
})
|
||||
if err != nil {
|
||||
return gomatrixserverlib.Transaction{}, err
|
||||
}
|
||||
return ires.(gomatrixserverlib.Transaction), nil
|
||||
}
|
||||
|
||||
func (a *FederationInternalAPI) LookupState(
|
||||
ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion,
|
||||
) (res gomatrixserverlib.RespState, err error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Second*30)
|
||||
defer cancel()
|
||||
ires, err := a.doRequest(s, func() (interface{}, error) {
|
||||
return a.federation.LookupState(ctx, s, roomID, eventID, roomVersion)
|
||||
})
|
||||
if err != nil {
|
||||
return gomatrixserverlib.RespState{}, err
|
||||
}
|
||||
return ires.(gomatrixserverlib.RespState), nil
|
||||
}
|
||||
|
||||
func (a *FederationInternalAPI) LookupStateIDs(
|
||||
ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string,
|
||||
) (res gomatrixserverlib.RespStateIDs, err error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Second*30)
|
||||
defer cancel()
|
||||
ires, err := a.doRequest(s, func() (interface{}, error) {
|
||||
return a.federation.LookupStateIDs(ctx, s, roomID, eventID)
|
||||
})
|
||||
if err != nil {
|
||||
return gomatrixserverlib.RespStateIDs{}, err
|
||||
}
|
||||
return ires.(gomatrixserverlib.RespStateIDs), nil
|
||||
}
|
||||
|
||||
func (a *FederationInternalAPI) GetEvent(
|
||||
ctx context.Context, s gomatrixserverlib.ServerName, eventID string,
|
||||
) (res gomatrixserverlib.Transaction, err error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Second*30)
|
||||
defer cancel()
|
||||
ires, err := a.doRequest(s, func() (interface{}, error) {
|
||||
return a.federation.GetEvent(ctx, s, eventID)
|
||||
})
|
||||
if err != nil {
|
||||
return gomatrixserverlib.Transaction{}, err
|
||||
}
|
||||
return ires.(gomatrixserverlib.Transaction), nil
|
||||
}
|
||||
|
||||
func (a *FederationInternalAPI) LookupServerKeys(
|
||||
ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
|
||||
) ([]gomatrixserverlib.ServerKeys, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Minute)
|
||||
defer cancel()
|
||||
ires, err := a.doRequest(s, func() (interface{}, error) {
|
||||
return a.federation.LookupServerKeys(ctx, s, keyRequests)
|
||||
})
|
||||
if err != nil {
|
||||
return []gomatrixserverlib.ServerKeys{}, err
|
||||
}
|
||||
return ires.([]gomatrixserverlib.ServerKeys), nil
|
||||
}
|
||||
|
||||
func (a *FederationInternalAPI) MSC2836EventRelationships(
|
||||
ctx context.Context, s gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest,
|
||||
roomVersion gomatrixserverlib.RoomVersion,
|
||||
) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Minute)
|
||||
defer cancel()
|
||||
ires, err := a.doRequest(s, func() (interface{}, error) {
|
||||
return a.federation.MSC2836EventRelationships(ctx, s, r, roomVersion)
|
||||
})
|
||||
if err != nil {
|
||||
return res, err
|
||||
}
|
||||
return ires.(gomatrixserverlib.MSC2836EventRelationshipsResponse), nil
|
||||
}
|
||||
|
||||
func (a *FederationInternalAPI) MSC2946Spaces(
|
||||
ctx context.Context, s gomatrixserverlib.ServerName, roomID string, r gomatrixserverlib.MSC2946SpacesRequest,
|
||||
) (res gomatrixserverlib.MSC2946SpacesResponse, err error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Minute)
|
||||
defer cancel()
|
||||
ires, err := a.doRequest(s, func() (interface{}, error) {
|
||||
return a.federation.MSC2946Spaces(ctx, s, roomID, r)
|
||||
})
|
||||
if err != nil {
|
||||
return res, err
|
||||
}
|
||||
return ires.(gomatrixserverlib.MSC2946SpacesResponse), nil
|
||||
}
|
248
federationapi/internal/keys.go
Normal file
248
federationapi/internal/keys.go
Normal file
|
@ -0,0 +1,248 @@
|
|||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func (s *FederationInternalAPI) KeyRing() *gomatrixserverlib.KeyRing {
|
||||
// Return a keyring that forces requests to be proxied through the
|
||||
// below functions. That way we can enforce things like validity
|
||||
// and keeping the cache up-to-date.
|
||||
return s.keyRing
|
||||
}
|
||||
|
||||
func (s *FederationInternalAPI) StoreKeys(
|
||||
_ context.Context,
|
||||
results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult,
|
||||
) error {
|
||||
// Run in a background context - we don't want to stop this work just
|
||||
// because the caller gives up waiting.
|
||||
ctx := context.Background()
|
||||
|
||||
// Store any keys that we were given in our database.
|
||||
return s.keyRing.KeyDatabase.StoreKeys(ctx, results)
|
||||
}
|
||||
|
||||
func (s *FederationInternalAPI) FetchKeys(
|
||||
_ context.Context,
|
||||
requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
|
||||
) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) {
|
||||
// Run in a background context - we don't want to stop this work just
|
||||
// because the caller gives up waiting.
|
||||
ctx := context.Background()
|
||||
now := gomatrixserverlib.AsTimestamp(time.Now())
|
||||
results := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{}
|
||||
origRequests := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp{}
|
||||
for k, v := range requests {
|
||||
origRequests[k] = v
|
||||
}
|
||||
|
||||
// First, check if any of these key checks are for our own keys. If
|
||||
// they are then we will satisfy them directly.
|
||||
s.handleLocalKeys(ctx, requests, results)
|
||||
|
||||
// Then consult our local database and see if we have the requested
|
||||
// keys. These might come from a cache, depending on the database
|
||||
// implementation used.
|
||||
if err := s.handleDatabaseKeys(ctx, now, requests, results); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// For any key requests that we still have outstanding, next try to
|
||||
// fetch them directly. We'll go through each of the key fetchers to
|
||||
// ask for the remaining keys
|
||||
for _, fetcher := range s.keyRing.KeyFetchers {
|
||||
// If there are no more keys to look up then stop.
|
||||
if len(requests) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
// Ask the fetcher to look up our keys.
|
||||
if err := s.handleFetcherKeys(ctx, now, fetcher, requests, results); err != nil {
|
||||
logrus.WithError(err).WithFields(logrus.Fields{
|
||||
"fetcher_name": fetcher.FetcherName(),
|
||||
}).Errorf("Failed to retrieve %d key(s)", len(requests))
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Check that we've actually satisfied all of the key requests that we
|
||||
// were given. We should report an error if we didn't.
|
||||
for req := range origRequests {
|
||||
if _, ok := results[req]; !ok {
|
||||
// The results don't contain anything for this specific request, so
|
||||
// we've failed to satisfy it from local keys, database keys or from
|
||||
// all of the fetchers. Report an error.
|
||||
logrus.Warnf("Failed to retrieve key %q for server %q", req.KeyID, req.ServerName)
|
||||
}
|
||||
}
|
||||
|
||||
// Return the keys.
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (s *FederationInternalAPI) FetcherName() string {
|
||||
return fmt.Sprintf("FederationInternalAPI (wrapping %q)", s.keyRing.KeyDatabase.FetcherName())
|
||||
}
|
||||
|
||||
// handleLocalKeys handles cases where the key request contains
|
||||
// a request for our own server keys, either current or old.
|
||||
func (s *FederationInternalAPI) handleLocalKeys(
|
||||
_ context.Context,
|
||||
requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
|
||||
results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult,
|
||||
) {
|
||||
for req := range requests {
|
||||
if req.ServerName != s.cfg.Matrix.ServerName {
|
||||
continue
|
||||
}
|
||||
if req.KeyID == s.cfg.Matrix.KeyID {
|
||||
// We found a key request that is supposed to be for our own
|
||||
// keys. Remove it from the request list so we don't hit the
|
||||
// database or the fetchers for it.
|
||||
delete(requests, req)
|
||||
|
||||
// Insert our own key into the response.
|
||||
results[req] = gomatrixserverlib.PublicKeyLookupResult{
|
||||
VerifyKey: gomatrixserverlib.VerifyKey{
|
||||
Key: gomatrixserverlib.Base64Bytes(s.cfg.Matrix.PrivateKey.Public().(ed25519.PublicKey)),
|
||||
},
|
||||
ExpiredTS: gomatrixserverlib.PublicKeyNotExpired,
|
||||
ValidUntilTS: gomatrixserverlib.AsTimestamp(time.Now().Add(s.cfg.Matrix.KeyValidityPeriod)),
|
||||
}
|
||||
} else {
|
||||
// The key request doesn't match our current key. Let's see
|
||||
// if it matches any of our old verify keys.
|
||||
for _, oldVerifyKey := range s.cfg.Matrix.OldVerifyKeys {
|
||||
if req.KeyID == oldVerifyKey.KeyID {
|
||||
// We found a key request that is supposed to be an expired
|
||||
// key.
|
||||
delete(requests, req)
|
||||
|
||||
// Insert our own key into the response.
|
||||
results[req] = gomatrixserverlib.PublicKeyLookupResult{
|
||||
VerifyKey: gomatrixserverlib.VerifyKey{
|
||||
Key: gomatrixserverlib.Base64Bytes(oldVerifyKey.PrivateKey.Public().(ed25519.PublicKey)),
|
||||
},
|
||||
ExpiredTS: oldVerifyKey.ExpiredAt,
|
||||
ValidUntilTS: gomatrixserverlib.PublicKeyNotValid,
|
||||
}
|
||||
|
||||
// No need to look at the other keys.
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleDatabaseKeys handles cases where the key requests can be
|
||||
// satisfied from our local database/cache.
|
||||
func (s *FederationInternalAPI) handleDatabaseKeys(
|
||||
ctx context.Context,
|
||||
now gomatrixserverlib.Timestamp,
|
||||
requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
|
||||
results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult,
|
||||
) error {
|
||||
// Ask the database/cache for the keys.
|
||||
dbResults, err := s.keyRing.KeyDatabase.FetchKeys(ctx, requests)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// We successfully got some keys. Add them to the results.
|
||||
for req, res := range dbResults {
|
||||
// The key we've retrieved from the database/cache might
|
||||
// have passed its validity period, but right now, it's
|
||||
// the best thing we've got, and it might be sufficient to
|
||||
// verify a past event.
|
||||
results[req] = res
|
||||
|
||||
// If the key is valid right now then we can also remove it
|
||||
// from the request list as we don't need to fetch it again
|
||||
// in that case. If the key isn't valid right now, then by
|
||||
// leaving it in the 'requests' map, we'll try to update the
|
||||
// key using the fetchers in handleFetcherKeys.
|
||||
if res.WasValidAt(now, true) {
|
||||
delete(requests, req)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleFetcherKeys handles cases where a fetcher can satisfy
|
||||
// the remaining requests.
|
||||
func (s *FederationInternalAPI) handleFetcherKeys(
|
||||
ctx context.Context,
|
||||
_ gomatrixserverlib.Timestamp,
|
||||
fetcher gomatrixserverlib.KeyFetcher,
|
||||
requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
|
||||
results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult,
|
||||
) error {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"fetcher_name": fetcher.FetcherName(),
|
||||
}).Infof("Fetching %d key(s)", len(requests))
|
||||
|
||||
// Create a context that limits our requests to 30 seconds.
|
||||
fetcherCtx, fetcherCancel := context.WithTimeout(ctx, time.Second*30)
|
||||
defer fetcherCancel()
|
||||
|
||||
// Try to fetch the keys.
|
||||
fetcherResults, err := fetcher.FetchKeys(fetcherCtx, requests)
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetcher.FetchKeys: %w", err)
|
||||
}
|
||||
|
||||
// Build a map of the results that we want to commit to the
|
||||
// database. We do this in a separate map because otherwise we
|
||||
// might end up trying to rewrite database entries.
|
||||
storeResults := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{}
|
||||
|
||||
// Now let's look at the results that we got from this fetcher.
|
||||
for req, res := range fetcherResults {
|
||||
if prev, ok := results[req]; ok {
|
||||
// We've already got a previous entry for this request
|
||||
// so let's see if the newly retrieved one contains a more
|
||||
// up-to-date validity period.
|
||||
if res.ValidUntilTS > prev.ValidUntilTS {
|
||||
// This key is newer than the one we had so let's store
|
||||
// it in the database.
|
||||
storeResults[req] = res
|
||||
}
|
||||
} else {
|
||||
// We didn't already have a previous entry for this request
|
||||
// so store it in the database anyway for now.
|
||||
storeResults[req] = res
|
||||
}
|
||||
|
||||
// Update the results map with this new result. If nothing
|
||||
// else, we can try verifying against this key.
|
||||
results[req] = res
|
||||
|
||||
// Remove it from the request list so we won't re-fetch it.
|
||||
delete(requests, req)
|
||||
}
|
||||
|
||||
// Store the keys from our store map.
|
||||
if err = s.keyRing.KeyDatabase.StoreKeys(context.Background(), storeResults); err != nil {
|
||||
logrus.WithError(err).WithFields(logrus.Fields{
|
||||
"fetcher_name": fetcher.FetcherName(),
|
||||
"database_name": s.keyRing.KeyDatabase.FetcherName(),
|
||||
}).Errorf("Failed to store keys in the database")
|
||||
return fmt.Errorf("server key API failed to store retrieved keys: %w", err)
|
||||
}
|
||||
|
||||
if len(storeResults) > 0 {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"fetcher_name": fetcher.FetcherName(),
|
||||
}).Infof("Updated %d of %d key(s) in database (%d keys remaining)", len(storeResults), len(results), len(requests))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
727
federationapi/internal/perform.go
Normal file
727
federationapi/internal/perform.go
Normal file
|
@ -0,0 +1,727 @@
|
|||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/federationapi/api"
|
||||
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/roomserver/version"
|
||||
"github.com/matrix-org/gomatrix"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// PerformLeaveRequest implements api.FederationInternalAPI
|
||||
func (r *FederationInternalAPI) PerformDirectoryLookup(
|
||||
ctx context.Context,
|
||||
request *api.PerformDirectoryLookupRequest,
|
||||
response *api.PerformDirectoryLookupResponse,
|
||||
) (err error) {
|
||||
dir, err := r.federation.LookupRoomAlias(
|
||||
ctx,
|
||||
request.ServerName,
|
||||
request.RoomAlias,
|
||||
)
|
||||
if err != nil {
|
||||
r.statistics.ForServer(request.ServerName).Failure()
|
||||
return err
|
||||
}
|
||||
response.RoomID = dir.RoomID
|
||||
response.ServerNames = dir.Servers
|
||||
r.statistics.ForServer(request.ServerName).Success()
|
||||
return nil
|
||||
}
|
||||
|
||||
type federatedJoin struct {
|
||||
UserID string
|
||||
RoomID string
|
||||
}
|
||||
|
||||
// PerformJoin implements api.FederationInternalAPI
|
||||
func (r *FederationInternalAPI) PerformJoin(
|
||||
ctx context.Context,
|
||||
request *api.PerformJoinRequest,
|
||||
response *api.PerformJoinResponse,
|
||||
) {
|
||||
// Check that a join isn't already in progress for this user/room.
|
||||
j := federatedJoin{request.UserID, request.RoomID}
|
||||
if _, found := r.joins.Load(j); found {
|
||||
response.LastError = &gomatrix.HTTPError{
|
||||
Code: 429,
|
||||
Message: `{
|
||||
"errcode": "M_LIMIT_EXCEEDED",
|
||||
"error": "There is already a federated join to this room in progress. Please wait for it to finish."
|
||||
}`, // TODO: Why do none of our error types play nicely with each other?
|
||||
}
|
||||
return
|
||||
}
|
||||
r.joins.Store(j, nil)
|
||||
defer r.joins.Delete(j)
|
||||
|
||||
// Look up the supported room versions.
|
||||
var supportedVersions []gomatrixserverlib.RoomVersion
|
||||
for version := range version.SupportedRoomVersions() {
|
||||
supportedVersions = append(supportedVersions, version)
|
||||
}
|
||||
|
||||
// Deduplicate the server names we were provided but keep the ordering
|
||||
// as this encodes useful information about which servers are most likely
|
||||
// to respond.
|
||||
seenSet := make(map[gomatrixserverlib.ServerName]bool)
|
||||
var uniqueList []gomatrixserverlib.ServerName
|
||||
for _, srv := range request.ServerNames {
|
||||
if seenSet[srv] {
|
||||
continue
|
||||
}
|
||||
seenSet[srv] = true
|
||||
uniqueList = append(uniqueList, srv)
|
||||
}
|
||||
request.ServerNames = uniqueList
|
||||
|
||||
// Try each server that we were provided until we land on one that
|
||||
// successfully completes the make-join send-join dance.
|
||||
var lastErr error
|
||||
for _, serverName := range request.ServerNames {
|
||||
if err := r.performJoinUsingServer(
|
||||
ctx,
|
||||
request.RoomID,
|
||||
request.UserID,
|
||||
request.Content,
|
||||
serverName,
|
||||
supportedVersions,
|
||||
); err != nil {
|
||||
logrus.WithError(err).WithFields(logrus.Fields{
|
||||
"server_name": serverName,
|
||||
"room_id": request.RoomID,
|
||||
}).Warnf("Failed to join room through server")
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
|
||||
// We're all good.
|
||||
response.JoinedVia = serverName
|
||||
return
|
||||
}
|
||||
|
||||
// If we reach here then we didn't complete a join for some reason.
|
||||
var httpErr gomatrix.HTTPError
|
||||
if ok := errors.As(lastErr, &httpErr); ok {
|
||||
httpErr.Message = string(httpErr.Contents)
|
||||
// Clear the wrapped error, else serialising to JSON (in polylith mode) will fail
|
||||
httpErr.WrappedError = nil
|
||||
response.LastError = &httpErr
|
||||
} else {
|
||||
response.LastError = &gomatrix.HTTPError{
|
||||
Code: 0,
|
||||
WrappedError: nil,
|
||||
Message: "Unknown HTTP error",
|
||||
}
|
||||
if lastErr != nil {
|
||||
response.LastError.Message = lastErr.Error()
|
||||
}
|
||||
}
|
||||
|
||||
logrus.Errorf(
|
||||
"failed to join user %q to room %q through %d server(s): last error %s",
|
||||
request.UserID, request.RoomID, len(request.ServerNames), lastErr,
|
||||
)
|
||||
}
|
||||
|
||||
func (r *FederationInternalAPI) performJoinUsingServer(
|
||||
ctx context.Context,
|
||||
roomID, userID string,
|
||||
content map[string]interface{},
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
supportedVersions []gomatrixserverlib.RoomVersion,
|
||||
) error {
|
||||
// Try to perform a make_join using the information supplied in the
|
||||
// request.
|
||||
respMakeJoin, err := r.federation.MakeJoin(
|
||||
ctx,
|
||||
serverName,
|
||||
roomID,
|
||||
userID,
|
||||
supportedVersions,
|
||||
)
|
||||
if err != nil {
|
||||
// TODO: Check if the user was not allowed to join the room.
|
||||
r.statistics.ForServer(serverName).Failure()
|
||||
return fmt.Errorf("r.federation.MakeJoin: %w", err)
|
||||
}
|
||||
r.statistics.ForServer(serverName).Success()
|
||||
|
||||
// Set all the fields to be what they should be, this should be a no-op
|
||||
// but it's possible that the remote server returned us something "odd"
|
||||
respMakeJoin.JoinEvent.Type = gomatrixserverlib.MRoomMember
|
||||
respMakeJoin.JoinEvent.Sender = userID
|
||||
respMakeJoin.JoinEvent.StateKey = &userID
|
||||
respMakeJoin.JoinEvent.RoomID = roomID
|
||||
respMakeJoin.JoinEvent.Redacts = ""
|
||||
if content == nil {
|
||||
content = map[string]interface{}{}
|
||||
}
|
||||
content["membership"] = "join"
|
||||
if err = respMakeJoin.JoinEvent.SetContent(content); err != nil {
|
||||
return fmt.Errorf("respMakeJoin.JoinEvent.SetContent: %w", err)
|
||||
}
|
||||
if err = respMakeJoin.JoinEvent.SetUnsigned(struct{}{}); err != nil {
|
||||
return fmt.Errorf("respMakeJoin.JoinEvent.SetUnsigned: %w", err)
|
||||
}
|
||||
|
||||
// Work out if we support the room version that has been supplied in
|
||||
// the make_join response.
|
||||
// "If not provided, the room version is assumed to be either "1" or "2"."
|
||||
// https://matrix.org/docs/spec/server_server/unstable#get-matrix-federation-v1-make-join-roomid-userid
|
||||
if respMakeJoin.RoomVersion == "" {
|
||||
respMakeJoin.RoomVersion = setDefaultRoomVersionFromJoinEvent(respMakeJoin.JoinEvent)
|
||||
}
|
||||
if _, err = respMakeJoin.RoomVersion.EventFormat(); err != nil {
|
||||
return fmt.Errorf("respMakeJoin.RoomVersion.EventFormat: %w", err)
|
||||
}
|
||||
|
||||
// Build the join event.
|
||||
event, err := respMakeJoin.JoinEvent.Build(
|
||||
time.Now(),
|
||||
r.cfg.Matrix.ServerName,
|
||||
r.cfg.Matrix.KeyID,
|
||||
r.cfg.Matrix.PrivateKey,
|
||||
respMakeJoin.RoomVersion,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("respMakeJoin.JoinEvent.Build: %w", err)
|
||||
}
|
||||
|
||||
// No longer reuse the request context from this point forward.
|
||||
// We don't want the client timing out to interrupt the join.
|
||||
var cancel context.CancelFunc
|
||||
ctx, cancel = context.WithCancel(context.Background())
|
||||
|
||||
// Try to perform a send_join using the newly built event.
|
||||
respSendJoin, err := r.federation.SendJoin(
|
||||
ctx,
|
||||
serverName,
|
||||
event,
|
||||
respMakeJoin.RoomVersion,
|
||||
)
|
||||
if err != nil {
|
||||
r.statistics.ForServer(serverName).Failure()
|
||||
cancel()
|
||||
return fmt.Errorf("r.federation.SendJoin: %w", err)
|
||||
}
|
||||
r.statistics.ForServer(serverName).Success()
|
||||
|
||||
// Sanity-check the join response to ensure that it has a create
|
||||
// event, that the room version is known, etc.
|
||||
if err := sanityCheckAuthChain(respSendJoin.AuthEvents); err != nil {
|
||||
cancel()
|
||||
return fmt.Errorf("sanityCheckAuthChain: %w", err)
|
||||
}
|
||||
|
||||
// Process the join response in a goroutine. The idea here is
|
||||
// that we'll try and wait for as long as possible for the work
|
||||
// to complete, but if the client does give up waiting, we'll
|
||||
// still continue to process the join anyway so that we don't
|
||||
// waste the effort.
|
||||
go func() {
|
||||
defer cancel()
|
||||
|
||||
// TODO: Can we expand Check here to return a list of missing auth
|
||||
// events rather than failing one at a time?
|
||||
respState, err := respSendJoin.Check(ctx, r.keyRing, event, federatedAuthProvider(ctx, r.federation, r.keyRing, serverName))
|
||||
if err != nil {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"room_id": roomID,
|
||||
"user_id": userID,
|
||||
}).WithError(err).Error("Failed to process room join response")
|
||||
return
|
||||
}
|
||||
|
||||
// If we successfully performed a send_join above then the other
|
||||
// server now thinks we're a part of the room. Send the newly
|
||||
// returned state to the roomserver to update our local view.
|
||||
if err = roomserverAPI.SendEventWithState(
|
||||
ctx, r.rsAPI,
|
||||
roomserverAPI.KindNew,
|
||||
respState,
|
||||
event.Headered(respMakeJoin.RoomVersion),
|
||||
nil,
|
||||
); err != nil {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"room_id": roomID,
|
||||
"user_id": userID,
|
||||
}).WithError(err).Error("Failed to send room join response to roomserver")
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
<-ctx.Done()
|
||||
return nil
|
||||
}
|
||||
|
||||
// PerformOutboundPeekRequest implements api.FederationInternalAPI
|
||||
func (r *FederationInternalAPI) PerformOutboundPeek(
|
||||
ctx context.Context,
|
||||
request *api.PerformOutboundPeekRequest,
|
||||
response *api.PerformOutboundPeekResponse,
|
||||
) error {
|
||||
// Look up the supported room versions.
|
||||
var supportedVersions []gomatrixserverlib.RoomVersion
|
||||
for version := range version.SupportedRoomVersions() {
|
||||
supportedVersions = append(supportedVersions, version)
|
||||
}
|
||||
|
||||
// Deduplicate the server names we were provided but keep the ordering
|
||||
// as this encodes useful information about which servers are most likely
|
||||
// to respond.
|
||||
seenSet := make(map[gomatrixserverlib.ServerName]bool)
|
||||
var uniqueList []gomatrixserverlib.ServerName
|
||||
for _, srv := range request.ServerNames {
|
||||
if seenSet[srv] {
|
||||
continue
|
||||
}
|
||||
seenSet[srv] = true
|
||||
uniqueList = append(uniqueList, srv)
|
||||
}
|
||||
request.ServerNames = uniqueList
|
||||
|
||||
// See if there's an existing outbound peek for this room ID with
|
||||
// one of the specified servers.
|
||||
if peeks, err := r.db.GetOutboundPeeks(ctx, request.RoomID); err == nil {
|
||||
for _, peek := range peeks {
|
||||
if _, ok := seenSet[peek.ServerName]; ok {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Try each server that we were provided until we land on one that
|
||||
// successfully completes the peek
|
||||
var lastErr error
|
||||
for _, serverName := range request.ServerNames {
|
||||
if err := r.performOutboundPeekUsingServer(
|
||||
ctx,
|
||||
request.RoomID,
|
||||
serverName,
|
||||
supportedVersions,
|
||||
); err != nil {
|
||||
logrus.WithError(err).WithFields(logrus.Fields{
|
||||
"server_name": serverName,
|
||||
"room_id": request.RoomID,
|
||||
}).Warnf("Failed to peek room through server")
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
|
||||
// We're all good.
|
||||
return nil
|
||||
}
|
||||
|
||||
// If we reach here then we didn't complete a peek for some reason.
|
||||
var httpErr gomatrix.HTTPError
|
||||
if ok := errors.As(lastErr, &httpErr); ok {
|
||||
httpErr.Message = string(httpErr.Contents)
|
||||
// Clear the wrapped error, else serialising to JSON (in polylith mode) will fail
|
||||
httpErr.WrappedError = nil
|
||||
response.LastError = &httpErr
|
||||
} else {
|
||||
response.LastError = &gomatrix.HTTPError{
|
||||
Code: 0,
|
||||
WrappedError: nil,
|
||||
Message: lastErr.Error(),
|
||||
}
|
||||
}
|
||||
|
||||
logrus.Errorf(
|
||||
"failed to peek room %q through %d server(s): last error %s",
|
||||
request.RoomID, len(request.ServerNames), lastErr,
|
||||
)
|
||||
|
||||
return lastErr
|
||||
}
|
||||
|
||||
func (r *FederationInternalAPI) performOutboundPeekUsingServer(
|
||||
ctx context.Context,
|
||||
roomID string,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
supportedVersions []gomatrixserverlib.RoomVersion,
|
||||
) error {
|
||||
// create a unique ID for this peek.
|
||||
// for now we just use the room ID again. In future, if we ever
|
||||
// support concurrent peeks to the same room with different filters
|
||||
// then we would need to disambiguate further.
|
||||
peekID := roomID
|
||||
|
||||
// check whether we're peeking already to try to avoid needlessly
|
||||
// re-peeking on the server. we don't need a transaction for this,
|
||||
// given this is a nice-to-have.
|
||||
outboundPeek, err := r.db.GetOutboundPeek(ctx, serverName, roomID, peekID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
renewing := false
|
||||
if outboundPeek != nil {
|
||||
nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
|
||||
if nowMilli > outboundPeek.RenewedTimestamp+outboundPeek.RenewalInterval {
|
||||
logrus.Infof("stale outbound peek to %s for %s already exists; renewing", serverName, roomID)
|
||||
renewing = true
|
||||
} else {
|
||||
logrus.Infof("live outbound peek to %s for %s already exists", serverName, roomID)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Try to perform an outbound /peek using the information supplied in the
|
||||
// request.
|
||||
respPeek, err := r.federation.Peek(
|
||||
ctx,
|
||||
serverName,
|
||||
roomID,
|
||||
peekID,
|
||||
supportedVersions,
|
||||
)
|
||||
if err != nil {
|
||||
r.statistics.ForServer(serverName).Failure()
|
||||
return fmt.Errorf("r.federation.Peek: %w", err)
|
||||
}
|
||||
r.statistics.ForServer(serverName).Success()
|
||||
|
||||
// Work out if we support the room version that has been supplied in
|
||||
// the peek response.
|
||||
if respPeek.RoomVersion == "" {
|
||||
respPeek.RoomVersion = gomatrixserverlib.RoomVersionV1
|
||||
}
|
||||
if _, err = respPeek.RoomVersion.EventFormat(); err != nil {
|
||||
return fmt.Errorf("respPeek.RoomVersion.EventFormat: %w", err)
|
||||
}
|
||||
|
||||
// we have the peek state now so let's process regardless of whether upstream gives up
|
||||
ctx = context.Background()
|
||||
|
||||
respState := respPeek.ToRespState()
|
||||
// authenticate the state returned (check its auth events etc)
|
||||
// the equivalent of CheckSendJoinResponse()
|
||||
if err = sanityCheckAuthChain(respState.AuthEvents); err != nil {
|
||||
return fmt.Errorf("sanityCheckAuthChain: %w", err)
|
||||
}
|
||||
if err = respState.Check(ctx, r.keyRing, federatedAuthProvider(ctx, r.federation, r.keyRing, serverName)); err != nil {
|
||||
return fmt.Errorf("error checking state returned from peeking: %w", err)
|
||||
}
|
||||
|
||||
// If we've got this far, the remote server is peeking.
|
||||
if renewing {
|
||||
if err = r.db.RenewOutboundPeek(ctx, serverName, roomID, peekID, respPeek.RenewalInterval); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err = r.db.AddOutboundPeek(ctx, serverName, roomID, peekID, respPeek.RenewalInterval); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// logrus.Warnf("got respPeek %#v", respPeek)
|
||||
// Send the newly returned state to the roomserver to update our local view.
|
||||
if err = roomserverAPI.SendEventWithState(
|
||||
ctx, r.rsAPI,
|
||||
roomserverAPI.KindNew,
|
||||
&respState,
|
||||
respPeek.LatestEvent.Headered(respPeek.RoomVersion),
|
||||
nil,
|
||||
); err != nil {
|
||||
return fmt.Errorf("r.producer.SendEventWithState: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PerformLeaveRequest implements api.FederationInternalAPI
|
||||
func (r *FederationInternalAPI) PerformLeave(
|
||||
ctx context.Context,
|
||||
request *api.PerformLeaveRequest,
|
||||
response *api.PerformLeaveResponse,
|
||||
) (err error) {
|
||||
// Deduplicate the server names we were provided.
|
||||
util.SortAndUnique(request.ServerNames)
|
||||
|
||||
// Try each server that we were provided until we land on one that
|
||||
// successfully completes the make-leave send-leave dance.
|
||||
for _, serverName := range request.ServerNames {
|
||||
// Try to perform a make_leave using the information supplied in the
|
||||
// request.
|
||||
respMakeLeave, err := r.federation.MakeLeave(
|
||||
ctx,
|
||||
serverName,
|
||||
request.RoomID,
|
||||
request.UserID,
|
||||
)
|
||||
if err != nil {
|
||||
// TODO: Check if the user was not allowed to leave the room.
|
||||
logrus.WithError(err).Warnf("r.federation.MakeLeave failed")
|
||||
r.statistics.ForServer(serverName).Failure()
|
||||
continue
|
||||
}
|
||||
|
||||
// Set all the fields to be what they should be, this should be a no-op
|
||||
// but it's possible that the remote server returned us something "odd"
|
||||
respMakeLeave.LeaveEvent.Type = gomatrixserverlib.MRoomMember
|
||||
respMakeLeave.LeaveEvent.Sender = request.UserID
|
||||
respMakeLeave.LeaveEvent.StateKey = &request.UserID
|
||||
respMakeLeave.LeaveEvent.RoomID = request.RoomID
|
||||
respMakeLeave.LeaveEvent.Redacts = ""
|
||||
if respMakeLeave.LeaveEvent.Content == nil {
|
||||
content := map[string]interface{}{
|
||||
"membership": "leave",
|
||||
}
|
||||
if err = respMakeLeave.LeaveEvent.SetContent(content); err != nil {
|
||||
logrus.WithError(err).Warnf("respMakeLeave.LeaveEvent.SetContent failed")
|
||||
continue
|
||||
}
|
||||
}
|
||||
if err = respMakeLeave.LeaveEvent.SetUnsigned(struct{}{}); err != nil {
|
||||
logrus.WithError(err).Warnf("respMakeLeave.LeaveEvent.SetUnsigned failed")
|
||||
continue
|
||||
}
|
||||
|
||||
// Work out if we support the room version that has been supplied in
|
||||
// the make_leave response.
|
||||
if _, err = respMakeLeave.RoomVersion.EventFormat(); err != nil {
|
||||
return gomatrixserverlib.UnsupportedRoomVersionError{}
|
||||
}
|
||||
|
||||
// Build the leave event.
|
||||
event, err := respMakeLeave.LeaveEvent.Build(
|
||||
time.Now(),
|
||||
r.cfg.Matrix.ServerName,
|
||||
r.cfg.Matrix.KeyID,
|
||||
r.cfg.Matrix.PrivateKey,
|
||||
respMakeLeave.RoomVersion,
|
||||
)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Warnf("respMakeLeave.LeaveEvent.Build failed")
|
||||
continue
|
||||
}
|
||||
|
||||
// Try to perform a send_leave using the newly built event.
|
||||
err = r.federation.SendLeave(
|
||||
ctx,
|
||||
serverName,
|
||||
event,
|
||||
)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Warnf("r.federation.SendLeave failed")
|
||||
r.statistics.ForServer(serverName).Failure()
|
||||
continue
|
||||
}
|
||||
|
||||
r.statistics.ForServer(serverName).Success()
|
||||
return nil
|
||||
}
|
||||
|
||||
// If we reach here then we didn't complete a leave for some reason.
|
||||
return fmt.Errorf(
|
||||
"failed to leave room %q through %d server(s)",
|
||||
request.RoomID, len(request.ServerNames),
|
||||
)
|
||||
}
|
||||
|
||||
// PerformLeaveRequest implements api.FederationInternalAPI
|
||||
func (r *FederationInternalAPI) PerformInvite(
|
||||
ctx context.Context,
|
||||
request *api.PerformInviteRequest,
|
||||
response *api.PerformInviteResponse,
|
||||
) (err error) {
|
||||
if request.Event.StateKey() == nil {
|
||||
return errors.New("invite must be a state event")
|
||||
}
|
||||
|
||||
_, destination, err := gomatrixserverlib.SplitID('@', *request.Event.StateKey())
|
||||
if err != nil {
|
||||
return fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
|
||||
}
|
||||
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"event_id": request.Event.EventID(),
|
||||
"user_id": *request.Event.StateKey(),
|
||||
"room_id": request.Event.RoomID(),
|
||||
"room_version": request.RoomVersion,
|
||||
"destination": destination,
|
||||
}).Info("Sending invite")
|
||||
|
||||
inviteReq, err := gomatrixserverlib.NewInviteV2Request(request.Event, request.InviteRoomState)
|
||||
if err != nil {
|
||||
return fmt.Errorf("gomatrixserverlib.NewInviteV2Request: %w", err)
|
||||
}
|
||||
|
||||
inviteRes, err := r.federation.SendInviteV2(ctx, destination, inviteReq)
|
||||
if err != nil {
|
||||
return fmt.Errorf("r.federation.SendInviteV2: %w", err)
|
||||
}
|
||||
|
||||
response.Event = inviteRes.Event.Headered(request.RoomVersion)
|
||||
return nil
|
||||
}
|
||||
|
||||
// PerformServersAlive implements api.FederationInternalAPI
|
||||
func (r *FederationInternalAPI) PerformServersAlive(
|
||||
ctx context.Context,
|
||||
request *api.PerformServersAliveRequest,
|
||||
response *api.PerformServersAliveResponse,
|
||||
) (err error) {
|
||||
for _, srv := range request.Servers {
|
||||
_ = r.db.RemoveServerFromBlacklist(srv)
|
||||
r.queues.RetryServer(srv)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// PerformServersAlive implements api.FederationInternalAPI
|
||||
func (r *FederationInternalAPI) PerformBroadcastEDU(
|
||||
ctx context.Context,
|
||||
request *api.PerformBroadcastEDURequest,
|
||||
response *api.PerformBroadcastEDUResponse,
|
||||
) (err error) {
|
||||
destinations, err := r.db.GetAllJoinedHosts(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("r.db.GetAllJoinedHosts: %w", err)
|
||||
}
|
||||
if len(destinations) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
logrus.WithContext(ctx).Infof("Sending wake-up EDU to %d destination(s)", len(destinations))
|
||||
|
||||
edu := &gomatrixserverlib.EDU{
|
||||
Type: "org.matrix.dendrite.wakeup",
|
||||
Origin: string(r.cfg.Matrix.ServerName),
|
||||
}
|
||||
if err = r.queues.SendEDU(edu, r.cfg.Matrix.ServerName, destinations); err != nil {
|
||||
return fmt.Errorf("r.queues.SendEDU: %w", err)
|
||||
}
|
||||
|
||||
wakeReq := &api.PerformServersAliveRequest{
|
||||
Servers: destinations,
|
||||
}
|
||||
wakeRes := &api.PerformServersAliveResponse{}
|
||||
if err := r.PerformServersAlive(ctx, wakeReq, wakeRes); err != nil {
|
||||
return fmt.Errorf("r.PerformServersAlive: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func sanityCheckAuthChain(authChain []*gomatrixserverlib.Event) error {
|
||||
// sanity check we have a create event and it has a known room version
|
||||
for _, ev := range authChain {
|
||||
if ev.Type() == gomatrixserverlib.MRoomCreate && ev.StateKeyEquals("") {
|
||||
// make sure the room version is known
|
||||
content := ev.Content()
|
||||
verBody := struct {
|
||||
Version string `json:"room_version"`
|
||||
}{}
|
||||
err := json.Unmarshal(content, &verBody)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if verBody.Version == "" {
|
||||
// https://matrix.org/docs/spec/client_server/r0.6.0#m-room-create
|
||||
// The version of the room. Defaults to "1" if the key does not exist.
|
||||
verBody.Version = "1"
|
||||
}
|
||||
knownVersions := gomatrixserverlib.RoomVersions()
|
||||
if _, ok := knownVersions[gomatrixserverlib.RoomVersion(verBody.Version)]; !ok {
|
||||
return fmt.Errorf("auth chain m.room.create event has an unknown room version: %s", verBody.Version)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("auth chain response is missing m.room.create event")
|
||||
}
|
||||
|
||||
func setDefaultRoomVersionFromJoinEvent(joinEvent gomatrixserverlib.EventBuilder) gomatrixserverlib.RoomVersion {
|
||||
// if auth events are not event references we know it must be v3+
|
||||
// we have to do these shenanigans to satisfy sytest, specifically for:
|
||||
// "Outbound federation rejects m.room.create events with an unknown room version"
|
||||
hasEventRefs := true
|
||||
authEvents, ok := joinEvent.AuthEvents.([]interface{})
|
||||
if ok {
|
||||
if len(authEvents) > 0 {
|
||||
_, ok = authEvents[0].(string)
|
||||
if ok {
|
||||
// event refs are objects, not strings, so we know we must be dealing with a v3+ room.
|
||||
hasEventRefs = false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if hasEventRefs {
|
||||
return gomatrixserverlib.RoomVersionV1
|
||||
}
|
||||
return gomatrixserverlib.RoomVersionV4
|
||||
}
|
||||
|
||||
// FederatedAuthProvider is an auth chain provider which fetches events from the server provided
|
||||
func federatedAuthProvider(
|
||||
ctx context.Context, federation *gomatrixserverlib.FederationClient,
|
||||
keyRing gomatrixserverlib.JSONVerifier, server gomatrixserverlib.ServerName,
|
||||
) gomatrixserverlib.AuthChainProvider {
|
||||
// A list of events that we have retried, if they were not included in
|
||||
// the auth events supplied in the send_join.
|
||||
retries := map[string][]*gomatrixserverlib.Event{}
|
||||
|
||||
// Define a function which we can pass to Check to retrieve missing
|
||||
// auth events inline. This greatly increases our chances of not having
|
||||
// to repeat the entire set of checks just for a missing event or two.
|
||||
return func(roomVersion gomatrixserverlib.RoomVersion, eventIDs []string) ([]*gomatrixserverlib.Event, error) {
|
||||
returning := []*gomatrixserverlib.Event{}
|
||||
|
||||
// See if we have retry entries for each of the supplied event IDs.
|
||||
for _, eventID := range eventIDs {
|
||||
// If we've already satisfied a request for this event ID before then
|
||||
// just append the results. We won't retry the request.
|
||||
if retry, ok := retries[eventID]; ok {
|
||||
if retry == nil {
|
||||
return nil, fmt.Errorf("missingAuth: not retrying failed event ID %q", eventID)
|
||||
}
|
||||
returning = append(returning, retry...)
|
||||
continue
|
||||
}
|
||||
|
||||
// Make a note of the fact that we tried to do something with this
|
||||
// event ID, even if we don't succeed.
|
||||
retries[eventID] = nil
|
||||
|
||||
// Try to retrieve the event from the server that sent us the send
|
||||
// join response.
|
||||
tx, txerr := federation.GetEvent(ctx, server, eventID)
|
||||
if txerr != nil {
|
||||
return nil, fmt.Errorf("missingAuth r.federation.GetEvent: %w", txerr)
|
||||
}
|
||||
|
||||
// For each event returned, add it to the set of return events. We
|
||||
// also will populate the retries, in case someone asks for this
|
||||
// event ID again.
|
||||
for _, pdu := range tx.PDUs {
|
||||
// Try to parse the event.
|
||||
ev, everr := gomatrixserverlib.NewEventFromUntrustedJSON(pdu, roomVersion)
|
||||
if everr != nil {
|
||||
return nil, fmt.Errorf("missingAuth gomatrixserverlib.NewEventFromUntrustedJSON: %w", everr)
|
||||
}
|
||||
|
||||
// Check the signatures of the event.
|
||||
if err := ev.VerifyEventSignatures(ctx, keyRing); err != nil {
|
||||
return nil, fmt.Errorf("missingAuth VerifyEventSignatures: %w", err)
|
||||
}
|
||||
|
||||
// If the event is OK then add it to the results and the retry map.
|
||||
returning = append(returning, ev)
|
||||
retries[ev.EventID()] = append(retries[ev.EventID()], ev)
|
||||
}
|
||||
}
|
||||
return returning, nil
|
||||
}
|
||||
}
|
97
federationapi/internal/query.go
Normal file
97
federationapi/internal/query.go
Normal file
|
@ -0,0 +1,97 @@
|
|||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/federationapi/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
)
|
||||
|
||||
// QueryJoinedHostServerNamesInRoom implements api.FederationInternalAPI
|
||||
func (f *FederationInternalAPI) QueryJoinedHostServerNamesInRoom(
|
||||
ctx context.Context,
|
||||
request *api.QueryJoinedHostServerNamesInRoomRequest,
|
||||
response *api.QueryJoinedHostServerNamesInRoomResponse,
|
||||
) (err error) {
|
||||
joinedHosts, err := f.db.GetJoinedHostsForRooms(ctx, []string{request.RoomID})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
response.ServerNames = joinedHosts
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (a *FederationInternalAPI) fetchServerKeysDirectly(ctx context.Context, serverName gomatrixserverlib.ServerName) (*gomatrixserverlib.ServerKeys, error) {
|
||||
ctx, cancel := context.WithTimeout(ctx, time.Second*30)
|
||||
defer cancel()
|
||||
ires, err := a.doRequest(serverName, func() (interface{}, error) {
|
||||
return a.federation.GetServerKeys(ctx, serverName)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sks := ires.(gomatrixserverlib.ServerKeys)
|
||||
return &sks, nil
|
||||
}
|
||||
|
||||
func (a *FederationInternalAPI) fetchServerKeysFromCache(
|
||||
ctx context.Context, req *api.QueryServerKeysRequest,
|
||||
) ([]gomatrixserverlib.ServerKeys, error) {
|
||||
var results []gomatrixserverlib.ServerKeys
|
||||
for keyID, criteria := range req.KeyIDToCriteria {
|
||||
serverKeysResponses, _ := a.db.GetNotaryKeys(ctx, req.ServerName, []gomatrixserverlib.KeyID{keyID})
|
||||
if len(serverKeysResponses) == 0 {
|
||||
return nil, fmt.Errorf("failed to find server key response for key ID %s", keyID)
|
||||
}
|
||||
// we should only get 1 result as we only gave 1 key ID
|
||||
sk := serverKeysResponses[0]
|
||||
util.GetLogger(ctx).Infof("fetchServerKeysFromCache: minvalid:%v keys: %+v", criteria.MinimumValidUntilTS, sk)
|
||||
if criteria.MinimumValidUntilTS != 0 {
|
||||
// check if it's still valid. if they have the same value that's also valid
|
||||
if sk.ValidUntilTS < criteria.MinimumValidUntilTS {
|
||||
return nil, fmt.Errorf(
|
||||
"found server response for key ID %s but it is no longer valid, min: %v valid_until: %v",
|
||||
keyID, criteria.MinimumValidUntilTS, sk.ValidUntilTS,
|
||||
)
|
||||
}
|
||||
}
|
||||
results = append(results, sk)
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (a *FederationInternalAPI) QueryServerKeys(
|
||||
ctx context.Context, req *api.QueryServerKeysRequest, res *api.QueryServerKeysResponse,
|
||||
) error {
|
||||
// attempt to satisfy the entire request from the cache first
|
||||
results, err := a.fetchServerKeysFromCache(ctx, req)
|
||||
if err == nil {
|
||||
// satisfied entirely from cache, return it
|
||||
res.ServerKeys = results
|
||||
return nil
|
||||
}
|
||||
util.GetLogger(ctx).WithField("server", req.ServerName).WithError(err).Warn("notary: failed to satisfy keys request entirely from cache, hitting direct")
|
||||
|
||||
serverKeys, err := a.fetchServerKeysDirectly(ctx, req.ServerName)
|
||||
if err != nil {
|
||||
// try to load as much as we can from the cache in a best effort basis
|
||||
util.GetLogger(ctx).WithField("server", req.ServerName).WithError(err).Warn("notary: failed to ask server for keys, returning best effort keys")
|
||||
serverKeysResponses, dbErr := a.db.GetNotaryKeys(ctx, req.ServerName, req.KeyIDs())
|
||||
if dbErr != nil {
|
||||
return fmt.Errorf("notary: server returned %s, and db returned %s", err, dbErr)
|
||||
}
|
||||
res.ServerKeys = serverKeysResponses
|
||||
return nil
|
||||
}
|
||||
// cache it!
|
||||
if err = a.db.UpdateNotaryKeys(context.Background(), req.ServerName, *serverKeys); err != nil {
|
||||
// non-fatal, still return the response
|
||||
util.GetLogger(ctx).WithError(err).Warn("failed to UpdateNotaryKeys")
|
||||
}
|
||||
res.ServerKeys = []gomatrixserverlib.ServerKeys{*serverKeys}
|
||||
return nil
|
||||
}
|
575
federationapi/inthttp/client.go
Normal file
575
federationapi/inthttp/client.go
Normal file
|
@ -0,0 +1,575 @@
|
|||
package inthttp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/matrix-org/dendrite/federationapi/api"
|
||||
"github.com/matrix-org/dendrite/internal/caching"
|
||||
"github.com/matrix-org/dendrite/internal/httputil"
|
||||
"github.com/matrix-org/gomatrix"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/opentracing/opentracing-go"
|
||||
)
|
||||
|
||||
// HTTP paths for the internal HTTP API
|
||||
const (
|
||||
FederationAPIQueryJoinedHostServerNamesInRoomPath = "/federationapi/queryJoinedHostServerNamesInRoom"
|
||||
FederationAPIQueryServerKeysPath = "/federationapi/queryServerKeys"
|
||||
|
||||
FederationAPIPerformDirectoryLookupRequestPath = "/federationapi/performDirectoryLookup"
|
||||
FederationAPIPerformJoinRequestPath = "/federationapi/performJoinRequest"
|
||||
FederationAPIPerformLeaveRequestPath = "/federationapi/performLeaveRequest"
|
||||
FederationAPIPerformInviteRequestPath = "/federationapi/performInviteRequest"
|
||||
FederationAPIPerformOutboundPeekRequestPath = "/federationapi/performOutboundPeekRequest"
|
||||
FederationAPIPerformServersAlivePath = "/federationapi/performServersAlive"
|
||||
FederationAPIPerformBroadcastEDUPath = "/federationapi/performBroadcastEDU"
|
||||
|
||||
FederationAPIGetUserDevicesPath = "/federationapi/client/getUserDevices"
|
||||
FederationAPIClaimKeysPath = "/federationapi/client/claimKeys"
|
||||
FederationAPIQueryKeysPath = "/federationapi/client/queryKeys"
|
||||
FederationAPIBackfillPath = "/federationapi/client/backfill"
|
||||
FederationAPILookupStatePath = "/federationapi/client/lookupState"
|
||||
FederationAPILookupStateIDsPath = "/federationapi/client/lookupStateIDs"
|
||||
FederationAPIGetEventPath = "/federationapi/client/getEvent"
|
||||
FederationAPILookupServerKeysPath = "/federationapi/client/lookupServerKeys"
|
||||
FederationAPIEventRelationshipsPath = "/federationapi/client/msc2836eventRelationships"
|
||||
FederationAPISpacesSummaryPath = "/federationapi/client/msc2946spacesSummary"
|
||||
|
||||
FederationAPIInputPublicKeyPath = "/federationapi/inputPublicKey"
|
||||
FederationAPIQueryPublicKeyPath = "/federationapi/queryPublicKey"
|
||||
)
|
||||
|
||||
// NewFederationAPIClient creates a FederationInternalAPI implemented by talking to a HTTP POST API.
|
||||
// If httpClient is nil an error is returned
|
||||
func NewFederationAPIClient(federationSenderURL string, httpClient *http.Client, cache caching.ServerKeyCache) (api.FederationInternalAPI, error) {
|
||||
if httpClient == nil {
|
||||
return nil, errors.New("NewFederationInternalAPIHTTP: httpClient is <nil>")
|
||||
}
|
||||
return &httpFederationInternalAPI{federationSenderURL, httpClient, cache}, nil
|
||||
}
|
||||
|
||||
type httpFederationInternalAPI struct {
|
||||
federationAPIURL string
|
||||
httpClient *http.Client
|
||||
cache caching.ServerKeyCache
|
||||
}
|
||||
|
||||
// Handle an instruction to make_leave & send_leave with a remote server.
|
||||
func (h *httpFederationInternalAPI) PerformLeave(
|
||||
ctx context.Context,
|
||||
request *api.PerformLeaveRequest,
|
||||
response *api.PerformLeaveResponse,
|
||||
) error {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformLeaveRequest")
|
||||
defer span.Finish()
|
||||
|
||||
apiURL := h.federationAPIURL + FederationAPIPerformLeaveRequestPath
|
||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
|
||||
}
|
||||
|
||||
// Handle sending an invite to a remote server.
|
||||
func (h *httpFederationInternalAPI) PerformInvite(
|
||||
ctx context.Context,
|
||||
request *api.PerformInviteRequest,
|
||||
response *api.PerformInviteResponse,
|
||||
) error {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformInviteRequest")
|
||||
defer span.Finish()
|
||||
|
||||
apiURL := h.federationAPIURL + FederationAPIPerformInviteRequestPath
|
||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
|
||||
}
|
||||
|
||||
// Handle starting a peek on a remote server.
|
||||
func (h *httpFederationInternalAPI) PerformOutboundPeek(
|
||||
ctx context.Context,
|
||||
request *api.PerformOutboundPeekRequest,
|
||||
response *api.PerformOutboundPeekResponse,
|
||||
) error {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformOutboundPeekRequest")
|
||||
defer span.Finish()
|
||||
|
||||
apiURL := h.federationAPIURL + FederationAPIPerformOutboundPeekRequestPath
|
||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
|
||||
}
|
||||
|
||||
func (h *httpFederationInternalAPI) PerformServersAlive(
|
||||
ctx context.Context,
|
||||
request *api.PerformServersAliveRequest,
|
||||
response *api.PerformServersAliveResponse,
|
||||
) error {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformServersAlive")
|
||||
defer span.Finish()
|
||||
|
||||
apiURL := h.federationAPIURL + FederationAPIPerformServersAlivePath
|
||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
|
||||
}
|
||||
|
||||
// QueryJoinedHostServerNamesInRoom implements FederationInternalAPI
|
||||
func (h *httpFederationInternalAPI) QueryJoinedHostServerNamesInRoom(
|
||||
ctx context.Context,
|
||||
request *api.QueryJoinedHostServerNamesInRoomRequest,
|
||||
response *api.QueryJoinedHostServerNamesInRoomResponse,
|
||||
) error {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryJoinedHostServerNamesInRoom")
|
||||
defer span.Finish()
|
||||
|
||||
apiURL := h.federationAPIURL + FederationAPIQueryJoinedHostServerNamesInRoomPath
|
||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
|
||||
}
|
||||
|
||||
// Handle an instruction to make_join & send_join with a remote server.
|
||||
func (h *httpFederationInternalAPI) PerformJoin(
|
||||
ctx context.Context,
|
||||
request *api.PerformJoinRequest,
|
||||
response *api.PerformJoinResponse,
|
||||
) {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformJoinRequest")
|
||||
defer span.Finish()
|
||||
|
||||
apiURL := h.federationAPIURL + FederationAPIPerformJoinRequestPath
|
||||
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
|
||||
if err != nil {
|
||||
response.LastError = &gomatrix.HTTPError{
|
||||
Message: err.Error(),
|
||||
Code: 0,
|
||||
WrappedError: err,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle an instruction to make_join & send_join with a remote server.
|
||||
func (h *httpFederationInternalAPI) PerformDirectoryLookup(
|
||||
ctx context.Context,
|
||||
request *api.PerformDirectoryLookupRequest,
|
||||
response *api.PerformDirectoryLookupResponse,
|
||||
) error {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformDirectoryLookup")
|
||||
defer span.Finish()
|
||||
|
||||
apiURL := h.federationAPIURL + FederationAPIPerformDirectoryLookupRequestPath
|
||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
|
||||
}
|
||||
|
||||
// Handle an instruction to broadcast an EDU to all servers in rooms we are joined to.
|
||||
func (h *httpFederationInternalAPI) PerformBroadcastEDU(
|
||||
ctx context.Context,
|
||||
request *api.PerformBroadcastEDURequest,
|
||||
response *api.PerformBroadcastEDUResponse,
|
||||
) error {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformBroadcastEDU")
|
||||
defer span.Finish()
|
||||
|
||||
apiURL := h.federationAPIURL + FederationAPIPerformBroadcastEDUPath
|
||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
|
||||
}
|
||||
|
||||
type getUserDevices struct {
|
||||
S gomatrixserverlib.ServerName
|
||||
UserID string
|
||||
Res *gomatrixserverlib.RespUserDevices
|
||||
Err *api.FederationClientError
|
||||
}
|
||||
|
||||
func (h *httpFederationInternalAPI) GetUserDevices(
|
||||
ctx context.Context, s gomatrixserverlib.ServerName, userID string,
|
||||
) (gomatrixserverlib.RespUserDevices, error) {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "GetUserDevices")
|
||||
defer span.Finish()
|
||||
|
||||
var result gomatrixserverlib.RespUserDevices
|
||||
request := getUserDevices{
|
||||
S: s,
|
||||
UserID: userID,
|
||||
}
|
||||
var response getUserDevices
|
||||
apiURL := h.federationAPIURL + FederationAPIGetUserDevicesPath
|
||||
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
if response.Err != nil {
|
||||
return result, response.Err
|
||||
}
|
||||
return *response.Res, nil
|
||||
}
|
||||
|
||||
type claimKeys struct {
|
||||
S gomatrixserverlib.ServerName
|
||||
OneTimeKeys map[string]map[string]string
|
||||
Res *gomatrixserverlib.RespClaimKeys
|
||||
Err *api.FederationClientError
|
||||
}
|
||||
|
||||
func (h *httpFederationInternalAPI) ClaimKeys(
|
||||
ctx context.Context, s gomatrixserverlib.ServerName, oneTimeKeys map[string]map[string]string,
|
||||
) (gomatrixserverlib.RespClaimKeys, error) {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "ClaimKeys")
|
||||
defer span.Finish()
|
||||
|
||||
var result gomatrixserverlib.RespClaimKeys
|
||||
request := claimKeys{
|
||||
S: s,
|
||||
OneTimeKeys: oneTimeKeys,
|
||||
}
|
||||
var response claimKeys
|
||||
apiURL := h.federationAPIURL + FederationAPIClaimKeysPath
|
||||
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
if response.Err != nil {
|
||||
return result, response.Err
|
||||
}
|
||||
return *response.Res, nil
|
||||
}
|
||||
|
||||
type queryKeys struct {
|
||||
S gomatrixserverlib.ServerName
|
||||
Keys map[string][]string
|
||||
Res *gomatrixserverlib.RespQueryKeys
|
||||
Err *api.FederationClientError
|
||||
}
|
||||
|
||||
func (h *httpFederationInternalAPI) QueryKeys(
|
||||
ctx context.Context, s gomatrixserverlib.ServerName, keys map[string][]string,
|
||||
) (gomatrixserverlib.RespQueryKeys, error) {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryKeys")
|
||||
defer span.Finish()
|
||||
|
||||
var result gomatrixserverlib.RespQueryKeys
|
||||
request := queryKeys{
|
||||
S: s,
|
||||
Keys: keys,
|
||||
}
|
||||
var response queryKeys
|
||||
apiURL := h.federationAPIURL + FederationAPIQueryKeysPath
|
||||
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
if response.Err != nil {
|
||||
return result, response.Err
|
||||
}
|
||||
return *response.Res, nil
|
||||
}
|
||||
|
||||
type backfill struct {
|
||||
S gomatrixserverlib.ServerName
|
||||
RoomID string
|
||||
Limit int
|
||||
EventIDs []string
|
||||
Res *gomatrixserverlib.Transaction
|
||||
Err *api.FederationClientError
|
||||
}
|
||||
|
||||
func (h *httpFederationInternalAPI) Backfill(
|
||||
ctx context.Context, s gomatrixserverlib.ServerName, roomID string, limit int, eventIDs []string,
|
||||
) (gomatrixserverlib.Transaction, error) {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "Backfill")
|
||||
defer span.Finish()
|
||||
|
||||
request := backfill{
|
||||
S: s,
|
||||
RoomID: roomID,
|
||||
Limit: limit,
|
||||
EventIDs: eventIDs,
|
||||
}
|
||||
var response backfill
|
||||
apiURL := h.federationAPIURL + FederationAPIBackfillPath
|
||||
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
|
||||
if err != nil {
|
||||
return gomatrixserverlib.Transaction{}, err
|
||||
}
|
||||
if response.Err != nil {
|
||||
return gomatrixserverlib.Transaction{}, response.Err
|
||||
}
|
||||
return *response.Res, nil
|
||||
}
|
||||
|
||||
type lookupState struct {
|
||||
S gomatrixserverlib.ServerName
|
||||
RoomID string
|
||||
EventID string
|
||||
RoomVersion gomatrixserverlib.RoomVersion
|
||||
Res *gomatrixserverlib.RespState
|
||||
Err *api.FederationClientError
|
||||
}
|
||||
|
||||
func (h *httpFederationInternalAPI) LookupState(
|
||||
ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string, roomVersion gomatrixserverlib.RoomVersion,
|
||||
) (gomatrixserverlib.RespState, error) {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "LookupState")
|
||||
defer span.Finish()
|
||||
|
||||
request := lookupState{
|
||||
S: s,
|
||||
RoomID: roomID,
|
||||
EventID: eventID,
|
||||
RoomVersion: roomVersion,
|
||||
}
|
||||
var response lookupState
|
||||
apiURL := h.federationAPIURL + FederationAPILookupStatePath
|
||||
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
|
||||
if err != nil {
|
||||
return gomatrixserverlib.RespState{}, err
|
||||
}
|
||||
if response.Err != nil {
|
||||
return gomatrixserverlib.RespState{}, response.Err
|
||||
}
|
||||
return *response.Res, nil
|
||||
}
|
||||
|
||||
type lookupStateIDs struct {
|
||||
S gomatrixserverlib.ServerName
|
||||
RoomID string
|
||||
EventID string
|
||||
Res *gomatrixserverlib.RespStateIDs
|
||||
Err *api.FederationClientError
|
||||
}
|
||||
|
||||
func (h *httpFederationInternalAPI) LookupStateIDs(
|
||||
ctx context.Context, s gomatrixserverlib.ServerName, roomID, eventID string,
|
||||
) (gomatrixserverlib.RespStateIDs, error) {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "LookupStateIDs")
|
||||
defer span.Finish()
|
||||
|
||||
request := lookupStateIDs{
|
||||
S: s,
|
||||
RoomID: roomID,
|
||||
EventID: eventID,
|
||||
}
|
||||
var response lookupStateIDs
|
||||
apiURL := h.federationAPIURL + FederationAPILookupStateIDsPath
|
||||
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
|
||||
if err != nil {
|
||||
return gomatrixserverlib.RespStateIDs{}, err
|
||||
}
|
||||
if response.Err != nil {
|
||||
return gomatrixserverlib.RespStateIDs{}, response.Err
|
||||
}
|
||||
return *response.Res, nil
|
||||
}
|
||||
|
||||
type getEvent struct {
|
||||
S gomatrixserverlib.ServerName
|
||||
EventID string
|
||||
Res *gomatrixserverlib.Transaction
|
||||
Err *api.FederationClientError
|
||||
}
|
||||
|
||||
func (h *httpFederationInternalAPI) GetEvent(
|
||||
ctx context.Context, s gomatrixserverlib.ServerName, eventID string,
|
||||
) (gomatrixserverlib.Transaction, error) {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "GetEvent")
|
||||
defer span.Finish()
|
||||
|
||||
request := getEvent{
|
||||
S: s,
|
||||
EventID: eventID,
|
||||
}
|
||||
var response getEvent
|
||||
apiURL := h.federationAPIURL + FederationAPIGetEventPath
|
||||
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
|
||||
if err != nil {
|
||||
return gomatrixserverlib.Transaction{}, err
|
||||
}
|
||||
if response.Err != nil {
|
||||
return gomatrixserverlib.Transaction{}, response.Err
|
||||
}
|
||||
return *response.Res, nil
|
||||
}
|
||||
|
||||
func (h *httpFederationInternalAPI) QueryServerKeys(
|
||||
ctx context.Context, req *api.QueryServerKeysRequest, res *api.QueryServerKeysResponse,
|
||||
) error {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryServerKeys")
|
||||
defer span.Finish()
|
||||
|
||||
apiURL := h.federationAPIURL + FederationAPIQueryServerKeysPath
|
||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, req, res)
|
||||
}
|
||||
|
||||
type lookupServerKeys struct {
|
||||
S gomatrixserverlib.ServerName
|
||||
KeyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp
|
||||
ServerKeys []gomatrixserverlib.ServerKeys
|
||||
Err *api.FederationClientError
|
||||
}
|
||||
|
||||
func (h *httpFederationInternalAPI) LookupServerKeys(
|
||||
ctx context.Context, s gomatrixserverlib.ServerName, keyRequests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
|
||||
) ([]gomatrixserverlib.ServerKeys, error) {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "LookupServerKeys")
|
||||
defer span.Finish()
|
||||
|
||||
request := lookupServerKeys{
|
||||
S: s,
|
||||
KeyRequests: keyRequests,
|
||||
}
|
||||
var response lookupServerKeys
|
||||
apiURL := h.federationAPIURL + FederationAPILookupServerKeysPath
|
||||
err := httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
|
||||
if err != nil {
|
||||
return []gomatrixserverlib.ServerKeys{}, err
|
||||
}
|
||||
if response.Err != nil {
|
||||
return []gomatrixserverlib.ServerKeys{}, response.Err
|
||||
}
|
||||
return response.ServerKeys, nil
|
||||
}
|
||||
|
||||
type eventRelationships struct {
|
||||
S gomatrixserverlib.ServerName
|
||||
Req gomatrixserverlib.MSC2836EventRelationshipsRequest
|
||||
RoomVer gomatrixserverlib.RoomVersion
|
||||
Res gomatrixserverlib.MSC2836EventRelationshipsResponse
|
||||
Err *api.FederationClientError
|
||||
}
|
||||
|
||||
func (h *httpFederationInternalAPI) MSC2836EventRelationships(
|
||||
ctx context.Context, s gomatrixserverlib.ServerName, r gomatrixserverlib.MSC2836EventRelationshipsRequest,
|
||||
roomVersion gomatrixserverlib.RoomVersion,
|
||||
) (res gomatrixserverlib.MSC2836EventRelationshipsResponse, err error) {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "MSC2836EventRelationships")
|
||||
defer span.Finish()
|
||||
|
||||
request := eventRelationships{
|
||||
S: s,
|
||||
Req: r,
|
||||
RoomVer: roomVersion,
|
||||
}
|
||||
var response eventRelationships
|
||||
apiURL := h.federationAPIURL + FederationAPIEventRelationshipsPath
|
||||
err = httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
|
||||
if err != nil {
|
||||
return res, err
|
||||
}
|
||||
if response.Err != nil {
|
||||
return res, response.Err
|
||||
}
|
||||
return response.Res, nil
|
||||
}
|
||||
|
||||
type spacesReq struct {
|
||||
S gomatrixserverlib.ServerName
|
||||
Req gomatrixserverlib.MSC2946SpacesRequest
|
||||
RoomID string
|
||||
Res gomatrixserverlib.MSC2946SpacesResponse
|
||||
Err *api.FederationClientError
|
||||
}
|
||||
|
||||
func (h *httpFederationInternalAPI) MSC2946Spaces(
|
||||
ctx context.Context, dst gomatrixserverlib.ServerName, roomID string, r gomatrixserverlib.MSC2946SpacesRequest,
|
||||
) (res gomatrixserverlib.MSC2946SpacesResponse, err error) {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "MSC2946Spaces")
|
||||
defer span.Finish()
|
||||
|
||||
request := spacesReq{
|
||||
S: dst,
|
||||
Req: r,
|
||||
RoomID: roomID,
|
||||
}
|
||||
var response spacesReq
|
||||
apiURL := h.federationAPIURL + FederationAPISpacesSummaryPath
|
||||
err = httputil.PostJSON(ctx, span, h.httpClient, apiURL, &request, &response)
|
||||
if err != nil {
|
||||
return res, err
|
||||
}
|
||||
if response.Err != nil {
|
||||
return res, response.Err
|
||||
}
|
||||
return response.Res, nil
|
||||
}
|
||||
|
||||
func (s *httpFederationInternalAPI) KeyRing() *gomatrixserverlib.KeyRing {
|
||||
// This is a bit of a cheat - we tell gomatrixserverlib that this API is
|
||||
// both the key database and the key fetcher. While this does have the
|
||||
// rather unfortunate effect of preventing gomatrixserverlib from handling
|
||||
// key fetchers directly, we can at least reimplement this behaviour on
|
||||
// the other end of the API.
|
||||
return &gomatrixserverlib.KeyRing{
|
||||
KeyDatabase: s,
|
||||
KeyFetchers: []gomatrixserverlib.KeyFetcher{},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *httpFederationInternalAPI) FetcherName() string {
|
||||
return "httpServerKeyInternalAPI"
|
||||
}
|
||||
|
||||
func (s *httpFederationInternalAPI) StoreKeys(
|
||||
_ context.Context,
|
||||
results map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult,
|
||||
) error {
|
||||
// Run in a background context - we don't want to stop this work just
|
||||
// because the caller gives up waiting.
|
||||
ctx := context.Background()
|
||||
request := api.InputPublicKeysRequest{
|
||||
Keys: make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult),
|
||||
}
|
||||
response := api.InputPublicKeysResponse{}
|
||||
for req, res := range results {
|
||||
request.Keys[req] = res
|
||||
s.cache.StoreServerKey(req, res)
|
||||
}
|
||||
return s.InputPublicKeys(ctx, &request, &response)
|
||||
}
|
||||
|
||||
func (s *httpFederationInternalAPI) FetchKeys(
|
||||
_ context.Context,
|
||||
requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
|
||||
) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) {
|
||||
// Run in a background context - we don't want to stop this work just
|
||||
// because the caller gives up waiting.
|
||||
ctx := context.Background()
|
||||
result := make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult)
|
||||
request := api.QueryPublicKeysRequest{
|
||||
Requests: make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp),
|
||||
}
|
||||
response := api.QueryPublicKeysResponse{
|
||||
Results: make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult),
|
||||
}
|
||||
for req, ts := range requests {
|
||||
if res, ok := s.cache.GetServerKey(req, ts); ok {
|
||||
result[req] = res
|
||||
continue
|
||||
}
|
||||
request.Requests[req] = ts
|
||||
}
|
||||
err := s.QueryPublicKeys(ctx, &request, &response)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for req, res := range response.Results {
|
||||
result[req] = res
|
||||
s.cache.StoreServerKey(req, res)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (h *httpFederationInternalAPI) InputPublicKeys(
|
||||
ctx context.Context,
|
||||
request *api.InputPublicKeysRequest,
|
||||
response *api.InputPublicKeysResponse,
|
||||
) error {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "InputPublicKey")
|
||||
defer span.Finish()
|
||||
|
||||
apiURL := h.federationAPIURL + FederationAPIInputPublicKeyPath
|
||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
|
||||
}
|
||||
|
||||
func (h *httpFederationInternalAPI) QueryPublicKeys(
|
||||
ctx context.Context,
|
||||
request *api.QueryPublicKeysRequest,
|
||||
response *api.QueryPublicKeysResponse,
|
||||
) error {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "QueryPublicKey")
|
||||
defer span.Finish()
|
||||
|
||||
apiURL := h.federationAPIURL + FederationAPIQueryPublicKeyPath
|
||||
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
|
||||
}
|
374
federationapi/inthttp/server.go
Normal file
374
federationapi/inthttp/server.go
Normal file
|
@ -0,0 +1,374 @@
|
|||
package inthttp
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/matrix-org/dendrite/federationapi/api"
|
||||
"github.com/matrix-org/dendrite/internal/httputil"
|
||||
"github.com/matrix-org/util"
|
||||
)
|
||||
|
||||
// AddRoutes adds the FederationInternalAPI handlers to the http.ServeMux.
|
||||
// nolint:gocyclo
|
||||
func AddRoutes(intAPI api.FederationInternalAPI, internalAPIMux *mux.Router) {
|
||||
internalAPIMux.Handle(
|
||||
FederationAPIQueryJoinedHostServerNamesInRoomPath,
|
||||
httputil.MakeInternalAPI("QueryJoinedHostServerNamesInRoom", func(req *http.Request) util.JSONResponse {
|
||||
var request api.QueryJoinedHostServerNamesInRoomRequest
|
||||
var response api.QueryJoinedHostServerNamesInRoomResponse
|
||||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
if err := intAPI.QueryJoinedHostServerNamesInRoom(req.Context(), &request, &response); err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||
}),
|
||||
)
|
||||
internalAPIMux.Handle(
|
||||
FederationAPIPerformJoinRequestPath,
|
||||
httputil.MakeInternalAPI("PerformJoinRequest", func(req *http.Request) util.JSONResponse {
|
||||
var request api.PerformJoinRequest
|
||||
var response api.PerformJoinResponse
|
||||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
intAPI.PerformJoin(req.Context(), &request, &response)
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||
}),
|
||||
)
|
||||
internalAPIMux.Handle(
|
||||
FederationAPIPerformLeaveRequestPath,
|
||||
httputil.MakeInternalAPI("PerformLeaveRequest", func(req *http.Request) util.JSONResponse {
|
||||
var request api.PerformLeaveRequest
|
||||
var response api.PerformLeaveResponse
|
||||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
if err := intAPI.PerformLeave(req.Context(), &request, &response); err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||
}),
|
||||
)
|
||||
internalAPIMux.Handle(
|
||||
FederationAPIPerformInviteRequestPath,
|
||||
httputil.MakeInternalAPI("PerformInviteRequest", func(req *http.Request) util.JSONResponse {
|
||||
var request api.PerformInviteRequest
|
||||
var response api.PerformInviteResponse
|
||||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
if err := intAPI.PerformInvite(req.Context(), &request, &response); err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||
}),
|
||||
)
|
||||
internalAPIMux.Handle(
|
||||
FederationAPIPerformDirectoryLookupRequestPath,
|
||||
httputil.MakeInternalAPI("PerformDirectoryLookupRequest", func(req *http.Request) util.JSONResponse {
|
||||
var request api.PerformDirectoryLookupRequest
|
||||
var response api.PerformDirectoryLookupResponse
|
||||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
if err := intAPI.PerformDirectoryLookup(req.Context(), &request, &response); err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||
}),
|
||||
)
|
||||
internalAPIMux.Handle(
|
||||
FederationAPIPerformServersAlivePath,
|
||||
httputil.MakeInternalAPI("PerformServersAliveRequest", func(req *http.Request) util.JSONResponse {
|
||||
var request api.PerformServersAliveRequest
|
||||
var response api.PerformServersAliveResponse
|
||||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
if err := intAPI.PerformServersAlive(req.Context(), &request, &response); err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||
}),
|
||||
)
|
||||
internalAPIMux.Handle(
|
||||
FederationAPIPerformBroadcastEDUPath,
|
||||
httputil.MakeInternalAPI("PerformBroadcastEDU", func(req *http.Request) util.JSONResponse {
|
||||
var request api.PerformBroadcastEDURequest
|
||||
var response api.PerformBroadcastEDUResponse
|
||||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
if err := intAPI.PerformBroadcastEDU(req.Context(), &request, &response); err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||
}),
|
||||
)
|
||||
internalAPIMux.Handle(
|
||||
FederationAPIGetUserDevicesPath,
|
||||
httputil.MakeInternalAPI("GetUserDevices", func(req *http.Request) util.JSONResponse {
|
||||
var request getUserDevices
|
||||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
res, err := intAPI.GetUserDevices(req.Context(), request.S, request.UserID)
|
||||
if err != nil {
|
||||
ferr, ok := err.(*api.FederationClientError)
|
||||
if ok {
|
||||
request.Err = ferr
|
||||
} else {
|
||||
request.Err = &api.FederationClientError{
|
||||
Err: err.Error(),
|
||||
}
|
||||
}
|
||||
}
|
||||
request.Res = &res
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: request}
|
||||
}),
|
||||
)
|
||||
internalAPIMux.Handle(
|
||||
FederationAPIClaimKeysPath,
|
||||
httputil.MakeInternalAPI("ClaimKeys", func(req *http.Request) util.JSONResponse {
|
||||
var request claimKeys
|
||||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
res, err := intAPI.ClaimKeys(req.Context(), request.S, request.OneTimeKeys)
|
||||
if err != nil {
|
||||
ferr, ok := err.(*api.FederationClientError)
|
||||
if ok {
|
||||
request.Err = ferr
|
||||
} else {
|
||||
request.Err = &api.FederationClientError{
|
||||
Err: err.Error(),
|
||||
}
|
||||
}
|
||||
}
|
||||
request.Res = &res
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: request}
|
||||
}),
|
||||
)
|
||||
internalAPIMux.Handle(
|
||||
FederationAPIQueryKeysPath,
|
||||
httputil.MakeInternalAPI("QueryKeys", func(req *http.Request) util.JSONResponse {
|
||||
var request queryKeys
|
||||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
res, err := intAPI.QueryKeys(req.Context(), request.S, request.Keys)
|
||||
if err != nil {
|
||||
ferr, ok := err.(*api.FederationClientError)
|
||||
if ok {
|
||||
request.Err = ferr
|
||||
} else {
|
||||
request.Err = &api.FederationClientError{
|
||||
Err: err.Error(),
|
||||
}
|
||||
}
|
||||
}
|
||||
request.Res = &res
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: request}
|
||||
}),
|
||||
)
|
||||
internalAPIMux.Handle(
|
||||
FederationAPIBackfillPath,
|
||||
httputil.MakeInternalAPI("Backfill", func(req *http.Request) util.JSONResponse {
|
||||
var request backfill
|
||||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
res, err := intAPI.Backfill(req.Context(), request.S, request.RoomID, request.Limit, request.EventIDs)
|
||||
if err != nil {
|
||||
ferr, ok := err.(*api.FederationClientError)
|
||||
if ok {
|
||||
request.Err = ferr
|
||||
} else {
|
||||
request.Err = &api.FederationClientError{
|
||||
Err: err.Error(),
|
||||
}
|
||||
}
|
||||
}
|
||||
request.Res = &res
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: request}
|
||||
}),
|
||||
)
|
||||
internalAPIMux.Handle(
|
||||
FederationAPILookupStatePath,
|
||||
httputil.MakeInternalAPI("LookupState", func(req *http.Request) util.JSONResponse {
|
||||
var request lookupState
|
||||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
res, err := intAPI.LookupState(req.Context(), request.S, request.RoomID, request.EventID, request.RoomVersion)
|
||||
if err != nil {
|
||||
ferr, ok := err.(*api.FederationClientError)
|
||||
if ok {
|
||||
request.Err = ferr
|
||||
} else {
|
||||
request.Err = &api.FederationClientError{
|
||||
Err: err.Error(),
|
||||
}
|
||||
}
|
||||
}
|
||||
request.Res = &res
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: request}
|
||||
}),
|
||||
)
|
||||
internalAPIMux.Handle(
|
||||
FederationAPILookupStateIDsPath,
|
||||
httputil.MakeInternalAPI("LookupStateIDs", func(req *http.Request) util.JSONResponse {
|
||||
var request lookupStateIDs
|
||||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
res, err := intAPI.LookupStateIDs(req.Context(), request.S, request.RoomID, request.EventID)
|
||||
if err != nil {
|
||||
ferr, ok := err.(*api.FederationClientError)
|
||||
if ok {
|
||||
request.Err = ferr
|
||||
} else {
|
||||
request.Err = &api.FederationClientError{
|
||||
Err: err.Error(),
|
||||
}
|
||||
}
|
||||
}
|
||||
request.Res = &res
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: request}
|
||||
}),
|
||||
)
|
||||
internalAPIMux.Handle(
|
||||
FederationAPIGetEventPath,
|
||||
httputil.MakeInternalAPI("GetEvent", func(req *http.Request) util.JSONResponse {
|
||||
var request getEvent
|
||||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
res, err := intAPI.GetEvent(req.Context(), request.S, request.EventID)
|
||||
if err != nil {
|
||||
ferr, ok := err.(*api.FederationClientError)
|
||||
if ok {
|
||||
request.Err = ferr
|
||||
} else {
|
||||
request.Err = &api.FederationClientError{
|
||||
Err: err.Error(),
|
||||
}
|
||||
}
|
||||
}
|
||||
request.Res = &res
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: request}
|
||||
}),
|
||||
)
|
||||
internalAPIMux.Handle(
|
||||
FederationAPIQueryServerKeysPath,
|
||||
httputil.MakeInternalAPI("QueryServerKeys", func(req *http.Request) util.JSONResponse {
|
||||
var request api.QueryServerKeysRequest
|
||||
var response api.QueryServerKeysResponse
|
||||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
if err := intAPI.QueryServerKeys(req.Context(), &request, &response); err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||
}),
|
||||
)
|
||||
internalAPIMux.Handle(
|
||||
FederationAPILookupServerKeysPath,
|
||||
httputil.MakeInternalAPI("LookupServerKeys", func(req *http.Request) util.JSONResponse {
|
||||
var request lookupServerKeys
|
||||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
res, err := intAPI.LookupServerKeys(req.Context(), request.S, request.KeyRequests)
|
||||
if err != nil {
|
||||
ferr, ok := err.(*api.FederationClientError)
|
||||
if ok {
|
||||
request.Err = ferr
|
||||
} else {
|
||||
request.Err = &api.FederationClientError{
|
||||
Err: err.Error(),
|
||||
}
|
||||
}
|
||||
}
|
||||
request.ServerKeys = res
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: request}
|
||||
}),
|
||||
)
|
||||
internalAPIMux.Handle(
|
||||
FederationAPIEventRelationshipsPath,
|
||||
httputil.MakeInternalAPI("MSC2836EventRelationships", func(req *http.Request) util.JSONResponse {
|
||||
var request eventRelationships
|
||||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
res, err := intAPI.MSC2836EventRelationships(req.Context(), request.S, request.Req, request.RoomVer)
|
||||
if err != nil {
|
||||
ferr, ok := err.(*api.FederationClientError)
|
||||
if ok {
|
||||
request.Err = ferr
|
||||
} else {
|
||||
request.Err = &api.FederationClientError{
|
||||
Err: err.Error(),
|
||||
}
|
||||
}
|
||||
}
|
||||
request.Res = res
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: request}
|
||||
}),
|
||||
)
|
||||
internalAPIMux.Handle(
|
||||
FederationAPISpacesSummaryPath,
|
||||
httputil.MakeInternalAPI("MSC2946SpacesSummary", func(req *http.Request) util.JSONResponse {
|
||||
var request spacesReq
|
||||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
res, err := intAPI.MSC2946Spaces(req.Context(), request.S, request.RoomID, request.Req)
|
||||
if err != nil {
|
||||
ferr, ok := err.(*api.FederationClientError)
|
||||
if ok {
|
||||
request.Err = ferr
|
||||
} else {
|
||||
request.Err = &api.FederationClientError{
|
||||
Err: err.Error(),
|
||||
}
|
||||
}
|
||||
}
|
||||
request.Res = res
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: request}
|
||||
}),
|
||||
)
|
||||
internalAPIMux.Handle(FederationAPIQueryPublicKeyPath,
|
||||
httputil.MakeInternalAPI("queryPublicKeys", func(req *http.Request) util.JSONResponse {
|
||||
request := api.QueryPublicKeysRequest{}
|
||||
response := api.QueryPublicKeysResponse{}
|
||||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
keys, err := intAPI.FetchKeys(req.Context(), request.Requests)
|
||||
if err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
response.Results = keys
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||
}),
|
||||
)
|
||||
internalAPIMux.Handle(FederationAPIInputPublicKeyPath,
|
||||
httputil.MakeInternalAPI("inputPublicKeys", func(req *http.Request) util.JSONResponse {
|
||||
request := api.InputPublicKeysRequest{}
|
||||
response := api.InputPublicKeysResponse{}
|
||||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||
return util.MessageResponse(http.StatusBadRequest, err.Error())
|
||||
}
|
||||
if err := intAPI.StoreKeys(req.Context(), request.Keys); err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||
}),
|
||||
)
|
||||
}
|
448
federationapi/queue/destinationqueue.go
Normal file
448
federationapi/queue/destinationqueue.go
Normal file
|
@ -0,0 +1,448 @@
|
|||
// Copyright 2017 Vector Creations Ltd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package queue
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/federationapi/statistics"
|
||||
"github.com/matrix-org/dendrite/federationapi/storage"
|
||||
"github.com/matrix-org/dendrite/federationapi/storage/shared"
|
||||
"github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/setup/process"
|
||||
"github.com/matrix-org/gomatrix"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/sirupsen/logrus"
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
const (
|
||||
maxPDUsPerTransaction = 50
|
||||
maxEDUsPerTransaction = 50
|
||||
maxPDUsInMemory = 128
|
||||
maxEDUsInMemory = 128
|
||||
queueIdleTimeout = time.Second * 30
|
||||
)
|
||||
|
||||
// destinationQueue is a queue of events for a single destination.
|
||||
// It is responsible for sending the events to the destination and
|
||||
// ensures that only one request is in flight to a given destination
|
||||
// at a time.
|
||||
type destinationQueue struct {
|
||||
queues *OutgoingQueues
|
||||
db storage.Database
|
||||
process *process.ProcessContext
|
||||
signing *SigningInfo
|
||||
rsAPI api.RoomserverInternalAPI
|
||||
client *gomatrixserverlib.FederationClient // federation client
|
||||
origin gomatrixserverlib.ServerName // origin of requests
|
||||
destination gomatrixserverlib.ServerName // destination of requests
|
||||
running atomic.Bool // is the queue worker running?
|
||||
backingOff atomic.Bool // true if we're backing off
|
||||
overflowed atomic.Bool // the queues exceed maxPDUsInMemory/maxEDUsInMemory, so we should consult the database for more
|
||||
statistics *statistics.ServerStatistics // statistics about this remote server
|
||||
transactionIDMutex sync.Mutex // protects transactionID
|
||||
transactionID gomatrixserverlib.TransactionID // last transaction ID if retrying, or "" if last txn was successful
|
||||
notify chan struct{} // interrupts idle wait pending PDUs/EDUs
|
||||
pendingPDUs []*queuedPDU // PDUs waiting to be sent
|
||||
pendingEDUs []*queuedEDU // EDUs waiting to be sent
|
||||
pendingMutex sync.RWMutex // protects pendingPDUs and pendingEDUs
|
||||
interruptBackoff chan bool // interrupts backoff
|
||||
}
|
||||
|
||||
// Send event adds the event to the pending queue for the destination.
|
||||
// If the queue is empty then it starts a background goroutine to
|
||||
// start sending events to that destination.
|
||||
func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, receipt *shared.Receipt) {
|
||||
if event == nil {
|
||||
logrus.Errorf("attempt to send nil PDU with destination %q", oq.destination)
|
||||
return
|
||||
}
|
||||
// Create a database entry that associates the given PDU NID with
|
||||
// this destination queue. We'll then be able to retrieve the PDU
|
||||
// later.
|
||||
if err := oq.db.AssociatePDUWithDestination(
|
||||
context.TODO(),
|
||||
"", // TODO: remove this, as we don't need to persist the transaction ID
|
||||
oq.destination, // the destination server name
|
||||
receipt, // NIDs from federationapi_queue_json table
|
||||
); err != nil {
|
||||
logrus.WithError(err).Errorf("failed to associate PDU %q with destination %q", event.EventID(), oq.destination)
|
||||
return
|
||||
}
|
||||
// Check if the destination is blacklisted. If it isn't then wake
|
||||
// up the queue.
|
||||
if !oq.statistics.Blacklisted() {
|
||||
// If there's room in memory to hold the event then add it to the
|
||||
// list.
|
||||
oq.pendingMutex.Lock()
|
||||
if len(oq.pendingPDUs) < maxPDUsInMemory {
|
||||
oq.pendingPDUs = append(oq.pendingPDUs, &queuedPDU{
|
||||
pdu: event,
|
||||
receipt: receipt,
|
||||
})
|
||||
} else {
|
||||
oq.overflowed.Store(true)
|
||||
}
|
||||
oq.pendingMutex.Unlock()
|
||||
// Wake up the queue if it's asleep.
|
||||
oq.wakeQueueIfNeeded()
|
||||
select {
|
||||
case oq.notify <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// sendEDU adds the EDU event to the pending queue for the destination.
|
||||
// If the queue is empty then it starts a background goroutine to
|
||||
// start sending events to that destination.
|
||||
func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *shared.Receipt) {
|
||||
if event == nil {
|
||||
logrus.Errorf("attempt to send nil EDU with destination %q", oq.destination)
|
||||
return
|
||||
}
|
||||
// Create a database entry that associates the given PDU NID with
|
||||
// this destination queue. We'll then be able to retrieve the PDU
|
||||
// later.
|
||||
if err := oq.db.AssociateEDUWithDestination(
|
||||
context.TODO(),
|
||||
oq.destination, // the destination server name
|
||||
receipt, // NIDs from federationapi_queue_json table
|
||||
); err != nil {
|
||||
logrus.WithError(err).Errorf("failed to associate EDU with destination %q", oq.destination)
|
||||
return
|
||||
}
|
||||
// Check if the destination is blacklisted. If it isn't then wake
|
||||
// up the queue.
|
||||
if !oq.statistics.Blacklisted() {
|
||||
// If there's room in memory to hold the event then add it to the
|
||||
// list.
|
||||
oq.pendingMutex.Lock()
|
||||
if len(oq.pendingEDUs) < maxEDUsInMemory {
|
||||
oq.pendingEDUs = append(oq.pendingEDUs, &queuedEDU{
|
||||
edu: event,
|
||||
receipt: receipt,
|
||||
})
|
||||
} else {
|
||||
oq.overflowed.Store(true)
|
||||
}
|
||||
oq.pendingMutex.Unlock()
|
||||
// Wake up the queue if it's asleep.
|
||||
oq.wakeQueueIfNeeded()
|
||||
select {
|
||||
case oq.notify <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// wakeQueueIfNeeded will wake up the destination queue if it is
|
||||
// not already running. If it is running but it is backing off
|
||||
// then we will interrupt the backoff, causing any federation
|
||||
// requests to retry.
|
||||
func (oq *destinationQueue) wakeQueueIfNeeded() {
|
||||
// If we are backing off then interrupt the backoff.
|
||||
if oq.backingOff.CAS(true, false) {
|
||||
oq.interruptBackoff <- true
|
||||
}
|
||||
// If we aren't running then wake up the queue.
|
||||
if !oq.running.Load() {
|
||||
// Start the queue.
|
||||
go oq.backgroundSend()
|
||||
}
|
||||
}
|
||||
|
||||
// getPendingFromDatabase will look at the database and see if
|
||||
// there are any persisted events that haven't been sent to this
|
||||
// destination yet. If so, they will be queued up.
|
||||
func (oq *destinationQueue) getPendingFromDatabase() {
|
||||
// Check to see if there's anything to do for this server
|
||||
// in the database.
|
||||
retrieved := false
|
||||
ctx := context.Background()
|
||||
oq.pendingMutex.Lock()
|
||||
defer oq.pendingMutex.Unlock()
|
||||
|
||||
// Take a note of all of the PDUs and EDUs that we already
|
||||
// have cached. We will index them based on the receipt,
|
||||
// which ultimately just contains the index of the PDU/EDU
|
||||
// in the database.
|
||||
gotPDUs := map[string]struct{}{}
|
||||
gotEDUs := map[string]struct{}{}
|
||||
for _, pdu := range oq.pendingPDUs {
|
||||
gotPDUs[pdu.receipt.String()] = struct{}{}
|
||||
}
|
||||
for _, edu := range oq.pendingEDUs {
|
||||
gotEDUs[edu.receipt.String()] = struct{}{}
|
||||
}
|
||||
|
||||
if pduCapacity := maxPDUsInMemory - len(oq.pendingPDUs); pduCapacity > 0 {
|
||||
// We have room in memory for some PDUs - let's request no more than that.
|
||||
if pdus, err := oq.db.GetPendingPDUs(ctx, oq.destination, pduCapacity); err == nil {
|
||||
for receipt, pdu := range pdus {
|
||||
if _, ok := gotPDUs[receipt.String()]; ok {
|
||||
continue
|
||||
}
|
||||
oq.pendingPDUs = append(oq.pendingPDUs, &queuedPDU{receipt, pdu})
|
||||
retrieved = true
|
||||
}
|
||||
} else {
|
||||
logrus.WithError(err).Errorf("Failed to get pending PDUs for %q", oq.destination)
|
||||
}
|
||||
}
|
||||
if eduCapacity := maxEDUsInMemory - len(oq.pendingEDUs); eduCapacity > 0 {
|
||||
// We have room in memory for some EDUs - let's request no more than that.
|
||||
if edus, err := oq.db.GetPendingEDUs(ctx, oq.destination, eduCapacity); err == nil {
|
||||
for receipt, edu := range edus {
|
||||
if _, ok := gotEDUs[receipt.String()]; ok {
|
||||
continue
|
||||
}
|
||||
oq.pendingEDUs = append(oq.pendingEDUs, &queuedEDU{receipt, edu})
|
||||
retrieved = true
|
||||
}
|
||||
} else {
|
||||
logrus.WithError(err).Errorf("Failed to get pending EDUs for %q", oq.destination)
|
||||
}
|
||||
}
|
||||
// If we've retrieved all of the events from the database with room to spare
|
||||
// in memory then we'll no longer consider this queue to be overflowed.
|
||||
if len(oq.pendingPDUs) < maxPDUsInMemory && len(oq.pendingEDUs) < maxEDUsInMemory {
|
||||
oq.overflowed.Store(false)
|
||||
}
|
||||
// If we've retrieved some events then notify the destination queue goroutine.
|
||||
if retrieved {
|
||||
select {
|
||||
case oq.notify <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// backgroundSend is the worker goroutine for sending events.
|
||||
func (oq *destinationQueue) backgroundSend() {
|
||||
// Check if a worker is already running, and if it isn't, then
|
||||
// mark it as started.
|
||||
if !oq.running.CAS(false, true) {
|
||||
return
|
||||
}
|
||||
destinationQueueRunning.Inc()
|
||||
defer destinationQueueRunning.Dec()
|
||||
defer oq.queues.clearQueue(oq)
|
||||
defer oq.running.Store(false)
|
||||
|
||||
// Mark the queue as overflowed, so we will consult the database
|
||||
// to see if there's anything new to send.
|
||||
oq.overflowed.Store(true)
|
||||
|
||||
for {
|
||||
// If we are overflowing memory and have sent things out to the
|
||||
// database then we can look up what those things are.
|
||||
if oq.overflowed.Load() {
|
||||
oq.getPendingFromDatabase()
|
||||
}
|
||||
|
||||
// If we have nothing to do then wait either for incoming events, or
|
||||
// until we hit an idle timeout.
|
||||
select {
|
||||
case <-oq.notify:
|
||||
// There's work to do, either because getPendingFromDatabase
|
||||
// told us there is, or because a new event has come in via
|
||||
// sendEvent/sendEDU.
|
||||
case <-time.After(queueIdleTimeout):
|
||||
// The worker is idle so stop the goroutine. It'll get
|
||||
// restarted automatically the next time we have an event to
|
||||
// send.
|
||||
return
|
||||
}
|
||||
|
||||
// If we are backing off this server then wait for the
|
||||
// backoff duration to complete first, or until explicitly
|
||||
// told to retry.
|
||||
until, blacklisted := oq.statistics.BackoffInfo()
|
||||
if blacklisted {
|
||||
// It's been suggested that we should give up because the backoff
|
||||
// has exceeded a maximum allowable value. Clean up the in-memory
|
||||
// buffers at this point. The PDU clean-up is already on a defer.
|
||||
logrus.Warnf("Blacklisting %q due to exceeding backoff threshold", oq.destination)
|
||||
oq.pendingMutex.Lock()
|
||||
for i := range oq.pendingPDUs {
|
||||
oq.pendingPDUs[i] = nil
|
||||
}
|
||||
for i := range oq.pendingEDUs {
|
||||
oq.pendingEDUs[i] = nil
|
||||
}
|
||||
oq.pendingPDUs = nil
|
||||
oq.pendingEDUs = nil
|
||||
oq.pendingMutex.Unlock()
|
||||
return
|
||||
}
|
||||
if until != nil && until.After(time.Now()) {
|
||||
// We haven't backed off yet, so wait for the suggested amount of
|
||||
// time.
|
||||
duration := time.Until(*until)
|
||||
logrus.Warnf("Backing off %q for %s", oq.destination, duration)
|
||||
oq.backingOff.Store(true)
|
||||
destinationQueueBackingOff.Inc()
|
||||
select {
|
||||
case <-time.After(duration):
|
||||
case <-oq.interruptBackoff:
|
||||
}
|
||||
destinationQueueBackingOff.Dec()
|
||||
oq.backingOff.Store(false)
|
||||
}
|
||||
|
||||
// Work out which PDUs/EDUs to include in the next transaction.
|
||||
oq.pendingMutex.RLock()
|
||||
pduCount := len(oq.pendingPDUs)
|
||||
eduCount := len(oq.pendingEDUs)
|
||||
if pduCount > maxPDUsPerTransaction {
|
||||
pduCount = maxPDUsPerTransaction
|
||||
}
|
||||
if eduCount > maxEDUsPerTransaction {
|
||||
eduCount = maxEDUsPerTransaction
|
||||
}
|
||||
toSendPDUs := oq.pendingPDUs[:pduCount]
|
||||
toSendEDUs := oq.pendingEDUs[:eduCount]
|
||||
oq.pendingMutex.RUnlock()
|
||||
|
||||
// If we have pending PDUs or EDUs then construct a transaction.
|
||||
// Try sending the next transaction and see what happens.
|
||||
transaction, pc, ec, terr := oq.nextTransaction(toSendPDUs, toSendEDUs)
|
||||
if terr != nil {
|
||||
// We failed to send the transaction. Mark it as a failure.
|
||||
oq.statistics.Failure()
|
||||
|
||||
} else if transaction {
|
||||
// If we successfully sent the transaction then clear out
|
||||
// the pending events and EDUs, and wipe our transaction ID.
|
||||
oq.statistics.Success()
|
||||
oq.pendingMutex.Lock()
|
||||
for i := range oq.pendingPDUs[:pc] {
|
||||
oq.pendingPDUs[i] = nil
|
||||
}
|
||||
for i := range oq.pendingEDUs[:ec] {
|
||||
oq.pendingEDUs[i] = nil
|
||||
}
|
||||
oq.pendingPDUs = oq.pendingPDUs[pc:]
|
||||
oq.pendingEDUs = oq.pendingEDUs[ec:]
|
||||
oq.pendingMutex.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// nextTransaction creates a new transaction from the pending event
|
||||
// queue and sends it. Returns true if a transaction was sent or
|
||||
// false otherwise.
|
||||
func (oq *destinationQueue) nextTransaction(
|
||||
pdus []*queuedPDU,
|
||||
edus []*queuedEDU,
|
||||
) (bool, int, int, error) {
|
||||
// If there's no projected transaction ID then generate one. If
|
||||
// the transaction succeeds then we'll set it back to "" so that
|
||||
// we generate a new one next time. If it fails, we'll preserve
|
||||
// it so that we retry with the same transaction ID.
|
||||
oq.transactionIDMutex.Lock()
|
||||
if oq.transactionID == "" {
|
||||
now := gomatrixserverlib.AsTimestamp(time.Now())
|
||||
oq.transactionID = gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, oq.statistics.SuccessCount()))
|
||||
}
|
||||
oq.transactionIDMutex.Unlock()
|
||||
|
||||
// Create the transaction.
|
||||
t := gomatrixserverlib.Transaction{
|
||||
PDUs: []json.RawMessage{},
|
||||
EDUs: []gomatrixserverlib.EDU{},
|
||||
}
|
||||
t.Origin = oq.origin
|
||||
t.Destination = oq.destination
|
||||
t.OriginServerTS = gomatrixserverlib.AsTimestamp(time.Now())
|
||||
t.TransactionID = oq.transactionID
|
||||
|
||||
// If we didn't get anything from the database and there are no
|
||||
// pending EDUs then there's nothing to do - stop here.
|
||||
if len(pdus) == 0 && len(edus) == 0 {
|
||||
return false, 0, 0, nil
|
||||
}
|
||||
|
||||
var pduReceipts []*shared.Receipt
|
||||
var eduReceipts []*shared.Receipt
|
||||
|
||||
// Go through PDUs that we retrieved from the database, if any,
|
||||
// and add them into the transaction.
|
||||
for _, pdu := range pdus {
|
||||
if pdu == nil || pdu.pdu == nil {
|
||||
continue
|
||||
}
|
||||
// Append the JSON of the event, since this is a json.RawMessage type in the
|
||||
// gomatrixserverlib.Transaction struct
|
||||
t.PDUs = append(t.PDUs, pdu.pdu.JSON())
|
||||
pduReceipts = append(pduReceipts, pdu.receipt)
|
||||
}
|
||||
|
||||
// Do the same for pending EDUS in the queue.
|
||||
for _, edu := range edus {
|
||||
if edu == nil || edu.edu == nil {
|
||||
continue
|
||||
}
|
||||
t.EDUs = append(t.EDUs, *edu.edu)
|
||||
eduReceipts = append(eduReceipts, edu.receipt)
|
||||
}
|
||||
|
||||
logrus.WithField("server_name", oq.destination).Debugf("Sending transaction %q containing %d PDUs, %d EDUs", t.TransactionID, len(t.PDUs), len(t.EDUs))
|
||||
|
||||
// Try to send the transaction to the destination server.
|
||||
// TODO: we should check for 500-ish fails vs 400-ish here,
|
||||
// since we shouldn't queue things indefinitely in response
|
||||
// to a 400-ish error
|
||||
ctx, cancel := context.WithTimeout(oq.process.Context(), time.Minute*5)
|
||||
defer cancel()
|
||||
_, err := oq.client.SendTransaction(ctx, t)
|
||||
switch err.(type) {
|
||||
case nil:
|
||||
// Clean up the transaction in the database.
|
||||
if pduReceipts != nil {
|
||||
//logrus.Infof("Cleaning PDUs %q", pduReceipt.String())
|
||||
if err = oq.db.CleanPDUs(context.Background(), oq.destination, pduReceipts); err != nil {
|
||||
logrus.WithError(err).Errorf("Failed to clean PDUs for server %q", t.Destination)
|
||||
}
|
||||
}
|
||||
if eduReceipts != nil {
|
||||
//logrus.Infof("Cleaning EDUs %q", eduReceipt.String())
|
||||
if err = oq.db.CleanEDUs(context.Background(), oq.destination, eduReceipts); err != nil {
|
||||
logrus.WithError(err).Errorf("Failed to clean EDUs for server %q", t.Destination)
|
||||
}
|
||||
}
|
||||
// Reset the transaction ID.
|
||||
oq.transactionIDMutex.Lock()
|
||||
oq.transactionID = ""
|
||||
oq.transactionIDMutex.Unlock()
|
||||
return true, len(t.PDUs), len(t.EDUs), nil
|
||||
case gomatrix.HTTPError:
|
||||
// Report that we failed to send the transaction and we
|
||||
// will retry again, subject to backoff.
|
||||
return false, 0, 0, err
|
||||
default:
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"destination": oq.destination,
|
||||
logrus.ErrorKey: err,
|
||||
}).Debugf("Failed to send transaction %q", t.TransactionID)
|
||||
return false, 0, 0, err
|
||||
}
|
||||
}
|
339
federationapi/queue/queue.go
Normal file
339
federationapi/queue/queue.go
Normal file
|
@ -0,0 +1,339 @@
|
|||
// Copyright 2017 Vector Creations Ltd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package queue
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/federationapi/statistics"
|
||||
"github.com/matrix-org/dendrite/federationapi/storage"
|
||||
"github.com/matrix-org/dendrite/federationapi/storage/shared"
|
||||
"github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/setup/process"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// OutgoingQueues is a collection of queues for sending transactions to other
|
||||
// matrix servers
|
||||
type OutgoingQueues struct {
|
||||
db storage.Database
|
||||
process *process.ProcessContext
|
||||
disabled bool
|
||||
rsAPI api.RoomserverInternalAPI
|
||||
origin gomatrixserverlib.ServerName
|
||||
client *gomatrixserverlib.FederationClient
|
||||
statistics *statistics.Statistics
|
||||
signing *SigningInfo
|
||||
queuesMutex sync.Mutex // protects the below
|
||||
queues map[gomatrixserverlib.ServerName]*destinationQueue
|
||||
}
|
||||
|
||||
func init() {
|
||||
prometheus.MustRegister(
|
||||
destinationQueueTotal, destinationQueueRunning,
|
||||
destinationQueueBackingOff,
|
||||
)
|
||||
}
|
||||
|
||||
var destinationQueueTotal = prometheus.NewGauge(
|
||||
prometheus.GaugeOpts{
|
||||
Namespace: "dendrite",
|
||||
Subsystem: "federationapi",
|
||||
Name: "destination_queues_total",
|
||||
},
|
||||
)
|
||||
|
||||
var destinationQueueRunning = prometheus.NewGauge(
|
||||
prometheus.GaugeOpts{
|
||||
Namespace: "dendrite",
|
||||
Subsystem: "federationapi",
|
||||
Name: "destination_queues_running",
|
||||
},
|
||||
)
|
||||
|
||||
var destinationQueueBackingOff = prometheus.NewGauge(
|
||||
prometheus.GaugeOpts{
|
||||
Namespace: "dendrite",
|
||||
Subsystem: "federationapi",
|
||||
Name: "destination_queues_backing_off",
|
||||
},
|
||||
)
|
||||
|
||||
// NewOutgoingQueues makes a new OutgoingQueues
|
||||
func NewOutgoingQueues(
|
||||
db storage.Database,
|
||||
process *process.ProcessContext,
|
||||
disabled bool,
|
||||
origin gomatrixserverlib.ServerName,
|
||||
client *gomatrixserverlib.FederationClient,
|
||||
rsAPI api.RoomserverInternalAPI,
|
||||
statistics *statistics.Statistics,
|
||||
signing *SigningInfo,
|
||||
) *OutgoingQueues {
|
||||
queues := &OutgoingQueues{
|
||||
disabled: disabled,
|
||||
process: process,
|
||||
db: db,
|
||||
rsAPI: rsAPI,
|
||||
origin: origin,
|
||||
client: client,
|
||||
statistics: statistics,
|
||||
signing: signing,
|
||||
queues: map[gomatrixserverlib.ServerName]*destinationQueue{},
|
||||
}
|
||||
// Look up which servers we have pending items for and then rehydrate those queues.
|
||||
if !disabled {
|
||||
time.AfterFunc(time.Second*5, func() {
|
||||
serverNames := map[gomatrixserverlib.ServerName]struct{}{}
|
||||
if names, err := db.GetPendingPDUServerNames(context.Background()); err == nil {
|
||||
for _, serverName := range names {
|
||||
serverNames[serverName] = struct{}{}
|
||||
}
|
||||
} else {
|
||||
log.WithError(err).Error("Failed to get PDU server names for destination queue hydration")
|
||||
}
|
||||
if names, err := db.GetPendingEDUServerNames(context.Background()); err == nil {
|
||||
for _, serverName := range names {
|
||||
serverNames[serverName] = struct{}{}
|
||||
}
|
||||
} else {
|
||||
log.WithError(err).Error("Failed to get EDU server names for destination queue hydration")
|
||||
}
|
||||
for serverName := range serverNames {
|
||||
if queue := queues.getQueue(serverName); queue != nil {
|
||||
queue.wakeQueueIfNeeded()
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
return queues
|
||||
}
|
||||
|
||||
// TODO: Move this somewhere useful for other components as we often need to ferry these 3 variables
|
||||
// around together
|
||||
type SigningInfo struct {
|
||||
ServerName gomatrixserverlib.ServerName
|
||||
KeyID gomatrixserverlib.KeyID
|
||||
PrivateKey ed25519.PrivateKey
|
||||
}
|
||||
|
||||
type queuedPDU struct {
|
||||
receipt *shared.Receipt
|
||||
pdu *gomatrixserverlib.HeaderedEvent
|
||||
}
|
||||
|
||||
type queuedEDU struct {
|
||||
receipt *shared.Receipt
|
||||
edu *gomatrixserverlib.EDU
|
||||
}
|
||||
|
||||
func (oqs *OutgoingQueues) getQueue(destination gomatrixserverlib.ServerName) *destinationQueue {
|
||||
if oqs.statistics.ForServer(destination).Blacklisted() {
|
||||
return nil
|
||||
}
|
||||
oqs.queuesMutex.Lock()
|
||||
defer oqs.queuesMutex.Unlock()
|
||||
oq, ok := oqs.queues[destination]
|
||||
if !ok || oq != nil {
|
||||
destinationQueueTotal.Inc()
|
||||
oq = &destinationQueue{
|
||||
queues: oqs,
|
||||
db: oqs.db,
|
||||
process: oqs.process,
|
||||
rsAPI: oqs.rsAPI,
|
||||
origin: oqs.origin,
|
||||
destination: destination,
|
||||
client: oqs.client,
|
||||
statistics: oqs.statistics.ForServer(destination),
|
||||
notify: make(chan struct{}, 1),
|
||||
interruptBackoff: make(chan bool),
|
||||
signing: oqs.signing,
|
||||
}
|
||||
oqs.queues[destination] = oq
|
||||
}
|
||||
return oq
|
||||
}
|
||||
|
||||
func (oqs *OutgoingQueues) clearQueue(oq *destinationQueue) {
|
||||
oqs.queuesMutex.Lock()
|
||||
defer oqs.queuesMutex.Unlock()
|
||||
|
||||
delete(oqs.queues, oq.destination)
|
||||
destinationQueueTotal.Dec()
|
||||
}
|
||||
|
||||
type ErrorFederationDisabled struct {
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e *ErrorFederationDisabled) Error() string {
|
||||
return e.Message
|
||||
}
|
||||
|
||||
// SendEvent sends an event to the destinations
|
||||
func (oqs *OutgoingQueues) SendEvent(
|
||||
ev *gomatrixserverlib.HeaderedEvent, origin gomatrixserverlib.ServerName,
|
||||
destinations []gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
if oqs.disabled {
|
||||
return &ErrorFederationDisabled{
|
||||
Message: "Federation disabled",
|
||||
}
|
||||
}
|
||||
if origin != oqs.origin {
|
||||
// TODO: Support virtual hosting; gh issue #577.
|
||||
return fmt.Errorf(
|
||||
"sendevent: unexpected server to send as: got %q expected %q",
|
||||
origin, oqs.origin,
|
||||
)
|
||||
}
|
||||
|
||||
// Deduplicate destinations and remove the origin from the list of
|
||||
// destinations just to be sure.
|
||||
destmap := map[gomatrixserverlib.ServerName]struct{}{}
|
||||
for _, d := range destinations {
|
||||
destmap[d] = struct{}{}
|
||||
}
|
||||
delete(destmap, oqs.origin)
|
||||
|
||||
// Check if any of the destinations are prohibited by server ACLs.
|
||||
for destination := range destmap {
|
||||
if api.IsServerBannedFromRoom(
|
||||
context.TODO(),
|
||||
oqs.rsAPI,
|
||||
ev.RoomID(),
|
||||
destination,
|
||||
) {
|
||||
delete(destmap, destination)
|
||||
}
|
||||
}
|
||||
|
||||
// If there are no remaining destinations then give up.
|
||||
if len(destmap) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"destinations": len(destmap), "event": ev.EventID(),
|
||||
}).Infof("Sending event")
|
||||
|
||||
headeredJSON, err := json.Marshal(ev)
|
||||
if err != nil {
|
||||
return fmt.Errorf("json.Marshal: %w", err)
|
||||
}
|
||||
|
||||
nid, err := oqs.db.StoreJSON(context.TODO(), string(headeredJSON))
|
||||
if err != nil {
|
||||
return fmt.Errorf("sendevent: oqs.db.StoreJSON: %w", err)
|
||||
}
|
||||
|
||||
for destination := range destmap {
|
||||
if queue := oqs.getQueue(destination); queue != nil {
|
||||
queue.sendEvent(ev, nid)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendEDU sends an EDU event to the destinations.
|
||||
func (oqs *OutgoingQueues) SendEDU(
|
||||
e *gomatrixserverlib.EDU, origin gomatrixserverlib.ServerName,
|
||||
destinations []gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
if oqs.disabled {
|
||||
return &ErrorFederationDisabled{
|
||||
Message: "Federation disabled",
|
||||
}
|
||||
}
|
||||
if origin != oqs.origin {
|
||||
// TODO: Support virtual hosting; gh issue #577.
|
||||
return fmt.Errorf(
|
||||
"sendevent: unexpected server to send as: got %q expected %q",
|
||||
origin, oqs.origin,
|
||||
)
|
||||
}
|
||||
|
||||
// Deduplicate destinations and remove the origin from the list of
|
||||
// destinations just to be sure.
|
||||
destmap := map[gomatrixserverlib.ServerName]struct{}{}
|
||||
for _, d := range destinations {
|
||||
destmap[d] = struct{}{}
|
||||
}
|
||||
delete(destmap, oqs.origin)
|
||||
|
||||
// There is absolutely no guarantee that the EDU will have a room_id
|
||||
// field, as it is not required by the spec. However, if it *does*
|
||||
// (e.g. typing notifications) then we should try to make sure we don't
|
||||
// bother sending them to servers that are prohibited by the server
|
||||
// ACLs.
|
||||
if result := gjson.GetBytes(e.Content, "room_id"); result.Exists() {
|
||||
for destination := range destmap {
|
||||
if api.IsServerBannedFromRoom(
|
||||
context.TODO(),
|
||||
oqs.rsAPI,
|
||||
result.Str,
|
||||
destination,
|
||||
) {
|
||||
delete(destmap, destination)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If there are no remaining destinations then give up.
|
||||
if len(destmap) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"destinations": len(destmap), "edu_type": e.Type,
|
||||
}).Info("Sending EDU event")
|
||||
|
||||
ephemeralJSON, err := json.Marshal(e)
|
||||
if err != nil {
|
||||
return fmt.Errorf("json.Marshal: %w", err)
|
||||
}
|
||||
|
||||
nid, err := oqs.db.StoreJSON(context.TODO(), string(ephemeralJSON))
|
||||
if err != nil {
|
||||
return fmt.Errorf("sendevent: oqs.db.StoreJSON: %w", err)
|
||||
}
|
||||
|
||||
for destination := range destmap {
|
||||
if queue := oqs.getQueue(destination); queue != nil {
|
||||
queue.sendEDU(e, nid)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RetryServer attempts to resend events to the given server if we had given up.
|
||||
func (oqs *OutgoingQueues) RetryServer(srv gomatrixserverlib.ServerName) {
|
||||
if oqs.disabled {
|
||||
return
|
||||
}
|
||||
if queue := oqs.getQueue(srv); queue != nil {
|
||||
queue.wakeQueueIfNeeded()
|
||||
}
|
||||
}
|
|
@ -21,7 +21,7 @@ import (
|
|||
|
||||
"github.com/matrix-org/dendrite/clientapi/httputil"
|
||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||
federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api"
|
||||
federationAPI "github.com/matrix-org/dendrite/federationapi/api"
|
||||
"github.com/matrix-org/dendrite/keyserver/api"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
@ -179,7 +179,7 @@ func localKeys(cfg *config.FederationAPI, validUntil time.Time) (*gomatrixserver
|
|||
|
||||
func NotaryKeys(
|
||||
httpReq *http.Request, cfg *config.FederationAPI,
|
||||
fsAPI federationSenderAPI.FederationSenderInternalAPI,
|
||||
fsAPI federationAPI.FederationInternalAPI,
|
||||
req *gomatrixserverlib.PublicKeyNotaryLookupRequest,
|
||||
) util.JSONResponse {
|
||||
if req == nil {
|
||||
|
@ -203,8 +203,8 @@ func NotaryKeys(
|
|||
return util.ErrorResponse(err)
|
||||
}
|
||||
} else {
|
||||
var resp federationSenderAPI.QueryServerKeysResponse
|
||||
err := fsAPI.QueryServerKeys(httpReq.Context(), &federationSenderAPI.QueryServerKeysRequest{
|
||||
var resp federationAPI.QueryServerKeysResponse
|
||||
err := fsAPI.QueryServerKeys(httpReq.Context(), &federationAPI.QueryServerKeysRequest{
|
||||
ServerName: serverName,
|
||||
KeyIDToCriteria: kidToCriteria,
|
||||
}, &resp)
|
||||
|
|
|
@ -33,7 +33,7 @@ func Peek(
|
|||
roomID, peekID string,
|
||||
remoteVersions []gomatrixserverlib.RoomVersion,
|
||||
) util.JSONResponse {
|
||||
// TODO: check if we're just refreshing an existing peek by querying the federationsender
|
||||
// TODO: check if we're just refreshing an existing peek by querying the federationapi
|
||||
|
||||
verReq := api.QueryRoomVersionForRoomRequest{RoomID: roomID}
|
||||
verRes := api.QueryRoomVersionForRoomResponse{}
|
||||
|
|
|
@ -19,7 +19,7 @@ import (
|
|||
"net/http"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||
federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api"
|
||||
federationAPI "github.com/matrix-org/dendrite/federationapi/api"
|
||||
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/gomatrix"
|
||||
|
@ -33,7 +33,7 @@ func RoomAliasToID(
|
|||
federation *gomatrixserverlib.FederationClient,
|
||||
cfg *config.FederationAPI,
|
||||
rsAPI roomserverAPI.RoomserverInternalAPI,
|
||||
senderAPI federationSenderAPI.FederationSenderInternalAPI,
|
||||
senderAPI federationAPI.FederationInternalAPI,
|
||||
) util.JSONResponse {
|
||||
roomAlias := httpReq.FormValue("room_alias")
|
||||
if roomAlias == "" {
|
||||
|
@ -64,8 +64,8 @@ func RoomAliasToID(
|
|||
}
|
||||
|
||||
if queryRes.RoomID != "" {
|
||||
serverQueryReq := federationSenderAPI.QueryJoinedHostServerNamesInRoomRequest{RoomID: queryRes.RoomID}
|
||||
var serverQueryRes federationSenderAPI.QueryJoinedHostServerNamesInRoomResponse
|
||||
serverQueryReq := federationAPI.QueryJoinedHostServerNamesInRoomRequest{RoomID: queryRes.RoomID}
|
||||
var serverQueryRes federationAPI.QueryJoinedHostServerNamesInRoomResponse
|
||||
if err = senderAPI.QueryJoinedHostServerNamesInRoom(httpReq.Context(), &serverQueryReq, &serverQueryRes); err != nil {
|
||||
util.GetLogger(httpReq.Context()).WithError(err).Error("senderAPI.QueryJoinedHostServerNamesInRoom failed")
|
||||
return jsonerror.InternalServerError()
|
||||
|
|
|
@ -21,7 +21,6 @@ import (
|
|||
"github.com/matrix-org/dendrite/clientapi/jsonerror"
|
||||
eduserverAPI "github.com/matrix-org/dendrite/eduserver/api"
|
||||
federationAPI "github.com/matrix-org/dendrite/federationapi/api"
|
||||
federationSenderAPI "github.com/matrix-org/dendrite/federationsender/api"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/httputil"
|
||||
keyserverAPI "github.com/matrix-org/dendrite/keyserver/api"
|
||||
|
@ -46,7 +45,7 @@ func Setup(
|
|||
cfg *config.FederationAPI,
|
||||
rsAPI roomserverAPI.RoomserverInternalAPI,
|
||||
eduAPI eduserverAPI.EDUServerInputAPI,
|
||||
fsAPI federationSenderAPI.FederationSenderInternalAPI,
|
||||
fsAPI federationAPI.FederationInternalAPI,
|
||||
keys gomatrixserverlib.JSONVerifier,
|
||||
federation *gomatrixserverlib.FederationClient,
|
||||
userAPI userapi.UserInternalAPI,
|
||||
|
|
180
federationapi/statistics/statistics.go
Normal file
180
federationapi/statistics/statistics.go
Normal file
|
@ -0,0 +1,180 @@
|
|||
package statistics
|
||||
|
||||
import (
|
||||
"math"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/federationapi/storage"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/sirupsen/logrus"
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
// Statistics contains information about all of the remote federated
|
||||
// hosts that we have interacted with. It is basically a threadsafe
|
||||
// wrapper.
|
||||
type Statistics struct {
|
||||
DB storage.Database
|
||||
servers map[gomatrixserverlib.ServerName]*ServerStatistics
|
||||
mutex sync.RWMutex
|
||||
|
||||
// How many times should we tolerate consecutive failures before we
|
||||
// just blacklist the host altogether? The backoff is exponential,
|
||||
// so the max time here to attempt is 2**failures seconds.
|
||||
FailuresUntilBlacklist uint32
|
||||
}
|
||||
|
||||
// ForServer returns server statistics for the given server name. If it
|
||||
// does not exist, it will create empty statistics and return those.
|
||||
func (s *Statistics) ForServer(serverName gomatrixserverlib.ServerName) *ServerStatistics {
|
||||
// If the map hasn't been initialised yet then do that.
|
||||
if s.servers == nil {
|
||||
s.mutex.Lock()
|
||||
s.servers = make(map[gomatrixserverlib.ServerName]*ServerStatistics)
|
||||
s.mutex.Unlock()
|
||||
}
|
||||
// Look up if we have statistics for this server already.
|
||||
s.mutex.RLock()
|
||||
server, found := s.servers[serverName]
|
||||
s.mutex.RUnlock()
|
||||
// If we don't, then make one.
|
||||
if !found {
|
||||
s.mutex.Lock()
|
||||
server = &ServerStatistics{
|
||||
statistics: s,
|
||||
serverName: serverName,
|
||||
interrupt: make(chan struct{}),
|
||||
}
|
||||
s.servers[serverName] = server
|
||||
s.mutex.Unlock()
|
||||
blacklisted, err := s.DB.IsServerBlacklisted(serverName)
|
||||
if err != nil {
|
||||
logrus.WithError(err).Errorf("Failed to get blacklist entry %q", serverName)
|
||||
} else {
|
||||
server.blacklisted.Store(blacklisted)
|
||||
}
|
||||
}
|
||||
return server
|
||||
}
|
||||
|
||||
// ServerStatistics contains information about our interactions with a
|
||||
// remote federated host, e.g. how many times we were successful, how
|
||||
// many times we failed etc. It also manages the backoff time and black-
|
||||
// listing a remote host if it remains uncooperative.
|
||||
type ServerStatistics struct {
|
||||
statistics *Statistics //
|
||||
serverName gomatrixserverlib.ServerName //
|
||||
blacklisted atomic.Bool // is the node blacklisted
|
||||
backoffStarted atomic.Bool // is the backoff started
|
||||
backoffUntil atomic.Value // time.Time until this backoff interval ends
|
||||
backoffCount atomic.Uint32 // number of times BackoffDuration has been called
|
||||
interrupt chan struct{} // interrupts the backoff goroutine
|
||||
successCounter atomic.Uint32 // how many times have we succeeded?
|
||||
}
|
||||
|
||||
// duration returns how long the next backoff interval should be.
|
||||
func (s *ServerStatistics) duration(count uint32) time.Duration {
|
||||
return time.Second * time.Duration(math.Exp2(float64(count)))
|
||||
}
|
||||
|
||||
// cancel will interrupt the currently active backoff.
|
||||
func (s *ServerStatistics) cancel() {
|
||||
s.blacklisted.Store(false)
|
||||
s.backoffUntil.Store(time.Time{})
|
||||
select {
|
||||
case s.interrupt <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// Success updates the server statistics with a new successful
|
||||
// attempt, which increases the sent counter and resets the idle and
|
||||
// failure counters. If a host was blacklisted at this point then
|
||||
// we will unblacklist it.
|
||||
func (s *ServerStatistics) Success() {
|
||||
s.cancel()
|
||||
s.successCounter.Inc()
|
||||
s.backoffCount.Store(0)
|
||||
if s.statistics.DB != nil {
|
||||
if err := s.statistics.DB.RemoveServerFromBlacklist(s.serverName); err != nil {
|
||||
logrus.WithError(err).Errorf("Failed to remove %q from blacklist", s.serverName)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Failure marks a failure and starts backing off if needed.
|
||||
// The next call to BackoffIfRequired will do the right thing
|
||||
// after this. It will return the time that the current failure
|
||||
// will result in backoff waiting until, and a bool signalling
|
||||
// whether we have blacklisted and therefore to give up.
|
||||
func (s *ServerStatistics) Failure() (time.Time, bool) {
|
||||
// If we aren't already backing off, this call will start
|
||||
// a new backoff period. Increase the failure counter and
|
||||
// start a goroutine which will wait out the backoff and
|
||||
// unset the backoffStarted flag when done.
|
||||
if s.backoffStarted.CAS(false, true) {
|
||||
if s.backoffCount.Inc() >= s.statistics.FailuresUntilBlacklist {
|
||||
s.blacklisted.Store(true)
|
||||
if s.statistics.DB != nil {
|
||||
if err := s.statistics.DB.AddServerToBlacklist(s.serverName); err != nil {
|
||||
logrus.WithError(err).Errorf("Failed to add %q to blacklist", s.serverName)
|
||||
}
|
||||
}
|
||||
return time.Time{}, true
|
||||
}
|
||||
|
||||
go func() {
|
||||
until, ok := s.backoffUntil.Load().(time.Time)
|
||||
if ok {
|
||||
select {
|
||||
case <-time.After(time.Until(until)):
|
||||
case <-s.interrupt:
|
||||
}
|
||||
}
|
||||
s.backoffStarted.Store(false)
|
||||
}()
|
||||
}
|
||||
|
||||
// Check if we have blacklisted this node.
|
||||
if s.blacklisted.Load() {
|
||||
return time.Now(), true
|
||||
}
|
||||
|
||||
// If we're already backing off and we haven't yet surpassed
|
||||
// the deadline then return that. Repeated calls to Failure
|
||||
// within a single backoff interval will have no side effects.
|
||||
if until, ok := s.backoffUntil.Load().(time.Time); ok && !time.Now().After(until) {
|
||||
return until, false
|
||||
}
|
||||
|
||||
// We're either backing off and have passed the deadline, or
|
||||
// we aren't backing off, so work out what the next interval
|
||||
// will be.
|
||||
count := s.backoffCount.Load()
|
||||
until := time.Now().Add(s.duration(count))
|
||||
s.backoffUntil.Store(until)
|
||||
return until, false
|
||||
}
|
||||
|
||||
// BackoffInfo returns information about the current or previous backoff.
|
||||
// Returns the last backoffUntil time and whether the server is currently blacklisted or not.
|
||||
func (s *ServerStatistics) BackoffInfo() (*time.Time, bool) {
|
||||
until, ok := s.backoffUntil.Load().(time.Time)
|
||||
if ok {
|
||||
return &until, s.blacklisted.Load()
|
||||
}
|
||||
return nil, s.blacklisted.Load()
|
||||
}
|
||||
|
||||
// Blacklisted returns true if the server is blacklisted and false
|
||||
// otherwise.
|
||||
func (s *ServerStatistics) Blacklisted() bool {
|
||||
return s.blacklisted.Load()
|
||||
}
|
||||
|
||||
// SuccessCount returns the number of successful requests. This is
|
||||
// usually useful in constructing transaction IDs.
|
||||
func (s *ServerStatistics) SuccessCount() uint32 {
|
||||
return s.successCounter.Load()
|
||||
}
|
64
federationapi/statistics/statistics_test.go
Normal file
64
federationapi/statistics/statistics_test.go
Normal file
|
@ -0,0 +1,64 @@
|
|||
package statistics
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestBackoff(t *testing.T) {
|
||||
stats := Statistics{
|
||||
FailuresUntilBlacklist: 7,
|
||||
}
|
||||
server := ServerStatistics{
|
||||
statistics: &stats,
|
||||
serverName: "test.com",
|
||||
}
|
||||
|
||||
// Start by checking that counting successes works.
|
||||
server.Success()
|
||||
if successes := server.SuccessCount(); successes != 1 {
|
||||
t.Fatalf("Expected success count 1, got %d", successes)
|
||||
}
|
||||
|
||||
// Register a failure.
|
||||
server.Failure()
|
||||
|
||||
t.Logf("Backoff counter: %d", server.backoffCount.Load())
|
||||
|
||||
// Now we're going to simulate backing off a few times to see
|
||||
// what happens.
|
||||
for i := uint32(1); i <= 10; i++ {
|
||||
// Register another failure for good measure. This should have no
|
||||
// side effects since a backoff is already in progress. If it does
|
||||
// then we'll fail.
|
||||
until, blacklisted := server.Failure()
|
||||
|
||||
// Get the duration.
|
||||
_, blacklist := server.BackoffInfo()
|
||||
duration := time.Until(until).Round(time.Second)
|
||||
|
||||
// Unset the backoff, or otherwise our next call will think that
|
||||
// there's a backoff in progress and return the same result.
|
||||
server.cancel()
|
||||
server.backoffStarted.Store(false)
|
||||
|
||||
// Check if we should be blacklisted by now.
|
||||
if i >= stats.FailuresUntilBlacklist {
|
||||
if !blacklist {
|
||||
t.Fatalf("Backoff %d should have resulted in blacklist but didn't", i)
|
||||
} else if blacklist != blacklisted {
|
||||
t.Fatalf("BackoffInfo and Failure returned different blacklist values")
|
||||
} else {
|
||||
t.Logf("Backoff %d is blacklisted as expected", i)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Check if the duration is what we expect.
|
||||
t.Logf("Backoff %d is for %s", i, duration)
|
||||
if wanted := time.Second * time.Duration(math.Exp2(float64(i))); !blacklist && duration != wanted {
|
||||
t.Fatalf("Backoff %d should have been %s but was %s", i, wanted, duration)
|
||||
}
|
||||
}
|
||||
}
|
68
federationapi/storage/cache/keydb.go
vendored
Normal file
68
federationapi/storage/cache/keydb.go
vendored
Normal file
|
@ -0,0 +1,68 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/caching"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// A Database implements gomatrixserverlib.KeyDatabase and is used to store
|
||||
// the public keys for other matrix servers.
|
||||
type KeyDatabase struct {
|
||||
inner gomatrixserverlib.KeyDatabase
|
||||
cache caching.ServerKeyCache
|
||||
}
|
||||
|
||||
func NewKeyDatabase(inner gomatrixserverlib.KeyDatabase, cache caching.ServerKeyCache) (*KeyDatabase, error) {
|
||||
if inner == nil {
|
||||
return nil, errors.New("inner database can't be nil")
|
||||
}
|
||||
if cache == nil {
|
||||
return nil, errors.New("cache can't be nil")
|
||||
}
|
||||
return &KeyDatabase{
|
||||
inner: inner,
|
||||
cache: cache,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// FetcherName implements KeyFetcher
|
||||
func (d KeyDatabase) FetcherName() string {
|
||||
return "InMemoryKeyCache"
|
||||
}
|
||||
|
||||
// FetchKeys implements gomatrixserverlib.KeyDatabase
|
||||
func (d *KeyDatabase) FetchKeys(
|
||||
ctx context.Context,
|
||||
requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
|
||||
) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) {
|
||||
results := make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult)
|
||||
for req, ts := range requests {
|
||||
if res, cached := d.cache.GetServerKey(req, ts); cached {
|
||||
results[req] = res
|
||||
delete(requests, req)
|
||||
}
|
||||
}
|
||||
fromDB, err := d.inner.FetchKeys(ctx, requests)
|
||||
if err != nil {
|
||||
return results, err
|
||||
}
|
||||
for req, res := range fromDB {
|
||||
results[req] = res
|
||||
d.cache.StoreServerKey(req, res)
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// StoreKeys implements gomatrixserverlib.KeyDatabase
|
||||
func (d *KeyDatabase) StoreKeys(
|
||||
ctx context.Context,
|
||||
keyMap map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult,
|
||||
) error {
|
||||
for req, res := range keyMap {
|
||||
d.cache.StoreServerKey(req, res)
|
||||
}
|
||||
return d.inner.StoreKeys(ctx, keyMap)
|
||||
}
|
76
federationapi/storage/interface.go
Normal file
76
federationapi/storage/interface.go
Normal file
|
@ -0,0 +1,76 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/matrix-org/dendrite/federationapi/storage/shared"
|
||||
"github.com/matrix-org/dendrite/federationapi/types"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
type Database interface {
|
||||
internal.PartitionStorer
|
||||
gomatrixserverlib.KeyDatabase
|
||||
|
||||
UpdateRoom(ctx context.Context, roomID, oldEventID, newEventID string, addHosts []types.JoinedHost, removeHosts []string) (joinedHosts []types.JoinedHost, err error)
|
||||
|
||||
GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error)
|
||||
GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
|
||||
// GetJoinedHostsForRooms returns the complete set of servers in the rooms given.
|
||||
GetJoinedHostsForRooms(ctx context.Context, roomIDs []string) ([]gomatrixserverlib.ServerName, error)
|
||||
PurgeRoomState(ctx context.Context, roomID string) error
|
||||
|
||||
StoreJSON(ctx context.Context, js string) (*shared.Receipt, error)
|
||||
|
||||
GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent, err error)
|
||||
GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*shared.Receipt]*gomatrixserverlib.EDU, err error)
|
||||
|
||||
AssociatePDUWithDestination(ctx context.Context, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt) error
|
||||
AssociateEDUWithDestination(ctx context.Context, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt) error
|
||||
|
||||
CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error
|
||||
CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error
|
||||
|
||||
GetPendingPDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error)
|
||||
GetPendingEDUCount(ctx context.Context, serverName gomatrixserverlib.ServerName) (int64, error)
|
||||
|
||||
GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
|
||||
GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
|
||||
|
||||
// these don't have contexts passed in as we want things to happen regardless of the request context
|
||||
AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error
|
||||
RemoveServerFromBlacklist(serverName gomatrixserverlib.ServerName) error
|
||||
RemoveAllServersFromBlacklist() error
|
||||
IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error)
|
||||
|
||||
AddOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error
|
||||
RenewOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error
|
||||
GetOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.OutboundPeek, error)
|
||||
GetOutboundPeeks(ctx context.Context, roomID string) ([]types.OutboundPeek, error)
|
||||
|
||||
AddInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error
|
||||
RenewInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error
|
||||
GetInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.InboundPeek, error)
|
||||
GetInboundPeeks(ctx context.Context, roomID string) ([]types.InboundPeek, error)
|
||||
|
||||
// Update the notary with the given server keys from the given server name.
|
||||
UpdateNotaryKeys(ctx context.Context, serverName gomatrixserverlib.ServerName, serverKeys gomatrixserverlib.ServerKeys) error
|
||||
// Query the notary for the server keys for the given server. If `optKeyIDs` is not empty, multiple server keys may be returned (between 1 - len(optKeyIDs))
|
||||
// such that the combination of all server keys will include all the `optKeyIDs`.
|
||||
GetNotaryKeys(ctx context.Context, serverName gomatrixserverlib.ServerName, optKeyIDs []gomatrixserverlib.KeyID) ([]gomatrixserverlib.ServerKeys, error)
|
||||
}
|
115
federationapi/storage/postgres/blacklist_table.go
Normal file
115
federationapi/storage/postgres/blacklist_table.go
Normal file
|
@ -0,0 +1,115 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const blacklistSchema = `
|
||||
CREATE TABLE IF NOT EXISTS federationsender_blacklist (
|
||||
-- The blacklisted server name
|
||||
server_name TEXT NOT NULL,
|
||||
UNIQUE (server_name)
|
||||
);
|
||||
`
|
||||
|
||||
const insertBlacklistSQL = "" +
|
||||
"INSERT INTO federationsender_blacklist (server_name) VALUES ($1)" +
|
||||
" ON CONFLICT DO NOTHING"
|
||||
|
||||
const selectBlacklistSQL = "" +
|
||||
"SELECT server_name FROM federationsender_blacklist WHERE server_name = $1"
|
||||
|
||||
const deleteBlacklistSQL = "" +
|
||||
"DELETE FROM federationsender_blacklist WHERE server_name = $1"
|
||||
|
||||
const deleteAllBlacklistSQL = "" +
|
||||
"TRUNCATE federationsender_blacklist"
|
||||
|
||||
type blacklistStatements struct {
|
||||
db *sql.DB
|
||||
insertBlacklistStmt *sql.Stmt
|
||||
selectBlacklistStmt *sql.Stmt
|
||||
deleteBlacklistStmt *sql.Stmt
|
||||
deleteAllBlacklistStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresBlacklistTable(db *sql.DB) (s *blacklistStatements, err error) {
|
||||
s = &blacklistStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err = db.Exec(blacklistSchema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if s.insertBlacklistStmt, err = db.Prepare(insertBlacklistSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectBlacklistStmt, err = db.Prepare(selectBlacklistSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteBlacklistStmt, err = db.Prepare(deleteBlacklistSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteAllBlacklistStmt, err = db.Prepare(deleteAllBlacklistSQL); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *blacklistStatements) InsertBlacklist(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertBlacklistStmt)
|
||||
_, err := stmt.ExecContext(ctx, serverName)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *blacklistStatements) SelectBlacklist(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
||||
) (bool, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectBlacklistStmt)
|
||||
res, err := stmt.QueryContext(ctx, serverName)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer res.Close() // nolint:errcheck
|
||||
// The query will return the server name if the server is blacklisted, and
|
||||
// will return no rows if not. By calling Next, we find out if a row was
|
||||
// returned or not - we don't care about the value itself.
|
||||
return res.Next(), nil
|
||||
}
|
||||
|
||||
func (s *blacklistStatements) DeleteBlacklist(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteBlacklistStmt)
|
||||
_, err := stmt.ExecContext(ctx, serverName)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *blacklistStatements) DeleteAllBlacklist(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteAllBlacklistStmt)
|
||||
_, err := stmt.ExecContext(ctx)
|
||||
return err
|
||||
}
|
|
@ -0,0 +1,46 @@
|
|||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package deltas
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/pressly/goose"
|
||||
)
|
||||
|
||||
func LoadFromGoose() {
|
||||
goose.AddMigration(UpRemoveRoomsTable, DownRemoveRoomsTable)
|
||||
}
|
||||
|
||||
func LoadRemoveRoomsTable(m *sqlutil.Migrations) {
|
||||
m.AddMigration(UpRemoveRoomsTable, DownRemoveRoomsTable)
|
||||
}
|
||||
|
||||
func UpRemoveRoomsTable(tx *sql.Tx) error {
|
||||
_, err := tx.Exec(`
|
||||
DROP TABLE IF EXISTS federationsender_rooms;
|
||||
`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute upgrade: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func DownRemoveRoomsTable(tx *sql.Tx) error {
|
||||
// We can't reverse this.
|
||||
return nil
|
||||
}
|
176
federationapi/storage/postgres/inbound_peeks_table.go
Normal file
176
federationapi/storage/postgres/inbound_peeks_table.go
Normal file
|
@ -0,0 +1,176 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/federationapi/types"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const inboundPeeksSchema = `
|
||||
CREATE TABLE IF NOT EXISTS federationsender_inbound_peeks (
|
||||
room_id TEXT NOT NULL,
|
||||
server_name TEXT NOT NULL,
|
||||
peek_id TEXT NOT NULL,
|
||||
creation_ts BIGINT NOT NULL,
|
||||
renewed_ts BIGINT NOT NULL,
|
||||
renewal_interval BIGINT NOT NULL,
|
||||
UNIQUE (room_id, server_name, peek_id)
|
||||
);
|
||||
`
|
||||
|
||||
const insertInboundPeekSQL = "" +
|
||||
"INSERT INTO federationsender_inbound_peeks (room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval) VALUES ($1, $2, $3, $4, $5, $6)"
|
||||
|
||||
const selectInboundPeekSQL = "" +
|
||||
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
|
||||
|
||||
const selectInboundPeeksSQL = "" +
|
||||
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1"
|
||||
|
||||
const renewInboundPeekSQL = "" +
|
||||
"UPDATE federationsender_inbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5"
|
||||
|
||||
const deleteInboundPeekSQL = "" +
|
||||
"DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2"
|
||||
|
||||
const deleteInboundPeeksSQL = "" +
|
||||
"DELETE FROM federationsender_inbound_peeks WHERE room_id = $1"
|
||||
|
||||
type inboundPeeksStatements struct {
|
||||
db *sql.DB
|
||||
insertInboundPeekStmt *sql.Stmt
|
||||
selectInboundPeekStmt *sql.Stmt
|
||||
selectInboundPeeksStmt *sql.Stmt
|
||||
renewInboundPeekStmt *sql.Stmt
|
||||
deleteInboundPeekStmt *sql.Stmt
|
||||
deleteInboundPeeksStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresInboundPeeksTable(db *sql.DB) (s *inboundPeeksStatements, err error) {
|
||||
s = &inboundPeeksStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err = db.Exec(inboundPeeksSchema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if s.insertInboundPeekStmt, err = db.Prepare(insertInboundPeekSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectInboundPeekStmt, err = db.Prepare(selectInboundPeekSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectInboundPeeksStmt, err = db.Prepare(selectInboundPeeksSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.renewInboundPeekStmt, err = db.Prepare(renewInboundPeekSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteInboundPeeksStmt, err = db.Prepare(deleteInboundPeeksSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteInboundPeekStmt, err = db.Prepare(deleteInboundPeekSQL); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *inboundPeeksStatements) InsertInboundPeek(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64,
|
||||
) (err error) {
|
||||
nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
|
||||
stmt := sqlutil.TxStmt(txn, s.insertInboundPeekStmt)
|
||||
_, err = stmt.ExecContext(ctx, roomID, serverName, peekID, nowMilli, nowMilli, renewalInterval)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *inboundPeeksStatements) RenewInboundPeek(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64,
|
||||
) (err error) {
|
||||
nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
|
||||
_, err = sqlutil.TxStmt(txn, s.renewInboundPeekStmt).ExecContext(ctx, nowMilli, renewalInterval, roomID, serverName, peekID)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *inboundPeeksStatements) SelectInboundPeek(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string,
|
||||
) (*types.InboundPeek, error) {
|
||||
row := sqlutil.TxStmt(txn, s.selectInboundPeeksStmt).QueryRowContext(ctx, roomID)
|
||||
inboundPeek := types.InboundPeek{}
|
||||
err := row.Scan(
|
||||
&inboundPeek.RoomID,
|
||||
&inboundPeek.ServerName,
|
||||
&inboundPeek.PeekID,
|
||||
&inboundPeek.CreationTimestamp,
|
||||
&inboundPeek.RenewedTimestamp,
|
||||
&inboundPeek.RenewalInterval,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &inboundPeek, nil
|
||||
}
|
||||
|
||||
func (s *inboundPeeksStatements) SelectInboundPeeks(
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) (inboundPeeks []types.InboundPeek, err error) {
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectInboundPeeksStmt).QueryContext(ctx, roomID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "SelectInboundPeeks: rows.close() failed")
|
||||
|
||||
for rows.Next() {
|
||||
inboundPeek := types.InboundPeek{}
|
||||
if err = rows.Scan(
|
||||
&inboundPeek.RoomID,
|
||||
&inboundPeek.ServerName,
|
||||
&inboundPeek.PeekID,
|
||||
&inboundPeek.CreationTimestamp,
|
||||
&inboundPeek.RenewedTimestamp,
|
||||
&inboundPeek.RenewalInterval,
|
||||
); err != nil {
|
||||
return
|
||||
}
|
||||
inboundPeeks = append(inboundPeeks, inboundPeek)
|
||||
}
|
||||
|
||||
return inboundPeeks, rows.Err()
|
||||
}
|
||||
|
||||
func (s *inboundPeeksStatements) DeleteInboundPeek(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string,
|
||||
) (err error) {
|
||||
_, err = sqlutil.TxStmt(txn, s.deleteInboundPeekStmt).ExecContext(ctx, roomID, serverName, peekID)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *inboundPeeksStatements) DeleteInboundPeeks(
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) (err error) {
|
||||
_, err = sqlutil.TxStmt(txn, s.deleteInboundPeeksStmt).ExecContext(ctx, roomID)
|
||||
return
|
||||
}
|
212
federationapi/storage/postgres/joined_hosts_table.go
Normal file
212
federationapi/storage/postgres/joined_hosts_table.go
Normal file
|
@ -0,0 +1,212 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/matrix-org/dendrite/federationapi/types"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const joinedHostsSchema = `
|
||||
-- The joined_hosts table stores a list of m.room.member event ids in the
|
||||
-- current state for each room where the membership is "join".
|
||||
-- There will be an entry for every user that is joined to the room.
|
||||
CREATE TABLE IF NOT EXISTS federationsender_joined_hosts (
|
||||
-- The string ID of the room.
|
||||
room_id TEXT NOT NULL,
|
||||
-- The event ID of the m.room.member join event.
|
||||
event_id TEXT NOT NULL,
|
||||
-- The domain part of the user ID the m.room.member event is for.
|
||||
server_name TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS federatonsender_joined_hosts_event_id_idx
|
||||
ON federationsender_joined_hosts (event_id);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS federatonsender_joined_hosts_room_id_idx
|
||||
ON federationsender_joined_hosts (room_id)
|
||||
`
|
||||
|
||||
const insertJoinedHostsSQL = "" +
|
||||
"INSERT INTO federationsender_joined_hosts (room_id, event_id, server_name)" +
|
||||
" VALUES ($1, $2, $3) ON CONFLICT DO NOTHING"
|
||||
|
||||
const deleteJoinedHostsSQL = "" +
|
||||
"DELETE FROM federationsender_joined_hosts WHERE event_id = ANY($1)"
|
||||
|
||||
const deleteJoinedHostsForRoomSQL = "" +
|
||||
"DELETE FROM federationsender_joined_hosts WHERE room_id = $1"
|
||||
|
||||
const selectJoinedHostsSQL = "" +
|
||||
"SELECT event_id, server_name FROM federationsender_joined_hosts" +
|
||||
" WHERE room_id = $1"
|
||||
|
||||
const selectAllJoinedHostsSQL = "" +
|
||||
"SELECT DISTINCT server_name FROM federationsender_joined_hosts"
|
||||
|
||||
const selectJoinedHostsForRoomsSQL = "" +
|
||||
"SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id = ANY($1)"
|
||||
|
||||
type joinedHostsStatements struct {
|
||||
db *sql.DB
|
||||
insertJoinedHostsStmt *sql.Stmt
|
||||
deleteJoinedHostsStmt *sql.Stmt
|
||||
deleteJoinedHostsForRoomStmt *sql.Stmt
|
||||
selectJoinedHostsStmt *sql.Stmt
|
||||
selectAllJoinedHostsStmt *sql.Stmt
|
||||
selectJoinedHostsForRoomsStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) {
|
||||
s = &joinedHostsStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err = s.db.Exec(joinedHostsSchema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if s.insertJoinedHostsStmt, err = s.db.Prepare(insertJoinedHostsSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteJoinedHostsStmt, err = s.db.Prepare(deleteJoinedHostsSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteJoinedHostsForRoomStmt, err = s.db.Prepare(deleteJoinedHostsForRoomSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectJoinedHostsStmt, err = s.db.Prepare(selectJoinedHostsSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectAllJoinedHostsStmt, err = s.db.Prepare(selectAllJoinedHostsSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectJoinedHostsForRoomsStmt, err = s.db.Prepare(selectJoinedHostsForRoomsSQL); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *joinedHostsStatements) InsertJoinedHosts(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx,
|
||||
roomID, eventID string,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertJoinedHostsStmt)
|
||||
_, err := stmt.ExecContext(ctx, roomID, eventID, serverName)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *joinedHostsStatements) DeleteJoinedHosts(
|
||||
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsStmt)
|
||||
_, err := stmt.ExecContext(ctx, pq.StringArray(eventIDs))
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *joinedHostsStatements) DeleteJoinedHostsForRoom(
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsForRoomStmt)
|
||||
_, err := stmt.ExecContext(ctx, roomID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *joinedHostsStatements) SelectJoinedHostsWithTx(
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) ([]types.JoinedHost, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectJoinedHostsStmt)
|
||||
return joinedHostsFromStmt(ctx, stmt, roomID)
|
||||
}
|
||||
|
||||
func (s *joinedHostsStatements) SelectJoinedHosts(
|
||||
ctx context.Context, roomID string,
|
||||
) ([]types.JoinedHost, error) {
|
||||
return joinedHostsFromStmt(ctx, s.selectJoinedHostsStmt, roomID)
|
||||
}
|
||||
|
||||
func (s *joinedHostsStatements) SelectAllJoinedHosts(
|
||||
ctx context.Context,
|
||||
) ([]gomatrixserverlib.ServerName, error) {
|
||||
rows, err := s.selectAllJoinedHostsStmt.QueryContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectAllJoinedHosts: rows.close() failed")
|
||||
|
||||
var result []gomatrixserverlib.ServerName
|
||||
for rows.Next() {
|
||||
var serverName string
|
||||
if err = rows.Scan(&serverName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, gomatrixserverlib.ServerName(serverName))
|
||||
}
|
||||
|
||||
return result, rows.Err()
|
||||
}
|
||||
|
||||
func (s *joinedHostsStatements) SelectJoinedHostsForRooms(
|
||||
ctx context.Context, roomIDs []string,
|
||||
) ([]gomatrixserverlib.ServerName, error) {
|
||||
rows, err := s.selectJoinedHostsForRoomsStmt.QueryContext(ctx, pq.StringArray(roomIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedHostsForRoomsStmt: rows.close() failed")
|
||||
|
||||
var result []gomatrixserverlib.ServerName
|
||||
for rows.Next() {
|
||||
var serverName string
|
||||
if err = rows.Scan(&serverName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, gomatrixserverlib.ServerName(serverName))
|
||||
}
|
||||
|
||||
return result, rows.Err()
|
||||
}
|
||||
|
||||
func joinedHostsFromStmt(
|
||||
ctx context.Context, stmt *sql.Stmt, roomID string,
|
||||
) ([]types.JoinedHost, error) {
|
||||
rows, err := stmt.QueryContext(ctx, roomID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "joinedHostsFromStmt: rows.close() failed")
|
||||
|
||||
var result []types.JoinedHost
|
||||
for rows.Next() {
|
||||
var eventID, serverName string
|
||||
if err = rows.Scan(&eventID, &serverName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, types.JoinedHost{
|
||||
MemberEventID: eventID,
|
||||
ServerName: gomatrixserverlib.ServerName(serverName),
|
||||
})
|
||||
}
|
||||
|
||||
return result, rows.Err()
|
||||
}
|
|
@ -0,0 +1,64 @@
|
|||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/dendrite/federationapi/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const notaryServerKeysJSONSchema = `
|
||||
CREATE SEQUENCE IF NOT EXISTS federationsender_notary_server_keys_json_pkey;
|
||||
CREATE TABLE IF NOT EXISTS federationsender_notary_server_keys_json (
|
||||
notary_id BIGINT PRIMARY KEY NOT NULL DEFAULT nextval('federationsender_notary_server_keys_json_pkey'),
|
||||
response_json TEXT NOT NULL,
|
||||
server_name TEXT NOT NULL,
|
||||
valid_until BIGINT NOT NULL
|
||||
);
|
||||
`
|
||||
|
||||
const insertServerKeysJSONSQL = "" +
|
||||
"INSERT INTO federationsender_notary_server_keys_json (response_json, server_name, valid_until) VALUES ($1, $2, $3)" +
|
||||
" RETURNING notary_id"
|
||||
|
||||
type notaryServerKeysStatements struct {
|
||||
db *sql.DB
|
||||
insertServerKeysJSONStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresNotaryServerKeysTable(db *sql.DB) (s *notaryServerKeysStatements, err error) {
|
||||
s = ¬aryServerKeysStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err = db.Exec(notaryServerKeysJSONSchema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if s.insertServerKeysJSONStmt, err = db.Prepare(insertServerKeysJSONSQL); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *notaryServerKeysStatements) InsertJSONResponse(
|
||||
ctx context.Context, txn *sql.Tx, keyQueryResponseJSON gomatrixserverlib.ServerKeys, serverName gomatrixserverlib.ServerName, validUntil gomatrixserverlib.Timestamp,
|
||||
) (tables.NotaryID, error) {
|
||||
var notaryID tables.NotaryID
|
||||
return notaryID, txn.Stmt(s.insertServerKeysJSONStmt).QueryRowContext(ctx, string(keyQueryResponseJSON.Raw), serverName, validUntil).Scan(¬aryID)
|
||||
}
|
|
@ -0,0 +1,167 @@
|
|||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/matrix-org/dendrite/federationapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const notaryServerKeysMetadataSchema = `
|
||||
CREATE TABLE IF NOT EXISTS federationsender_notary_server_keys_metadata (
|
||||
notary_id BIGINT NOT NULL,
|
||||
server_name TEXT NOT NULL,
|
||||
key_id TEXT NOT NULL,
|
||||
UNIQUE (server_name, key_id)
|
||||
);
|
||||
`
|
||||
|
||||
const upsertServerKeysSQL = "" +
|
||||
"INSERT INTO federationsender_notary_server_keys_metadata (notary_id, server_name, key_id) VALUES ($1, $2, $3)" +
|
||||
" ON CONFLICT (server_name, key_id) DO UPDATE SET notary_id = $1"
|
||||
|
||||
// for a given (server_name, key_id), find the existing notary ID and valid until. Used to check if we will replace it
|
||||
// JOINs with the json table
|
||||
const selectNotaryKeyMetadataSQL = `
|
||||
SELECT federationsender_notary_server_keys_metadata.notary_id, valid_until FROM federationsender_notary_server_keys_json
|
||||
JOIN federationsender_notary_server_keys_metadata ON
|
||||
federationsender_notary_server_keys_metadata.notary_id = federationsender_notary_server_keys_json.notary_id
|
||||
WHERE federationsender_notary_server_keys_metadata.server_name = $1 AND federationsender_notary_server_keys_metadata.key_id = $2
|
||||
`
|
||||
|
||||
// select the response which has the highest valid_until value
|
||||
// JOINs with the json table
|
||||
const selectNotaryKeyResponsesSQL = `
|
||||
SELECT response_json FROM federationsender_notary_server_keys_json
|
||||
WHERE server_name = $1 AND valid_until = (
|
||||
SELECT MAX(valid_until) FROM federationsender_notary_server_keys_json WHERE server_name = $1
|
||||
)
|
||||
`
|
||||
|
||||
// select the responses which have the given key IDs
|
||||
// JOINs with the json table
|
||||
const selectNotaryKeyResponsesWithKeyIDsSQL = `
|
||||
SELECT response_json FROM federationsender_notary_server_keys_json
|
||||
JOIN federationsender_notary_server_keys_metadata ON
|
||||
federationsender_notary_server_keys_metadata.notary_id = federationsender_notary_server_keys_json.notary_id
|
||||
WHERE federationsender_notary_server_keys_json.server_name = $1 AND federationsender_notary_server_keys_metadata.key_id = ANY ($2)
|
||||
GROUP BY federationsender_notary_server_keys_json.notary_id
|
||||
`
|
||||
|
||||
// JOINs with the metadata table
|
||||
const deleteUnusedServerKeysJSONSQL = `
|
||||
DELETE FROM federationsender_notary_server_keys_json WHERE federationsender_notary_server_keys_json.notary_id NOT IN (
|
||||
SELECT DISTINCT notary_id FROM federationsender_notary_server_keys_metadata
|
||||
)
|
||||
`
|
||||
|
||||
type notaryServerKeysMetadataStatements struct {
|
||||
db *sql.DB
|
||||
upsertServerKeysStmt *sql.Stmt
|
||||
selectNotaryKeyResponsesStmt *sql.Stmt
|
||||
selectNotaryKeyResponsesWithKeyIDsStmt *sql.Stmt
|
||||
selectNotaryKeyMetadataStmt *sql.Stmt
|
||||
deleteUnusedServerKeysJSONStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresNotaryServerKeysMetadataTable(db *sql.DB) (s *notaryServerKeysMetadataStatements, err error) {
|
||||
s = ¬aryServerKeysMetadataStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err = db.Exec(notaryServerKeysMetadataSchema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if s.upsertServerKeysStmt, err = db.Prepare(upsertServerKeysSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectNotaryKeyResponsesStmt, err = db.Prepare(selectNotaryKeyResponsesSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectNotaryKeyResponsesWithKeyIDsStmt, err = db.Prepare(selectNotaryKeyResponsesWithKeyIDsSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectNotaryKeyMetadataStmt, err = db.Prepare(selectNotaryKeyMetadataSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteUnusedServerKeysJSONStmt, err = db.Prepare(deleteUnusedServerKeysJSONSQL); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *notaryServerKeysMetadataStatements) UpsertKey(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, keyID gomatrixserverlib.KeyID, newNotaryID tables.NotaryID, newValidUntil gomatrixserverlib.Timestamp,
|
||||
) (tables.NotaryID, error) {
|
||||
notaryID := newNotaryID
|
||||
// see if the existing notary ID a) exists, b) has a longer valid_until
|
||||
var existingNotaryID tables.NotaryID
|
||||
var existingValidUntil gomatrixserverlib.Timestamp
|
||||
if err := txn.Stmt(s.selectNotaryKeyMetadataStmt).QueryRowContext(ctx, serverName, keyID).Scan(&existingNotaryID, &existingValidUntil); err != nil {
|
||||
if err != sql.ErrNoRows {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
if existingValidUntil.Time().After(newValidUntil.Time()) {
|
||||
// the existing valid_until is valid longer, so use that.
|
||||
return existingNotaryID, nil
|
||||
}
|
||||
// overwrite the notary_id for this (server_name, key_id) tuple
|
||||
_, err := txn.Stmt(s.upsertServerKeysStmt).ExecContext(ctx, notaryID, serverName, keyID)
|
||||
return notaryID, err
|
||||
}
|
||||
|
||||
func (s *notaryServerKeysMetadataStatements) SelectKeys(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, keyIDs []gomatrixserverlib.KeyID) ([]gomatrixserverlib.ServerKeys, error) {
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
if len(keyIDs) == 0 {
|
||||
rows, err = txn.Stmt(s.selectNotaryKeyResponsesStmt).QueryContext(ctx, string(serverName))
|
||||
} else {
|
||||
keyIDstr := make([]string, len(keyIDs))
|
||||
for i := range keyIDs {
|
||||
keyIDstr[i] = string(keyIDs[i])
|
||||
}
|
||||
rows, err = txn.Stmt(s.selectNotaryKeyResponsesWithKeyIDsStmt).QueryContext(ctx, string(serverName), pq.StringArray(keyIDstr))
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectNotaryKeyResponsesStmt close failed")
|
||||
var results []gomatrixserverlib.ServerKeys
|
||||
for rows.Next() {
|
||||
var sk gomatrixserverlib.ServerKeys
|
||||
var raw string
|
||||
if err = rows.Scan(&raw); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = json.Unmarshal([]byte(raw), &sk); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
results = append(results, sk)
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (s *notaryServerKeysMetadataStatements) DeleteOldJSONResponses(ctx context.Context, txn *sql.Tx) error {
|
||||
_, err := txn.Stmt(s.deleteUnusedServerKeysJSONStmt).ExecContext(ctx)
|
||||
return err
|
||||
}
|
176
federationapi/storage/postgres/outbound_peeks_table.go
Normal file
176
federationapi/storage/postgres/outbound_peeks_table.go
Normal file
|
@ -0,0 +1,176 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/federationapi/types"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const outboundPeeksSchema = `
|
||||
CREATE TABLE IF NOT EXISTS federationsender_outbound_peeks (
|
||||
room_id TEXT NOT NULL,
|
||||
server_name TEXT NOT NULL,
|
||||
peek_id TEXT NOT NULL,
|
||||
creation_ts BIGINT NOT NULL,
|
||||
renewed_ts BIGINT NOT NULL,
|
||||
renewal_interval BIGINT NOT NULL,
|
||||
UNIQUE (room_id, server_name, peek_id)
|
||||
);
|
||||
`
|
||||
|
||||
const insertOutboundPeekSQL = "" +
|
||||
"INSERT INTO federationsender_outbound_peeks (room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval) VALUES ($1, $2, $3, $4, $5, $6)"
|
||||
|
||||
const selectOutboundPeekSQL = "" +
|
||||
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
|
||||
|
||||
const selectOutboundPeeksSQL = "" +
|
||||
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1"
|
||||
|
||||
const renewOutboundPeekSQL = "" +
|
||||
"UPDATE federationsender_outbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5"
|
||||
|
||||
const deleteOutboundPeekSQL = "" +
|
||||
"DELETE FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2"
|
||||
|
||||
const deleteOutboundPeeksSQL = "" +
|
||||
"DELETE FROM federationsender_outbound_peeks WHERE room_id = $1"
|
||||
|
||||
type outboundPeeksStatements struct {
|
||||
db *sql.DB
|
||||
insertOutboundPeekStmt *sql.Stmt
|
||||
selectOutboundPeekStmt *sql.Stmt
|
||||
selectOutboundPeeksStmt *sql.Stmt
|
||||
renewOutboundPeekStmt *sql.Stmt
|
||||
deleteOutboundPeekStmt *sql.Stmt
|
||||
deleteOutboundPeeksStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresOutboundPeeksTable(db *sql.DB) (s *outboundPeeksStatements, err error) {
|
||||
s = &outboundPeeksStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err = db.Exec(outboundPeeksSchema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if s.insertOutboundPeekStmt, err = db.Prepare(insertOutboundPeekSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectOutboundPeekStmt, err = db.Prepare(selectOutboundPeekSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectOutboundPeeksStmt, err = db.Prepare(selectOutboundPeeksSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.renewOutboundPeekStmt, err = db.Prepare(renewOutboundPeekSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteOutboundPeeksStmt, err = db.Prepare(deleteOutboundPeeksSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteOutboundPeekStmt, err = db.Prepare(deleteOutboundPeekSQL); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *outboundPeeksStatements) InsertOutboundPeek(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64,
|
||||
) (err error) {
|
||||
nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
|
||||
stmt := sqlutil.TxStmt(txn, s.insertOutboundPeekStmt)
|
||||
_, err = stmt.ExecContext(ctx, roomID, serverName, peekID, nowMilli, nowMilli, renewalInterval)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *outboundPeeksStatements) RenewOutboundPeek(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64,
|
||||
) (err error) {
|
||||
nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
|
||||
_, err = sqlutil.TxStmt(txn, s.renewOutboundPeekStmt).ExecContext(ctx, nowMilli, renewalInterval, roomID, serverName, peekID)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *outboundPeeksStatements) SelectOutboundPeek(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string,
|
||||
) (*types.OutboundPeek, error) {
|
||||
row := sqlutil.TxStmt(txn, s.selectOutboundPeeksStmt).QueryRowContext(ctx, roomID)
|
||||
outboundPeek := types.OutboundPeek{}
|
||||
err := row.Scan(
|
||||
&outboundPeek.RoomID,
|
||||
&outboundPeek.ServerName,
|
||||
&outboundPeek.PeekID,
|
||||
&outboundPeek.CreationTimestamp,
|
||||
&outboundPeek.RenewedTimestamp,
|
||||
&outboundPeek.RenewalInterval,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &outboundPeek, nil
|
||||
}
|
||||
|
||||
func (s *outboundPeeksStatements) SelectOutboundPeeks(
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) (outboundPeeks []types.OutboundPeek, err error) {
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectOutboundPeeksStmt).QueryContext(ctx, roomID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "SelectOutboundPeeks: rows.close() failed")
|
||||
|
||||
for rows.Next() {
|
||||
outboundPeek := types.OutboundPeek{}
|
||||
if err = rows.Scan(
|
||||
&outboundPeek.RoomID,
|
||||
&outboundPeek.ServerName,
|
||||
&outboundPeek.PeekID,
|
||||
&outboundPeek.CreationTimestamp,
|
||||
&outboundPeek.RenewedTimestamp,
|
||||
&outboundPeek.RenewalInterval,
|
||||
); err != nil {
|
||||
return
|
||||
}
|
||||
outboundPeeks = append(outboundPeeks, outboundPeek)
|
||||
}
|
||||
|
||||
return outboundPeeks, rows.Err()
|
||||
}
|
||||
|
||||
func (s *outboundPeeksStatements) DeleteOutboundPeek(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string,
|
||||
) (err error) {
|
||||
_, err = sqlutil.TxStmt(txn, s.deleteOutboundPeekStmt).ExecContext(ctx, roomID, serverName, peekID)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *outboundPeeksStatements) DeleteOutboundPeeks(
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) (err error) {
|
||||
_, err = sqlutil.TxStmt(txn, s.deleteOutboundPeeksStmt).ExecContext(ctx, roomID)
|
||||
return
|
||||
}
|
198
federationapi/storage/postgres/queue_edus_table.go
Normal file
198
federationapi/storage/postgres/queue_edus_table.go
Normal file
|
@ -0,0 +1,198 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const queueEDUsSchema = `
|
||||
CREATE TABLE IF NOT EXISTS federationsender_queue_edus (
|
||||
-- The type of the event (informational).
|
||||
edu_type TEXT NOT NULL,
|
||||
-- The domain part of the user ID the EDU event is for.
|
||||
server_name TEXT NOT NULL,
|
||||
-- The JSON NID from the federationsender_queue_edus_json table.
|
||||
json_nid BIGINT NOT NULL
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_edus_json_nid_idx
|
||||
ON federationsender_queue_edus (json_nid, server_name);
|
||||
`
|
||||
|
||||
const insertQueueEDUSQL = "" +
|
||||
"INSERT INTO federationsender_queue_edus (edu_type, server_name, json_nid)" +
|
||||
" VALUES ($1, $2, $3)"
|
||||
|
||||
const deleteQueueEDUSQL = "" +
|
||||
"DELETE FROM federationsender_queue_edus WHERE server_name = $1 AND json_nid = ANY($2)"
|
||||
|
||||
const selectQueueEDUSQL = "" +
|
||||
"SELECT json_nid FROM federationsender_queue_edus" +
|
||||
" WHERE server_name = $1" +
|
||||
" LIMIT $2"
|
||||
|
||||
const selectQueueEDUReferenceJSONCountSQL = "" +
|
||||
"SELECT COUNT(*) FROM federationsender_queue_edus" +
|
||||
" WHERE json_nid = $1"
|
||||
|
||||
const selectQueueEDUCountSQL = "" +
|
||||
"SELECT COUNT(*) FROM federationsender_queue_edus" +
|
||||
" WHERE server_name = $1"
|
||||
|
||||
const selectQueueServerNamesSQL = "" +
|
||||
"SELECT DISTINCT server_name FROM federationsender_queue_edus"
|
||||
|
||||
type queueEDUsStatements struct {
|
||||
db *sql.DB
|
||||
insertQueueEDUStmt *sql.Stmt
|
||||
deleteQueueEDUStmt *sql.Stmt
|
||||
selectQueueEDUStmt *sql.Stmt
|
||||
selectQueueEDUReferenceJSONCountStmt *sql.Stmt
|
||||
selectQueueEDUCountStmt *sql.Stmt
|
||||
selectQueueEDUServerNamesStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresQueueEDUsTable(db *sql.DB) (s *queueEDUsStatements, err error) {
|
||||
s = &queueEDUsStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err = s.db.Exec(queueEDUsSchema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if s.insertQueueEDUStmt, err = s.db.Prepare(insertQueueEDUSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteQueueEDUStmt, err = s.db.Prepare(deleteQueueEDUSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectQueueEDUStmt, err = s.db.Prepare(selectQueueEDUSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectQueueEDUReferenceJSONCountStmt, err = s.db.Prepare(selectQueueEDUReferenceJSONCountSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectQueueEDUCountStmt, err = s.db.Prepare(selectQueueEDUCountSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectQueueEDUServerNamesStmt, err = s.db.Prepare(selectQueueServerNamesSQL); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *queueEDUsStatements) InsertQueueEDU(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx,
|
||||
eduType string,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
nid int64,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertQueueEDUStmt)
|
||||
_, err := stmt.ExecContext(
|
||||
ctx,
|
||||
eduType, // the EDU type
|
||||
serverName, // destination server name
|
||||
nid, // JSON blob NID
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *queueEDUsStatements) DeleteQueueEDUs(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
jsonNIDs []int64,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteQueueEDUStmt)
|
||||
_, err := stmt.ExecContext(ctx, serverName, pq.Int64Array(jsonNIDs))
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *queueEDUsStatements) SelectQueueEDUs(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
limit int,
|
||||
) ([]int64, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectQueueEDUStmt)
|
||||
rows, err := stmt.QueryContext(ctx, serverName, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed")
|
||||
var result []int64
|
||||
for rows.Next() {
|
||||
var nid int64
|
||||
if err = rows.Scan(&nid); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, nid)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *queueEDUsStatements) SelectQueueEDUReferenceJSONCount(
|
||||
ctx context.Context, txn *sql.Tx, jsonNID int64,
|
||||
) (int64, error) {
|
||||
var count int64
|
||||
stmt := sqlutil.TxStmt(txn, s.selectQueueEDUReferenceJSONCountStmt)
|
||||
err := stmt.QueryRowContext(ctx, jsonNID).Scan(&count)
|
||||
if err == sql.ErrNoRows {
|
||||
return -1, nil
|
||||
}
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (s *queueEDUsStatements) SelectQueueEDUCount(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
||||
) (int64, error) {
|
||||
var count int64
|
||||
stmt := sqlutil.TxStmt(txn, s.selectQueueEDUCountStmt)
|
||||
err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
|
||||
if err == sql.ErrNoRows {
|
||||
// It's acceptable for there to be no rows referencing a given
|
||||
// JSON NID but it's not an error condition. Just return as if
|
||||
// there's a zero count.
|
||||
return 0, nil
|
||||
}
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (s *queueEDUsStatements) SelectQueueEDUServerNames(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
) ([]gomatrixserverlib.ServerName, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectQueueEDUServerNamesStmt)
|
||||
rows, err := stmt.QueryContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed")
|
||||
var result []gomatrixserverlib.ServerName
|
||||
for rows.Next() {
|
||||
var serverName gomatrixserverlib.ServerName
|
||||
if err = rows.Scan(&serverName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, serverName)
|
||||
}
|
||||
|
||||
return result, rows.Err()
|
||||
}
|
115
federationapi/storage/postgres/queue_json_table.go
Normal file
115
federationapi/storage/postgres/queue_json_table.go
Normal file
|
@ -0,0 +1,115 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
)
|
||||
|
||||
const queueJSONSchema = `
|
||||
-- The federationsender_queue_json table contains event contents that
|
||||
-- we failed to send.
|
||||
CREATE TABLE IF NOT EXISTS federationsender_queue_json (
|
||||
-- The JSON NID. This allows the federationsender_queue_retry table to
|
||||
-- cross-reference to find the JSON blob.
|
||||
json_nid BIGSERIAL,
|
||||
-- The JSON body. Text so that we preserve UTF-8.
|
||||
json_body TEXT NOT NULL
|
||||
);
|
||||
`
|
||||
|
||||
const insertJSONSQL = "" +
|
||||
"INSERT INTO federationsender_queue_json (json_body)" +
|
||||
" VALUES ($1)" +
|
||||
" RETURNING json_nid"
|
||||
|
||||
const deleteJSONSQL = "" +
|
||||
"DELETE FROM federationsender_queue_json WHERE json_nid = ANY($1)"
|
||||
|
||||
const selectJSONSQL = "" +
|
||||
"SELECT json_nid, json_body FROM federationsender_queue_json" +
|
||||
" WHERE json_nid = ANY($1)"
|
||||
|
||||
type queueJSONStatements struct {
|
||||
db *sql.DB
|
||||
insertJSONStmt *sql.Stmt
|
||||
deleteJSONStmt *sql.Stmt
|
||||
selectJSONStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresQueueJSONTable(db *sql.DB) (s *queueJSONStatements, err error) {
|
||||
s = &queueJSONStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err = s.db.Exec(queueJSONSchema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if s.insertJSONStmt, err = s.db.Prepare(insertJSONSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteJSONStmt, err = s.db.Prepare(deleteJSONSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectJSONStmt, err = s.db.Prepare(selectJSONSQL); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *queueJSONStatements) InsertQueueJSON(
|
||||
ctx context.Context, txn *sql.Tx, json string,
|
||||
) (int64, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertJSONStmt)
|
||||
var lastid int64
|
||||
if err := stmt.QueryRowContext(ctx, json).Scan(&lastid); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return lastid, nil
|
||||
}
|
||||
|
||||
func (s *queueJSONStatements) DeleteQueueJSON(
|
||||
ctx context.Context, txn *sql.Tx, nids []int64,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteJSONStmt)
|
||||
_, err := stmt.ExecContext(ctx, pq.Int64Array(nids))
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *queueJSONStatements) SelectQueueJSON(
|
||||
ctx context.Context, txn *sql.Tx, jsonNIDs []int64,
|
||||
) (map[int64][]byte, error) {
|
||||
blobs := map[int64][]byte{}
|
||||
stmt := sqlutil.TxStmt(txn, s.selectJSONStmt)
|
||||
rows, err := stmt.QueryContext(ctx, pq.Int64Array(jsonNIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectJSON: rows.close() failed")
|
||||
for rows.Next() {
|
||||
var nid int64
|
||||
var blob []byte
|
||||
if err = rows.Scan(&nid, &blob); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
blobs[nid] = blob
|
||||
}
|
||||
return blobs, err
|
||||
}
|
202
federationapi/storage/postgres/queue_pdus_table.go
Normal file
202
federationapi/storage/postgres/queue_pdus_table.go
Normal file
|
@ -0,0 +1,202 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const queuePDUsSchema = `
|
||||
CREATE TABLE IF NOT EXISTS federationsender_queue_pdus (
|
||||
-- The transaction ID that was generated before persisting the event.
|
||||
transaction_id TEXT NOT NULL,
|
||||
-- The destination server that we will send the event to.
|
||||
server_name TEXT NOT NULL,
|
||||
-- The JSON NID from the federationsender_queue_pdus_json table.
|
||||
json_nid BIGINT NOT NULL
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_pdus_pdus_json_nid_idx
|
||||
ON federationsender_queue_pdus (json_nid, server_name);
|
||||
`
|
||||
|
||||
const insertQueuePDUSQL = "" +
|
||||
"INSERT INTO federationsender_queue_pdus (transaction_id, server_name, json_nid)" +
|
||||
" VALUES ($1, $2, $3)"
|
||||
|
||||
const deleteQueuePDUSQL = "" +
|
||||
"DELETE FROM federationsender_queue_pdus WHERE server_name = $1 AND json_nid = ANY($2)"
|
||||
|
||||
const selectQueuePDUsSQL = "" +
|
||||
"SELECT json_nid FROM federationsender_queue_pdus" +
|
||||
" WHERE server_name = $1" +
|
||||
" LIMIT $2"
|
||||
|
||||
const selectQueuePDUReferenceJSONCountSQL = "" +
|
||||
"SELECT COUNT(*) FROM federationsender_queue_pdus" +
|
||||
" WHERE json_nid = $1"
|
||||
|
||||
const selectQueuePDUsCountSQL = "" +
|
||||
"SELECT COUNT(*) FROM federationsender_queue_pdus" +
|
||||
" WHERE server_name = $1"
|
||||
|
||||
const selectQueuePDUServerNamesSQL = "" +
|
||||
"SELECT DISTINCT server_name FROM federationsender_queue_pdus"
|
||||
|
||||
type queuePDUsStatements struct {
|
||||
db *sql.DB
|
||||
insertQueuePDUStmt *sql.Stmt
|
||||
deleteQueuePDUsStmt *sql.Stmt
|
||||
selectQueuePDUsStmt *sql.Stmt
|
||||
selectQueuePDUReferenceJSONCountStmt *sql.Stmt
|
||||
selectQueuePDUsCountStmt *sql.Stmt
|
||||
selectQueuePDUServerNamesStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresQueuePDUsTable(db *sql.DB) (s *queuePDUsStatements, err error) {
|
||||
s = &queuePDUsStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err = s.db.Exec(queuePDUsSchema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if s.insertQueuePDUStmt, err = s.db.Prepare(insertQueuePDUSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteQueuePDUsStmt, err = s.db.Prepare(deleteQueuePDUSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectQueuePDUsStmt, err = s.db.Prepare(selectQueuePDUsSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectQueuePDUReferenceJSONCountStmt, err = s.db.Prepare(selectQueuePDUReferenceJSONCountSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectQueuePDUsCountStmt, err = s.db.Prepare(selectQueuePDUsCountSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectQueuePDUServerNamesStmt, err = s.db.Prepare(selectQueuePDUServerNamesSQL); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *queuePDUsStatements) InsertQueuePDU(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx,
|
||||
transactionID gomatrixserverlib.TransactionID,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
nid int64,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertQueuePDUStmt)
|
||||
_, err := stmt.ExecContext(
|
||||
ctx,
|
||||
transactionID, // the transaction ID that we initially attempted
|
||||
serverName, // destination server name
|
||||
nid, // JSON blob NID
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *queuePDUsStatements) DeleteQueuePDUs(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
jsonNIDs []int64,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteQueuePDUsStmt)
|
||||
_, err := stmt.ExecContext(ctx, serverName, pq.Int64Array(jsonNIDs))
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *queuePDUsStatements) SelectQueuePDUReferenceJSONCount(
|
||||
ctx context.Context, txn *sql.Tx, jsonNID int64,
|
||||
) (int64, error) {
|
||||
var count int64
|
||||
stmt := sqlutil.TxStmt(txn, s.selectQueuePDUReferenceJSONCountStmt)
|
||||
err := stmt.QueryRowContext(ctx, jsonNID).Scan(&count)
|
||||
if err == sql.ErrNoRows {
|
||||
// It's acceptable for there to be no rows referencing a given
|
||||
// JSON NID but it's not an error condition. Just return as if
|
||||
// there's a zero count.
|
||||
return 0, nil
|
||||
}
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (s *queuePDUsStatements) SelectQueuePDUCount(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
||||
) (int64, error) {
|
||||
var count int64
|
||||
stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsCountStmt)
|
||||
err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
|
||||
if err == sql.ErrNoRows {
|
||||
// It's acceptable for there to be no rows referencing a given
|
||||
// JSON NID but it's not an error condition. Just return as if
|
||||
// there's a zero count.
|
||||
return 0, nil
|
||||
}
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (s *queuePDUsStatements) SelectQueuePDUs(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
limit int,
|
||||
) ([]int64, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsStmt)
|
||||
rows, err := stmt.QueryContext(ctx, serverName, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed")
|
||||
var result []int64
|
||||
for rows.Next() {
|
||||
var nid int64
|
||||
if err = rows.Scan(&nid); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, nid)
|
||||
}
|
||||
|
||||
return result, rows.Err()
|
||||
}
|
||||
|
||||
func (s *queuePDUsStatements) SelectQueuePDUServerNames(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
) ([]gomatrixserverlib.ServerName, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectQueuePDUServerNamesStmt)
|
||||
rows, err := stmt.QueryContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed")
|
||||
var result []gomatrixserverlib.ServerName
|
||||
for rows.Next() {
|
||||
var serverName gomatrixserverlib.ServerName
|
||||
if err = rows.Scan(&serverName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, serverName)
|
||||
}
|
||||
|
||||
return result, rows.Err()
|
||||
}
|
146
federationapi/storage/postgres/server_key_table.go
Normal file
146
federationapi/storage/postgres/server_key_table.go
Normal file
|
@ -0,0 +1,146 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const serverSigningKeysSchema = `
|
||||
-- A cache of signing keys downloaded from remote servers.
|
||||
CREATE TABLE IF NOT EXISTS keydb_server_keys (
|
||||
-- The name of the matrix server the key is for.
|
||||
server_name TEXT NOT NULL,
|
||||
-- The ID of the server key.
|
||||
server_key_id TEXT NOT NULL,
|
||||
-- Combined server name and key ID separated by the ASCII unit separator
|
||||
-- to make it easier to run bulk queries.
|
||||
server_name_and_key_id TEXT NOT NULL,
|
||||
-- When the key is valid until as a millisecond timestamp.
|
||||
-- 0 if this is an expired key (in which case expired_ts will be non-zero)
|
||||
valid_until_ts BIGINT NOT NULL,
|
||||
-- When the key expired as a millisecond timestamp.
|
||||
-- 0 if this is an active key (in which case valid_until_ts will be non-zero)
|
||||
expired_ts BIGINT NOT NULL,
|
||||
-- The base64-encoded public key.
|
||||
server_key TEXT NOT NULL,
|
||||
CONSTRAINT keydb_server_keys_unique UNIQUE (server_name, server_key_id)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS keydb_server_name_and_key_id ON keydb_server_keys (server_name_and_key_id);
|
||||
`
|
||||
|
||||
const bulkSelectServerSigningKeysSQL = "" +
|
||||
"SELECT server_name, server_key_id, valid_until_ts, expired_ts, " +
|
||||
" server_key FROM keydb_server_keys" +
|
||||
" WHERE server_name_and_key_id = ANY($1)"
|
||||
|
||||
const upsertServerSigningKeysSQL = "" +
|
||||
"INSERT INTO keydb_server_keys (server_name, server_key_id," +
|
||||
" server_name_and_key_id, valid_until_ts, expired_ts, server_key)" +
|
||||
" VALUES ($1, $2, $3, $4, $5, $6)" +
|
||||
" ON CONFLICT ON CONSTRAINT keydb_server_keys_unique" +
|
||||
" DO UPDATE SET valid_until_ts = $4, expired_ts = $5, server_key = $6"
|
||||
|
||||
type serverSigningKeyStatements struct {
|
||||
bulkSelectServerKeysStmt *sql.Stmt
|
||||
upsertServerKeysStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresServerSigningKeysTable(db *sql.DB) (s *serverSigningKeyStatements, err error) {
|
||||
s = &serverSigningKeyStatements{}
|
||||
_, err = db.Exec(serverSigningKeysSchema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if s.bulkSelectServerKeysStmt, err = db.Prepare(bulkSelectServerSigningKeysSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.upsertServerKeysStmt, err = db.Prepare(upsertServerSigningKeysSQL); err != nil {
|
||||
return
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *serverSigningKeyStatements) BulkSelectServerKeys(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
|
||||
) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) {
|
||||
var nameAndKeyIDs []string
|
||||
for request := range requests {
|
||||
nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request))
|
||||
}
|
||||
stmt := s.bulkSelectServerKeysStmt
|
||||
rows, err := stmt.QueryContext(ctx, pq.StringArray(nameAndKeyIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectServerKeys: rows.close() failed")
|
||||
results := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{}
|
||||
for rows.Next() {
|
||||
var serverName string
|
||||
var keyID string
|
||||
var key string
|
||||
var validUntilTS int64
|
||||
var expiredTS int64
|
||||
if err = rows.Scan(&serverName, &keyID, &validUntilTS, &expiredTS, &key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r := gomatrixserverlib.PublicKeyLookupRequest{
|
||||
ServerName: gomatrixserverlib.ServerName(serverName),
|
||||
KeyID: gomatrixserverlib.KeyID(keyID),
|
||||
}
|
||||
vk := gomatrixserverlib.VerifyKey{}
|
||||
err = vk.Key.Decode(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
results[r] = gomatrixserverlib.PublicKeyLookupResult{
|
||||
VerifyKey: vk,
|
||||
ValidUntilTS: gomatrixserverlib.Timestamp(validUntilTS),
|
||||
ExpiredTS: gomatrixserverlib.Timestamp(expiredTS),
|
||||
}
|
||||
}
|
||||
return results, rows.Err()
|
||||
}
|
||||
|
||||
func (s *serverSigningKeyStatements) UpsertServerKeys(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
request gomatrixserverlib.PublicKeyLookupRequest,
|
||||
key gomatrixserverlib.PublicKeyLookupResult,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.upsertServerKeysStmt)
|
||||
_, err := stmt.ExecContext(
|
||||
ctx,
|
||||
string(request.ServerName),
|
||||
string(request.KeyID),
|
||||
nameAndKeyID(request),
|
||||
key.ValidUntilTS,
|
||||
key.ExpiredTS,
|
||||
key.Key.Encode(),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func nameAndKeyID(request gomatrixserverlib.PublicKeyLookupRequest) string {
|
||||
return string(request.ServerName) + "\x1F" + string(request.KeyID)
|
||||
}
|
109
federationapi/storage/postgres/storage.go
Normal file
109
federationapi/storage/postgres/storage.go
Normal file
|
@ -0,0 +1,109 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/federationapi/storage/postgres/deltas"
|
||||
"github.com/matrix-org/dendrite/federationapi/storage/shared"
|
||||
"github.com/matrix-org/dendrite/internal/caching"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
)
|
||||
|
||||
// Database stores information needed by the federation sender
|
||||
type Database struct {
|
||||
shared.Database
|
||||
sqlutil.PartitionOffsetStatements
|
||||
db *sql.DB
|
||||
writer sqlutil.Writer
|
||||
}
|
||||
|
||||
// NewDatabase opens a new database
|
||||
func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationCache) (*Database, error) {
|
||||
var d Database
|
||||
var err error
|
||||
if d.db, err = sqlutil.Open(dbProperties); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d.writer = sqlutil.NewDummyWriter()
|
||||
joinedHosts, err := NewPostgresJoinedHostsTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
queuePDUs, err := NewPostgresQueuePDUsTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
queueEDUs, err := NewPostgresQueueEDUsTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
queueJSON, err := NewPostgresQueueJSONTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
blacklist, err := NewPostgresBlacklistTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
inboundPeeks, err := NewPostgresInboundPeeksTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
outboundPeeks, err := NewPostgresOutboundPeeksTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
notaryJSON, err := NewPostgresNotaryServerKeysTable(d.db)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewPostgresNotaryServerKeysTable: %s", err)
|
||||
}
|
||||
notaryMetadata, err := NewPostgresNotaryServerKeysMetadataTable(d.db)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewPostgresNotaryServerKeysMetadataTable: %s", err)
|
||||
}
|
||||
serverSigningKeys, err := NewPostgresServerSigningKeysTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m := sqlutil.NewMigrations()
|
||||
deltas.LoadRemoveRoomsTable(m)
|
||||
if err = m.RunDeltas(d.db, dbProperties); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d.Database = shared.Database{
|
||||
DB: d.db,
|
||||
Cache: cache,
|
||||
Writer: d.writer,
|
||||
FederationJoinedHosts: joinedHosts,
|
||||
FederationQueuePDUs: queuePDUs,
|
||||
FederationQueueEDUs: queueEDUs,
|
||||
FederationQueueJSON: queueJSON,
|
||||
FederationBlacklist: blacklist,
|
||||
FederationInboundPeeks: inboundPeeks,
|
||||
FederationOutboundPeeks: outboundPeeks,
|
||||
NotaryServerKeysJSON: notaryJSON,
|
||||
NotaryServerKeysMetadata: notaryMetadata,
|
||||
ServerSigningKeys: serverSigningKeys,
|
||||
}
|
||||
if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "federationsender"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &d, nil
|
||||
}
|
247
federationapi/storage/shared/storage.go
Normal file
247
federationapi/storage/shared/storage.go
Normal file
|
@ -0,0 +1,247 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package shared
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/federationapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/federationapi/types"
|
||||
"github.com/matrix-org/dendrite/internal/caching"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
type Database struct {
|
||||
DB *sql.DB
|
||||
Cache caching.FederationCache
|
||||
Writer sqlutil.Writer
|
||||
FederationQueuePDUs tables.FederationQueuePDUs
|
||||
FederationQueueEDUs tables.FederationQueueEDUs
|
||||
FederationQueueJSON tables.FederationQueueJSON
|
||||
FederationJoinedHosts tables.FederationJoinedHosts
|
||||
FederationBlacklist tables.FederationBlacklist
|
||||
FederationOutboundPeeks tables.FederationOutboundPeeks
|
||||
FederationInboundPeeks tables.FederationInboundPeeks
|
||||
NotaryServerKeysJSON tables.FederationNotaryServerKeysJSON
|
||||
NotaryServerKeysMetadata tables.FederationNotaryServerKeysMetadata
|
||||
ServerSigningKeys tables.FederationServerSigningKeys
|
||||
}
|
||||
|
||||
// An Receipt contains the NIDs of a call to GetNextTransactionPDUs/EDUs.
|
||||
// We don't actually export the NIDs but we need the caller to be able
|
||||
// to pass them back so that we can clean up if the transaction sends
|
||||
// successfully.
|
||||
type Receipt struct {
|
||||
nid int64
|
||||
}
|
||||
|
||||
func (r *Receipt) String() string {
|
||||
return fmt.Sprintf("%d", r.nid)
|
||||
}
|
||||
|
||||
// UpdateRoom updates the joined hosts for a room and returns what the joined
|
||||
// hosts were before the update, or nil if this was a duplicate message.
|
||||
// This is called when we receive a message from kafka, so we pass in
|
||||
// oldEventID and newEventID to check that we haven't missed any messages or
|
||||
// this isn't a duplicate message.
|
||||
func (d *Database) UpdateRoom(
|
||||
ctx context.Context,
|
||||
roomID, oldEventID, newEventID string,
|
||||
addHosts []types.JoinedHost,
|
||||
removeHosts []string,
|
||||
) (joinedHosts []types.JoinedHost, err error) {
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
joinedHosts, err = d.FederationJoinedHosts.SelectJoinedHostsWithTx(ctx, txn, roomID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, add := range addHosts {
|
||||
err = d.FederationJoinedHosts.InsertJoinedHosts(ctx, txn, roomID, add.MemberEventID, add.ServerName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err = d.FederationJoinedHosts.DeleteJoinedHosts(ctx, txn, removeHosts); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// GetJoinedHosts returns the currently joined hosts for room,
|
||||
// as known to federationserver.
|
||||
// Returns an error if something goes wrong.
|
||||
func (d *Database) GetJoinedHosts(
|
||||
ctx context.Context, roomID string,
|
||||
) ([]types.JoinedHost, error) {
|
||||
return d.FederationJoinedHosts.SelectJoinedHosts(ctx, roomID)
|
||||
}
|
||||
|
||||
// GetAllJoinedHosts returns the currently joined hosts for
|
||||
// all rooms known to the federation sender.
|
||||
// Returns an error if something goes wrong.
|
||||
func (d *Database) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) {
|
||||
return d.FederationJoinedHosts.SelectAllJoinedHosts(ctx)
|
||||
}
|
||||
|
||||
func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string) ([]gomatrixserverlib.ServerName, error) {
|
||||
return d.FederationJoinedHosts.SelectJoinedHostsForRooms(ctx, roomIDs)
|
||||
}
|
||||
|
||||
// StoreJSON adds a JSON blob into the queue JSON table and returns
|
||||
// a NID. The NID will then be used when inserting the per-destination
|
||||
// metadata entries.
|
||||
func (d *Database) StoreJSON(
|
||||
ctx context.Context, js string,
|
||||
) (*Receipt, error) {
|
||||
var nid int64
|
||||
var err error
|
||||
_ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
nid, err = d.FederationQueueJSON.InsertQueueJSON(ctx, txn, js)
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("d.insertQueueJSON: %w", err)
|
||||
}
|
||||
return &Receipt{
|
||||
nid: nid,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (d *Database) PurgeRoomState(
|
||||
ctx context.Context, roomID string,
|
||||
) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
// If the event is a create event then we'll delete all of the existing
|
||||
// data for the room. The only reason that a create event would be replayed
|
||||
// to us in this way is if we're about to receive the entire room state.
|
||||
if err := d.FederationJoinedHosts.DeleteJoinedHostsForRoom(ctx, txn, roomID); err != nil {
|
||||
return fmt.Errorf("d.FederationJoinedHosts.DeleteJoinedHosts: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Database) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.FederationBlacklist.InsertBlacklist(context.TODO(), txn, serverName)
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Database) RemoveServerFromBlacklist(serverName gomatrixserverlib.ServerName) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.FederationBlacklist.DeleteBlacklist(context.TODO(), txn, serverName)
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Database) RemoveAllServersFromBlacklist() error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.FederationBlacklist.DeleteAllBlacklist(context.TODO(), txn)
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Database) IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error) {
|
||||
return d.FederationBlacklist.SelectBlacklist(context.TODO(), nil, serverName)
|
||||
}
|
||||
|
||||
func (d *Database) AddOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.FederationOutboundPeeks.InsertOutboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval)
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Database) RenewOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.FederationOutboundPeeks.RenewOutboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval)
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Database) GetOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.OutboundPeek, error) {
|
||||
return d.FederationOutboundPeeks.SelectOutboundPeek(ctx, nil, serverName, roomID, peekID)
|
||||
}
|
||||
|
||||
func (d *Database) GetOutboundPeeks(ctx context.Context, roomID string) ([]types.OutboundPeek, error) {
|
||||
return d.FederationOutboundPeeks.SelectOutboundPeeks(ctx, nil, roomID)
|
||||
}
|
||||
|
||||
func (d *Database) AddInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.FederationInboundPeeks.InsertInboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval)
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Database) RenewInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.FederationInboundPeeks.RenewInboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval)
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Database) GetInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.InboundPeek, error) {
|
||||
return d.FederationInboundPeeks.SelectInboundPeek(ctx, nil, serverName, roomID, peekID)
|
||||
}
|
||||
|
||||
func (d *Database) GetInboundPeeks(ctx context.Context, roomID string) ([]types.InboundPeek, error) {
|
||||
return d.FederationInboundPeeks.SelectInboundPeeks(ctx, nil, roomID)
|
||||
}
|
||||
|
||||
func (d *Database) UpdateNotaryKeys(ctx context.Context, serverName gomatrixserverlib.ServerName, serverKeys gomatrixserverlib.ServerKeys) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
validUntil := serverKeys.ValidUntilTS
|
||||
// Servers MUST use the lesser of this field and 7 days into the future when determining if a key is valid.
|
||||
// This is to avoid a situation where an attacker publishes a key which is valid for a significant amount of
|
||||
// time without a way for the homeserver owner to revoke it.
|
||||
// https://spec.matrix.org/unstable/server-server-api/#querying-keys-through-another-server
|
||||
weekIntoFuture := time.Now().Add(7 * 24 * time.Hour)
|
||||
if weekIntoFuture.Before(validUntil.Time()) {
|
||||
validUntil = gomatrixserverlib.AsTimestamp(weekIntoFuture)
|
||||
}
|
||||
notaryID, err := d.NotaryServerKeysJSON.InsertJSONResponse(ctx, txn, serverKeys, serverName, validUntil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// update the metadata for the keys
|
||||
for keyID := range serverKeys.OldVerifyKeys {
|
||||
_, err = d.NotaryServerKeysMetadata.UpsertKey(ctx, txn, serverName, keyID, notaryID, validUntil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
for keyID := range serverKeys.VerifyKeys {
|
||||
_, err = d.NotaryServerKeysMetadata.UpsertKey(ctx, txn, serverName, keyID, notaryID, validUntil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// clean up old responses
|
||||
return d.NotaryServerKeysMetadata.DeleteOldJSONResponses(ctx, txn)
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Database) GetNotaryKeys(
|
||||
ctx context.Context, serverName gomatrixserverlib.ServerName, optKeyIDs []gomatrixserverlib.KeyID,
|
||||
) (sks []gomatrixserverlib.ServerKeys, err error) {
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
sks, err = d.NotaryServerKeysMetadata.SelectKeys(ctx, txn, serverName, optKeyIDs)
|
||||
return err
|
||||
})
|
||||
return sks, err
|
||||
}
|
151
federationapi/storage/shared/storage_edus.go
Normal file
151
federationapi/storage/shared/storage_edus.go
Normal file
|
@ -0,0 +1,151 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package shared
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// AssociateEDUWithDestination creates an association that the
|
||||
// destination queues will use to determine which JSON blobs to send
|
||||
// to which servers.
|
||||
func (d *Database) AssociateEDUWithDestination(
|
||||
ctx context.Context,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
receipt *Receipt,
|
||||
) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
if err := d.FederationQueueEDUs.InsertQueueEDU(
|
||||
ctx, // context
|
||||
txn, // SQL transaction
|
||||
"", // TODO: EDU type for coalescing
|
||||
serverName, // destination server name
|
||||
receipt.nid, // NID from the federationapi_queue_json table
|
||||
); err != nil {
|
||||
return fmt.Errorf("InsertQueueEDU: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// GetNextTransactionEDUs retrieves events from the database for
|
||||
// the next pending transaction, up to the limit specified.
|
||||
func (d *Database) GetPendingEDUs(
|
||||
ctx context.Context,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
limit int,
|
||||
) (
|
||||
edus map[*Receipt]*gomatrixserverlib.EDU,
|
||||
err error,
|
||||
) {
|
||||
edus = make(map[*Receipt]*gomatrixserverlib.EDU)
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
nids, err := d.FederationQueueEDUs.SelectQueueEDUs(ctx, txn, serverName, limit)
|
||||
if err != nil {
|
||||
return fmt.Errorf("SelectQueueEDUs: %w", err)
|
||||
}
|
||||
|
||||
retrieve := make([]int64, 0, len(nids))
|
||||
for _, nid := range nids {
|
||||
if edu, ok := d.Cache.GetFederationQueuedEDU(nid); ok {
|
||||
edus[&Receipt{nid}] = edu
|
||||
} else {
|
||||
retrieve = append(retrieve, nid)
|
||||
}
|
||||
}
|
||||
|
||||
blobs, err := d.FederationQueueJSON.SelectQueueJSON(ctx, txn, retrieve)
|
||||
if err != nil {
|
||||
return fmt.Errorf("SelectQueueJSON: %w", err)
|
||||
}
|
||||
|
||||
for nid, blob := range blobs {
|
||||
var event gomatrixserverlib.EDU
|
||||
if err := json.Unmarshal(blob, &event); err != nil {
|
||||
return fmt.Errorf("json.Unmarshal: %w", err)
|
||||
}
|
||||
edus[&Receipt{nid}] = &event
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// CleanEDUs cleans up all specified EDUs. This is done when a
|
||||
// transaction was sent successfully.
|
||||
func (d *Database) CleanEDUs(
|
||||
ctx context.Context,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
receipts []*Receipt,
|
||||
) error {
|
||||
if len(receipts) == 0 {
|
||||
return errors.New("expected receipt")
|
||||
}
|
||||
|
||||
nids := make([]int64, len(receipts))
|
||||
for i := range receipts {
|
||||
nids[i] = receipts[i].nid
|
||||
}
|
||||
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
if err := d.FederationQueueEDUs.DeleteQueueEDUs(ctx, txn, serverName, nids); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var deleteNIDs []int64
|
||||
for _, nid := range nids {
|
||||
count, err := d.FederationQueueEDUs.SelectQueueEDUReferenceJSONCount(ctx, txn, nid)
|
||||
if err != nil {
|
||||
return fmt.Errorf("SelectQueueEDUReferenceJSONCount: %w", err)
|
||||
}
|
||||
if count == 0 {
|
||||
deleteNIDs = append(deleteNIDs, nid)
|
||||
d.Cache.EvictFederationQueuedEDU(nid)
|
||||
}
|
||||
}
|
||||
|
||||
if len(deleteNIDs) > 0 {
|
||||
if err := d.FederationQueueJSON.DeleteQueueJSON(ctx, txn, deleteNIDs); err != nil {
|
||||
return fmt.Errorf("DeleteQueueJSON: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// GetPendingEDUCount returns the number of EDUs waiting to be
|
||||
// sent for a given servername.
|
||||
func (d *Database) GetPendingEDUCount(
|
||||
ctx context.Context,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
) (int64, error) {
|
||||
return d.FederationQueueEDUs.SelectQueueEDUCount(ctx, nil, serverName)
|
||||
}
|
||||
|
||||
// GetPendingServerNames returns the server names that have EDUs
|
||||
// waiting to be sent.
|
||||
func (d *Database) GetPendingEDUServerNames(
|
||||
ctx context.Context,
|
||||
) ([]gomatrixserverlib.ServerName, error) {
|
||||
return d.FederationQueueEDUs.SelectQueueEDUServerNames(ctx, nil)
|
||||
}
|
59
federationapi/storage/shared/storage_keys.go
Normal file
59
federationapi/storage/shared/storage_keys.go
Normal file
|
@ -0,0 +1,59 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package shared
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// FetcherName implements KeyFetcher
|
||||
func (d Database) FetcherName() string {
|
||||
return "FederationAPIKeyDatabase"
|
||||
}
|
||||
|
||||
// FetchKeys implements gomatrixserverlib.KeyDatabase
|
||||
func (d *Database) FetchKeys(
|
||||
ctx context.Context,
|
||||
requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
|
||||
) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) {
|
||||
return d.ServerSigningKeys.BulkSelectServerKeys(ctx, nil, requests)
|
||||
}
|
||||
|
||||
// StoreKeys implements gomatrixserverlib.KeyDatabase
|
||||
func (d *Database) StoreKeys(
|
||||
ctx context.Context,
|
||||
keyMap map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult,
|
||||
) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
var lastErr error
|
||||
for request, keys := range keyMap {
|
||||
if err := d.ServerSigningKeys.UpsertServerKeys(ctx, txn, request, keys); err != nil {
|
||||
// Rather than returning immediately on error we try to insert the
|
||||
// remaining keys.
|
||||
// Since we are inserting the keys outside of a transaction it is
|
||||
// possible for some of the inserts to succeed even though some
|
||||
// of the inserts have failed.
|
||||
// Ensuring that we always insert all the keys we can means that
|
||||
// this behaviour won't depend on the iteration order of the map.
|
||||
lastErr = err
|
||||
}
|
||||
}
|
||||
return lastErr
|
||||
})
|
||||
}
|
159
federationapi/storage/shared/storage_pdus.go
Normal file
159
federationapi/storage/shared/storage_pdus.go
Normal file
|
@ -0,0 +1,159 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package shared
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// AssociatePDUWithDestination creates an association that the
|
||||
// destination queues will use to determine which JSON blobs to send
|
||||
// to which servers.
|
||||
func (d *Database) AssociatePDUWithDestination(
|
||||
ctx context.Context,
|
||||
transactionID gomatrixserverlib.TransactionID,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
receipt *Receipt,
|
||||
) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
if err := d.FederationQueuePDUs.InsertQueuePDU(
|
||||
ctx, // context
|
||||
txn, // SQL transaction
|
||||
transactionID, // transaction ID
|
||||
serverName, // destination server name
|
||||
receipt.nid, // NID from the federationapi_queue_json table
|
||||
); err != nil {
|
||||
return fmt.Errorf("InsertQueuePDU: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// GetNextTransactionPDUs retrieves events from the database for
|
||||
// the next pending transaction, up to the limit specified.
|
||||
func (d *Database) GetPendingPDUs(
|
||||
ctx context.Context,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
limit int,
|
||||
) (
|
||||
events map[*Receipt]*gomatrixserverlib.HeaderedEvent,
|
||||
err error,
|
||||
) {
|
||||
// Strictly speaking this doesn't need to be using the writer
|
||||
// since we are only performing selects, but since we don't have
|
||||
// a guarantee of transactional isolation, it's actually useful
|
||||
// to know in SQLite mode that nothing else is trying to modify
|
||||
// the database.
|
||||
events = make(map[*Receipt]*gomatrixserverlib.HeaderedEvent)
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
nids, err := d.FederationQueuePDUs.SelectQueuePDUs(ctx, txn, serverName, limit)
|
||||
if err != nil {
|
||||
return fmt.Errorf("SelectQueuePDUs: %w", err)
|
||||
}
|
||||
|
||||
retrieve := make([]int64, 0, len(nids))
|
||||
for _, nid := range nids {
|
||||
if event, ok := d.Cache.GetFederationQueuedPDU(nid); ok {
|
||||
events[&Receipt{nid}] = event
|
||||
} else {
|
||||
retrieve = append(retrieve, nid)
|
||||
}
|
||||
}
|
||||
|
||||
blobs, err := d.FederationQueueJSON.SelectQueueJSON(ctx, txn, retrieve)
|
||||
if err != nil {
|
||||
return fmt.Errorf("SelectQueueJSON: %w", err)
|
||||
}
|
||||
|
||||
for nid, blob := range blobs {
|
||||
var event gomatrixserverlib.HeaderedEvent
|
||||
if err := json.Unmarshal(blob, &event); err != nil {
|
||||
return fmt.Errorf("json.Unmarshal: %w", err)
|
||||
}
|
||||
events[&Receipt{nid}] = &event
|
||||
d.Cache.StoreFederationQueuedPDU(nid, &event)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// CleanTransactionPDUs cleans up all associated events for a
|
||||
// given transaction. This is done when the transaction was sent
|
||||
// successfully.
|
||||
func (d *Database) CleanPDUs(
|
||||
ctx context.Context,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
receipts []*Receipt,
|
||||
) error {
|
||||
if len(receipts) == 0 {
|
||||
return errors.New("expected receipt")
|
||||
}
|
||||
|
||||
nids := make([]int64, len(receipts))
|
||||
for i := range receipts {
|
||||
nids[i] = receipts[i].nid
|
||||
}
|
||||
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
if err := d.FederationQueuePDUs.DeleteQueuePDUs(ctx, txn, serverName, nids); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var deleteNIDs []int64
|
||||
for _, nid := range nids {
|
||||
count, err := d.FederationQueuePDUs.SelectQueuePDUReferenceJSONCount(ctx, txn, nid)
|
||||
if err != nil {
|
||||
return fmt.Errorf("SelectQueuePDUReferenceJSONCount: %w", err)
|
||||
}
|
||||
if count == 0 {
|
||||
deleteNIDs = append(deleteNIDs, nid)
|
||||
d.Cache.EvictFederationQueuedPDU(nid)
|
||||
}
|
||||
}
|
||||
|
||||
if len(deleteNIDs) > 0 {
|
||||
if err := d.FederationQueueJSON.DeleteQueueJSON(ctx, txn, deleteNIDs); err != nil {
|
||||
return fmt.Errorf("DeleteQueueJSON: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// GetPendingPDUCount returns the number of PDUs waiting to be
|
||||
// sent for a given servername.
|
||||
func (d *Database) GetPendingPDUCount(
|
||||
ctx context.Context,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
) (int64, error) {
|
||||
return d.FederationQueuePDUs.SelectQueuePDUCount(ctx, nil, serverName)
|
||||
}
|
||||
|
||||
// GetPendingServerNames returns the server names that have PDUs
|
||||
// waiting to be sent.
|
||||
func (d *Database) GetPendingPDUServerNames(
|
||||
ctx context.Context,
|
||||
) ([]gomatrixserverlib.ServerName, error) {
|
||||
return d.FederationQueuePDUs.SelectQueuePDUServerNames(ctx, nil)
|
||||
}
|
115
federationapi/storage/sqlite3/blacklist_table.go
Normal file
115
federationapi/storage/sqlite3/blacklist_table.go
Normal file
|
@ -0,0 +1,115 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const blacklistSchema = `
|
||||
CREATE TABLE IF NOT EXISTS federationsender_blacklist (
|
||||
-- The blacklisted server name
|
||||
server_name TEXT NOT NULL,
|
||||
UNIQUE (server_name)
|
||||
);
|
||||
`
|
||||
|
||||
const insertBlacklistSQL = "" +
|
||||
"INSERT INTO federationsender_blacklist (server_name) VALUES ($1)" +
|
||||
" ON CONFLICT DO NOTHING"
|
||||
|
||||
const selectBlacklistSQL = "" +
|
||||
"SELECT server_name FROM federationsender_blacklist WHERE server_name = $1"
|
||||
|
||||
const deleteBlacklistSQL = "" +
|
||||
"DELETE FROM federationsender_blacklist WHERE server_name = $1"
|
||||
|
||||
const deleteAllBlacklistSQL = "" +
|
||||
"DELETE FROM federationsender_blacklist"
|
||||
|
||||
type blacklistStatements struct {
|
||||
db *sql.DB
|
||||
insertBlacklistStmt *sql.Stmt
|
||||
selectBlacklistStmt *sql.Stmt
|
||||
deleteBlacklistStmt *sql.Stmt
|
||||
deleteAllBlacklistStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSQLiteBlacklistTable(db *sql.DB) (s *blacklistStatements, err error) {
|
||||
s = &blacklistStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err = db.Exec(blacklistSchema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if s.insertBlacklistStmt, err = db.Prepare(insertBlacklistSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectBlacklistStmt, err = db.Prepare(selectBlacklistSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteBlacklistStmt, err = db.Prepare(deleteBlacklistSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteAllBlacklistStmt, err = db.Prepare(deleteAllBlacklistSQL); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *blacklistStatements) InsertBlacklist(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertBlacklistStmt)
|
||||
_, err := stmt.ExecContext(ctx, serverName)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *blacklistStatements) SelectBlacklist(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
||||
) (bool, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectBlacklistStmt)
|
||||
res, err := stmt.QueryContext(ctx, serverName)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer res.Close() // nolint:errcheck
|
||||
// The query will return the server name if the server is blacklisted, and
|
||||
// will return no rows if not. By calling Next, we find out if a row was
|
||||
// returned or not - we don't care about the value itself.
|
||||
return res.Next(), nil
|
||||
}
|
||||
|
||||
func (s *blacklistStatements) DeleteBlacklist(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteBlacklistStmt)
|
||||
_, err := stmt.ExecContext(ctx, serverName)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *blacklistStatements) DeleteAllBlacklist(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteAllBlacklistStmt)
|
||||
_, err := stmt.ExecContext(ctx)
|
||||
return err
|
||||
}
|
|
@ -0,0 +1,46 @@
|
|||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package deltas
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/pressly/goose"
|
||||
)
|
||||
|
||||
func LoadFromGoose() {
|
||||
goose.AddMigration(UpRemoveRoomsTable, DownRemoveRoomsTable)
|
||||
}
|
||||
|
||||
func LoadRemoveRoomsTable(m *sqlutil.Migrations) {
|
||||
m.AddMigration(UpRemoveRoomsTable, DownRemoveRoomsTable)
|
||||
}
|
||||
|
||||
func UpRemoveRoomsTable(tx *sql.Tx) error {
|
||||
_, err := tx.Exec(`
|
||||
DROP TABLE IF EXISTS federationsender_rooms;
|
||||
`)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to execute upgrade: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func DownRemoveRoomsTable(tx *sql.Tx) error {
|
||||
// We can't reverse this.
|
||||
return nil
|
||||
}
|
176
federationapi/storage/sqlite3/inbound_peeks_table.go
Normal file
176
federationapi/storage/sqlite3/inbound_peeks_table.go
Normal file
|
@ -0,0 +1,176 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/federationapi/types"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const inboundPeeksSchema = `
|
||||
CREATE TABLE IF NOT EXISTS federationsender_inbound_peeks (
|
||||
room_id TEXT NOT NULL,
|
||||
server_name TEXT NOT NULL,
|
||||
peek_id TEXT NOT NULL,
|
||||
creation_ts INTEGER NOT NULL,
|
||||
renewed_ts INTEGER NOT NULL,
|
||||
renewal_interval INTEGER NOT NULL,
|
||||
UNIQUE (room_id, server_name, peek_id)
|
||||
);
|
||||
`
|
||||
|
||||
const insertInboundPeekSQL = "" +
|
||||
"INSERT INTO federationsender_inbound_peeks (room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval) VALUES ($1, $2, $3, $4, $5, $6)"
|
||||
|
||||
const selectInboundPeekSQL = "" +
|
||||
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
|
||||
|
||||
const selectInboundPeeksSQL = "" +
|
||||
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_inbound_peeks WHERE room_id = $1"
|
||||
|
||||
const renewInboundPeekSQL = "" +
|
||||
"UPDATE federationsender_inbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5"
|
||||
|
||||
const deleteInboundPeekSQL = "" +
|
||||
"DELETE FROM federationsender_inbound_peeks WHERE room_id = $1 and server_name = $2"
|
||||
|
||||
const deleteInboundPeeksSQL = "" +
|
||||
"DELETE FROM federationsender_inbound_peeks WHERE room_id = $1"
|
||||
|
||||
type inboundPeeksStatements struct {
|
||||
db *sql.DB
|
||||
insertInboundPeekStmt *sql.Stmt
|
||||
selectInboundPeekStmt *sql.Stmt
|
||||
selectInboundPeeksStmt *sql.Stmt
|
||||
renewInboundPeekStmt *sql.Stmt
|
||||
deleteInboundPeekStmt *sql.Stmt
|
||||
deleteInboundPeeksStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSQLiteInboundPeeksTable(db *sql.DB) (s *inboundPeeksStatements, err error) {
|
||||
s = &inboundPeeksStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err = db.Exec(inboundPeeksSchema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if s.insertInboundPeekStmt, err = db.Prepare(insertInboundPeekSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectInboundPeekStmt, err = db.Prepare(selectInboundPeekSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectInboundPeeksStmt, err = db.Prepare(selectInboundPeeksSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.renewInboundPeekStmt, err = db.Prepare(renewInboundPeekSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteInboundPeeksStmt, err = db.Prepare(deleteInboundPeeksSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteInboundPeekStmt, err = db.Prepare(deleteInboundPeekSQL); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *inboundPeeksStatements) InsertInboundPeek(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64,
|
||||
) (err error) {
|
||||
nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
|
||||
stmt := sqlutil.TxStmt(txn, s.insertInboundPeekStmt)
|
||||
_, err = stmt.ExecContext(ctx, roomID, serverName, peekID, nowMilli, nowMilli, renewalInterval)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *inboundPeeksStatements) RenewInboundPeek(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64,
|
||||
) (err error) {
|
||||
nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
|
||||
_, err = sqlutil.TxStmt(txn, s.renewInboundPeekStmt).ExecContext(ctx, nowMilli, renewalInterval, roomID, serverName, peekID)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *inboundPeeksStatements) SelectInboundPeek(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string,
|
||||
) (*types.InboundPeek, error) {
|
||||
row := sqlutil.TxStmt(txn, s.selectInboundPeeksStmt).QueryRowContext(ctx, roomID)
|
||||
inboundPeek := types.InboundPeek{}
|
||||
err := row.Scan(
|
||||
&inboundPeek.RoomID,
|
||||
&inboundPeek.ServerName,
|
||||
&inboundPeek.PeekID,
|
||||
&inboundPeek.CreationTimestamp,
|
||||
&inboundPeek.RenewedTimestamp,
|
||||
&inboundPeek.RenewalInterval,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &inboundPeek, nil
|
||||
}
|
||||
|
||||
func (s *inboundPeeksStatements) SelectInboundPeeks(
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) (inboundPeeks []types.InboundPeek, err error) {
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectInboundPeeksStmt).QueryContext(ctx, roomID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "SelectInboundPeeks: rows.close() failed")
|
||||
|
||||
for rows.Next() {
|
||||
inboundPeek := types.InboundPeek{}
|
||||
if err = rows.Scan(
|
||||
&inboundPeek.RoomID,
|
||||
&inboundPeek.ServerName,
|
||||
&inboundPeek.PeekID,
|
||||
&inboundPeek.CreationTimestamp,
|
||||
&inboundPeek.RenewedTimestamp,
|
||||
&inboundPeek.RenewalInterval,
|
||||
); err != nil {
|
||||
return
|
||||
}
|
||||
inboundPeeks = append(inboundPeeks, inboundPeek)
|
||||
}
|
||||
|
||||
return inboundPeeks, rows.Err()
|
||||
}
|
||||
|
||||
func (s *inboundPeeksStatements) DeleteInboundPeek(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string,
|
||||
) (err error) {
|
||||
_, err = sqlutil.TxStmt(txn, s.deleteInboundPeekStmt).ExecContext(ctx, roomID, serverName, peekID)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *inboundPeeksStatements) DeleteInboundPeeks(
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) (err error) {
|
||||
_, err = sqlutil.TxStmt(txn, s.deleteInboundPeeksStmt).ExecContext(ctx, roomID)
|
||||
return
|
||||
}
|
219
federationapi/storage/sqlite3/joined_hosts_table.go
Normal file
219
federationapi/storage/sqlite3/joined_hosts_table.go
Normal file
|
@ -0,0 +1,219 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strings"
|
||||
|
||||
"github.com/matrix-org/dendrite/federationapi/types"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const joinedHostsSchema = `
|
||||
-- The joined_hosts table stores a list of m.room.member event ids in the
|
||||
-- current state for each room where the membership is "join".
|
||||
-- There will be an entry for every user that is joined to the room.
|
||||
CREATE TABLE IF NOT EXISTS federationsender_joined_hosts (
|
||||
-- The string ID of the room.
|
||||
room_id TEXT NOT NULL,
|
||||
-- The event ID of the m.room.member join event.
|
||||
event_id TEXT NOT NULL,
|
||||
-- The domain part of the user ID the m.room.member event is for.
|
||||
server_name TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS federatonsender_joined_hosts_event_id_idx
|
||||
ON federationsender_joined_hosts (event_id);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS federatonsender_joined_hosts_room_id_idx
|
||||
ON federationsender_joined_hosts (room_id)
|
||||
`
|
||||
|
||||
const insertJoinedHostsSQL = "" +
|
||||
"INSERT OR IGNORE INTO federationsender_joined_hosts (room_id, event_id, server_name)" +
|
||||
" VALUES ($1, $2, $3)"
|
||||
|
||||
const deleteJoinedHostsSQL = "" +
|
||||
"DELETE FROM federationsender_joined_hosts WHERE event_id = $1"
|
||||
|
||||
const deleteJoinedHostsForRoomSQL = "" +
|
||||
"DELETE FROM federationsender_joined_hosts WHERE room_id = $1"
|
||||
|
||||
const selectJoinedHostsSQL = "" +
|
||||
"SELECT event_id, server_name FROM federationsender_joined_hosts" +
|
||||
" WHERE room_id = $1"
|
||||
|
||||
const selectAllJoinedHostsSQL = "" +
|
||||
"SELECT DISTINCT server_name FROM federationsender_joined_hosts"
|
||||
|
||||
const selectJoinedHostsForRoomsSQL = "" +
|
||||
"SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id IN ($1)"
|
||||
|
||||
type joinedHostsStatements struct {
|
||||
db *sql.DB
|
||||
insertJoinedHostsStmt *sql.Stmt
|
||||
deleteJoinedHostsStmt *sql.Stmt
|
||||
deleteJoinedHostsForRoomStmt *sql.Stmt
|
||||
selectJoinedHostsStmt *sql.Stmt
|
||||
selectAllJoinedHostsStmt *sql.Stmt
|
||||
// selectJoinedHostsForRoomsStmt *sql.Stmt - prepared at runtime due to variadic
|
||||
}
|
||||
|
||||
func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) {
|
||||
s = &joinedHostsStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err = db.Exec(joinedHostsSchema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if s.insertJoinedHostsStmt, err = db.Prepare(insertJoinedHostsSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteJoinedHostsStmt, err = db.Prepare(deleteJoinedHostsSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteJoinedHostsForRoomStmt, err = s.db.Prepare(deleteJoinedHostsForRoomSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectJoinedHostsStmt, err = db.Prepare(selectJoinedHostsSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectAllJoinedHostsStmt, err = db.Prepare(selectAllJoinedHostsSQL); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *joinedHostsStatements) InsertJoinedHosts(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx,
|
||||
roomID, eventID string,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertJoinedHostsStmt)
|
||||
_, err := stmt.ExecContext(ctx, roomID, eventID, serverName)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *joinedHostsStatements) DeleteJoinedHosts(
|
||||
ctx context.Context, txn *sql.Tx, eventIDs []string,
|
||||
) error {
|
||||
for _, eventID := range eventIDs {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsStmt)
|
||||
if _, err := stmt.ExecContext(ctx, eventID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *joinedHostsStatements) DeleteJoinedHostsForRoom(
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteJoinedHostsForRoomStmt)
|
||||
_, err := stmt.ExecContext(ctx, roomID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *joinedHostsStatements) SelectJoinedHostsWithTx(
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) ([]types.JoinedHost, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectJoinedHostsStmt)
|
||||
return joinedHostsFromStmt(ctx, stmt, roomID)
|
||||
}
|
||||
|
||||
func (s *joinedHostsStatements) SelectJoinedHosts(
|
||||
ctx context.Context, roomID string,
|
||||
) ([]types.JoinedHost, error) {
|
||||
return joinedHostsFromStmt(ctx, s.selectJoinedHostsStmt, roomID)
|
||||
}
|
||||
|
||||
func (s *joinedHostsStatements) SelectAllJoinedHosts(
|
||||
ctx context.Context,
|
||||
) ([]gomatrixserverlib.ServerName, error) {
|
||||
rows, err := s.selectAllJoinedHostsStmt.QueryContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectAllJoinedHosts: rows.close() failed")
|
||||
|
||||
var result []gomatrixserverlib.ServerName
|
||||
for rows.Next() {
|
||||
var serverName string
|
||||
if err = rows.Scan(&serverName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, gomatrixserverlib.ServerName(serverName))
|
||||
}
|
||||
|
||||
return result, rows.Err()
|
||||
}
|
||||
|
||||
func (s *joinedHostsStatements) SelectJoinedHostsForRooms(
|
||||
ctx context.Context, roomIDs []string,
|
||||
) ([]gomatrixserverlib.ServerName, error) {
|
||||
iRoomIDs := make([]interface{}, len(roomIDs))
|
||||
for i := range roomIDs {
|
||||
iRoomIDs[i] = roomIDs[i]
|
||||
}
|
||||
|
||||
sql := strings.Replace(selectJoinedHostsForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomIDs)), 1)
|
||||
rows, err := s.db.QueryContext(ctx, sql, iRoomIDs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectJoinedHostsForRoomsStmt: rows.close() failed")
|
||||
|
||||
var result []gomatrixserverlib.ServerName
|
||||
for rows.Next() {
|
||||
var serverName string
|
||||
if err = rows.Scan(&serverName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, gomatrixserverlib.ServerName(serverName))
|
||||
}
|
||||
|
||||
return result, rows.Err()
|
||||
}
|
||||
|
||||
func joinedHostsFromStmt(
|
||||
ctx context.Context, stmt *sql.Stmt, roomID string,
|
||||
) ([]types.JoinedHost, error) {
|
||||
rows, err := stmt.QueryContext(ctx, roomID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "joinedHostsFromStmt: rows.close() failed")
|
||||
|
||||
var result []types.JoinedHost
|
||||
for rows.Next() {
|
||||
var eventID, serverName string
|
||||
if err = rows.Scan(&eventID, &serverName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, types.JoinedHost{
|
||||
MemberEventID: eventID,
|
||||
ServerName: gomatrixserverlib.ServerName(serverName),
|
||||
})
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
|
@ -0,0 +1,63 @@
|
|||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/dendrite/federationapi/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const notaryServerKeysJSONSchema = `
|
||||
CREATE TABLE IF NOT EXISTS federationsender_notary_server_keys_json (
|
||||
notary_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
response_json TEXT NOT NULL,
|
||||
server_name TEXT NOT NULL,
|
||||
valid_until BIGINT NOT NULL
|
||||
);
|
||||
`
|
||||
|
||||
const insertServerKeysJSONSQL = "" +
|
||||
"INSERT INTO federationsender_notary_server_keys_json (response_json, server_name, valid_until) VALUES ($1, $2, $3)" +
|
||||
" RETURNING notary_id"
|
||||
|
||||
type notaryServerKeysStatements struct {
|
||||
db *sql.DB
|
||||
insertServerKeysJSONStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSQLiteNotaryServerKeysTable(db *sql.DB) (s *notaryServerKeysStatements, err error) {
|
||||
s = ¬aryServerKeysStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err = db.Exec(notaryServerKeysJSONSchema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if s.insertServerKeysJSONStmt, err = db.Prepare(insertServerKeysJSONSQL); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *notaryServerKeysStatements) InsertJSONResponse(
|
||||
ctx context.Context, txn *sql.Tx, keyQueryResponseJSON gomatrixserverlib.ServerKeys, serverName gomatrixserverlib.ServerName, validUntil gomatrixserverlib.Timestamp,
|
||||
) (tables.NotaryID, error) {
|
||||
var notaryID tables.NotaryID
|
||||
return notaryID, txn.Stmt(s.insertServerKeysJSONStmt).QueryRowContext(ctx, string(keyQueryResponseJSON.Raw), serverName, validUntil).Scan(¬aryID)
|
||||
}
|
|
@ -0,0 +1,169 @@
|
|||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/matrix-org/dendrite/federationapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const notaryServerKeysMetadataSchema = `
|
||||
CREATE TABLE IF NOT EXISTS federationsender_notary_server_keys_metadata (
|
||||
notary_id BIGINT NOT NULL,
|
||||
server_name TEXT NOT NULL,
|
||||
key_id TEXT NOT NULL,
|
||||
UNIQUE (server_name, key_id)
|
||||
);
|
||||
`
|
||||
|
||||
const upsertServerKeysSQL = "" +
|
||||
"INSERT INTO federationsender_notary_server_keys_metadata (notary_id, server_name, key_id) VALUES ($1, $2, $3)" +
|
||||
" ON CONFLICT (server_name, key_id) DO UPDATE SET notary_id = $1"
|
||||
|
||||
// for a given (server_name, key_id), find the existing notary ID and valid until. Used to check if we will replace it
|
||||
// JOINs with the json table
|
||||
const selectNotaryKeyMetadataSQL = `
|
||||
SELECT federationsender_notary_server_keys_metadata.notary_id, valid_until FROM federationsender_notary_server_keys_json
|
||||
JOIN federationsender_notary_server_keys_metadata ON
|
||||
federationsender_notary_server_keys_metadata.notary_id = federationsender_notary_server_keys_json.notary_id
|
||||
WHERE federationsender_notary_server_keys_metadata.server_name = $1 AND federationsender_notary_server_keys_metadata.key_id = $2
|
||||
`
|
||||
|
||||
// select the response which has the highest valid_until value
|
||||
// JOINs with the json table
|
||||
const selectNotaryKeyResponsesSQL = `
|
||||
SELECT response_json FROM federationsender_notary_server_keys_json
|
||||
WHERE server_name = $1 AND valid_until = (
|
||||
SELECT MAX(valid_until) FROM federationsender_notary_server_keys_json WHERE server_name = $1
|
||||
)
|
||||
`
|
||||
|
||||
// select the responses which have the given key IDs
|
||||
// JOINs with the json table
|
||||
const selectNotaryKeyResponsesWithKeyIDsSQL = `
|
||||
SELECT response_json FROM federationsender_notary_server_keys_json
|
||||
JOIN federationsender_notary_server_keys_metadata ON
|
||||
federationsender_notary_server_keys_metadata.notary_id = federationsender_notary_server_keys_json.notary_id
|
||||
WHERE federationsender_notary_server_keys_json.server_name = $1 AND federationsender_notary_server_keys_metadata.key_id IN ($2)
|
||||
GROUP BY federationsender_notary_server_keys_json.notary_id
|
||||
`
|
||||
|
||||
// JOINs with the metadata table
|
||||
const deleteUnusedServerKeysJSONSQL = `
|
||||
DELETE FROM federationsender_notary_server_keys_json WHERE federationsender_notary_server_keys_json.notary_id NOT IN (
|
||||
SELECT DISTINCT notary_id FROM federationsender_notary_server_keys_metadata
|
||||
)
|
||||
`
|
||||
|
||||
type notaryServerKeysMetadataStatements struct {
|
||||
db *sql.DB
|
||||
upsertServerKeysStmt *sql.Stmt
|
||||
selectNotaryKeyResponsesStmt *sql.Stmt
|
||||
selectNotaryKeyMetadataStmt *sql.Stmt
|
||||
deleteUnusedServerKeysJSONStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSQLiteNotaryServerKeysMetadataTable(db *sql.DB) (s *notaryServerKeysMetadataStatements, err error) {
|
||||
s = ¬aryServerKeysMetadataStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err = db.Exec(notaryServerKeysMetadataSchema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if s.upsertServerKeysStmt, err = db.Prepare(upsertServerKeysSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectNotaryKeyResponsesStmt, err = db.Prepare(selectNotaryKeyResponsesSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectNotaryKeyMetadataStmt, err = db.Prepare(selectNotaryKeyMetadataSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteUnusedServerKeysJSONStmt, err = db.Prepare(deleteUnusedServerKeysJSONSQL); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *notaryServerKeysMetadataStatements) UpsertKey(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, keyID gomatrixserverlib.KeyID, newNotaryID tables.NotaryID, newValidUntil gomatrixserverlib.Timestamp,
|
||||
) (tables.NotaryID, error) {
|
||||
notaryID := newNotaryID
|
||||
// see if the existing notary ID a) exists, b) has a longer valid_until
|
||||
var existingNotaryID tables.NotaryID
|
||||
var existingValidUntil gomatrixserverlib.Timestamp
|
||||
if err := txn.Stmt(s.selectNotaryKeyMetadataStmt).QueryRowContext(ctx, serverName, keyID).Scan(&existingNotaryID, &existingValidUntil); err != nil {
|
||||
if err != sql.ErrNoRows {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
if existingValidUntil.Time().After(newValidUntil.Time()) {
|
||||
// the existing valid_until is valid longer, so use that.
|
||||
return existingNotaryID, nil
|
||||
}
|
||||
// overwrite the notary_id for this (server_name, key_id) tuple
|
||||
_, err := txn.Stmt(s.upsertServerKeysStmt).ExecContext(ctx, notaryID, serverName, keyID)
|
||||
return notaryID, err
|
||||
}
|
||||
|
||||
func (s *notaryServerKeysMetadataStatements) SelectKeys(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, keyIDs []gomatrixserverlib.KeyID) ([]gomatrixserverlib.ServerKeys, error) {
|
||||
var rows *sql.Rows
|
||||
var err error
|
||||
if len(keyIDs) == 0 {
|
||||
rows, err = txn.Stmt(s.selectNotaryKeyResponsesStmt).QueryContext(ctx, string(serverName))
|
||||
} else {
|
||||
iKeyIDs := make([]interface{}, len(keyIDs)+1)
|
||||
iKeyIDs[0] = serverName
|
||||
for i := range keyIDs {
|
||||
iKeyIDs[i+1] = string(keyIDs[i])
|
||||
}
|
||||
sql := strings.Replace(selectNotaryKeyResponsesWithKeyIDsSQL, "($2)", sqlutil.QueryVariadicOffset(len(keyIDs), 1), 1)
|
||||
fmt.Println(sql)
|
||||
fmt.Println(iKeyIDs...)
|
||||
rows, err = s.db.QueryContext(ctx, sql, iKeyIDs...)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectNotaryKeyResponsesStmt close failed")
|
||||
var results []gomatrixserverlib.ServerKeys
|
||||
for rows.Next() {
|
||||
var sk gomatrixserverlib.ServerKeys
|
||||
var raw string
|
||||
if err = rows.Scan(&raw); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = json.Unmarshal([]byte(raw), &sk); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
results = append(results, sk)
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (s *notaryServerKeysMetadataStatements) DeleteOldJSONResponses(ctx context.Context, txn *sql.Tx) error {
|
||||
_, err := txn.Stmt(s.deleteUnusedServerKeysJSONStmt).ExecContext(ctx)
|
||||
return err
|
||||
}
|
176
federationapi/storage/sqlite3/outbound_peeks_table.go
Normal file
176
federationapi/storage/sqlite3/outbound_peeks_table.go
Normal file
|
@ -0,0 +1,176 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/federationapi/types"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const outboundPeeksSchema = `
|
||||
CREATE TABLE IF NOT EXISTS federationsender_outbound_peeks (
|
||||
room_id TEXT NOT NULL,
|
||||
server_name TEXT NOT NULL,
|
||||
peek_id TEXT NOT NULL,
|
||||
creation_ts INTEGER NOT NULL,
|
||||
renewed_ts INTEGER NOT NULL,
|
||||
renewal_interval INTEGER NOT NULL,
|
||||
UNIQUE (room_id, server_name, peek_id)
|
||||
);
|
||||
`
|
||||
|
||||
const insertOutboundPeekSQL = "" +
|
||||
"INSERT INTO federationsender_outbound_peeks (room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval) VALUES ($1, $2, $3, $4, $5, $6)"
|
||||
|
||||
const selectOutboundPeekSQL = "" +
|
||||
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2 and peek_id = $3"
|
||||
|
||||
const selectOutboundPeeksSQL = "" +
|
||||
"SELECT room_id, server_name, peek_id, creation_ts, renewed_ts, renewal_interval FROM federationsender_outbound_peeks WHERE room_id = $1"
|
||||
|
||||
const renewOutboundPeekSQL = "" +
|
||||
"UPDATE federationsender_outbound_peeks SET renewed_ts=$1, renewal_interval=$2 WHERE room_id = $3 and server_name = $4 and peek_id = $5"
|
||||
|
||||
const deleteOutboundPeekSQL = "" +
|
||||
"DELETE FROM federationsender_outbound_peeks WHERE room_id = $1 and server_name = $2"
|
||||
|
||||
const deleteOutboundPeeksSQL = "" +
|
||||
"DELETE FROM federationsender_outbound_peeks WHERE room_id = $1"
|
||||
|
||||
type outboundPeeksStatements struct {
|
||||
db *sql.DB
|
||||
insertOutboundPeekStmt *sql.Stmt
|
||||
selectOutboundPeekStmt *sql.Stmt
|
||||
selectOutboundPeeksStmt *sql.Stmt
|
||||
renewOutboundPeekStmt *sql.Stmt
|
||||
deleteOutboundPeekStmt *sql.Stmt
|
||||
deleteOutboundPeeksStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSQLiteOutboundPeeksTable(db *sql.DB) (s *outboundPeeksStatements, err error) {
|
||||
s = &outboundPeeksStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err = db.Exec(outboundPeeksSchema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if s.insertOutboundPeekStmt, err = db.Prepare(insertOutboundPeekSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectOutboundPeekStmt, err = db.Prepare(selectOutboundPeekSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectOutboundPeeksStmt, err = db.Prepare(selectOutboundPeeksSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.renewOutboundPeekStmt, err = db.Prepare(renewOutboundPeekSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteOutboundPeeksStmt, err = db.Prepare(deleteOutboundPeeksSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteOutboundPeekStmt, err = db.Prepare(deleteOutboundPeekSQL); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *outboundPeeksStatements) InsertOutboundPeek(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64,
|
||||
) (err error) {
|
||||
nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
|
||||
stmt := sqlutil.TxStmt(txn, s.insertOutboundPeekStmt)
|
||||
_, err = stmt.ExecContext(ctx, roomID, serverName, peekID, nowMilli, nowMilli, renewalInterval)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *outboundPeeksStatements) RenewOutboundPeek(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64,
|
||||
) (err error) {
|
||||
nowMilli := time.Now().UnixNano() / int64(time.Millisecond)
|
||||
_, err = sqlutil.TxStmt(txn, s.renewOutboundPeekStmt).ExecContext(ctx, nowMilli, renewalInterval, roomID, serverName, peekID)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *outboundPeeksStatements) SelectOutboundPeek(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string,
|
||||
) (*types.OutboundPeek, error) {
|
||||
row := sqlutil.TxStmt(txn, s.selectOutboundPeeksStmt).QueryRowContext(ctx, roomID)
|
||||
outboundPeek := types.OutboundPeek{}
|
||||
err := row.Scan(
|
||||
&outboundPeek.RoomID,
|
||||
&outboundPeek.ServerName,
|
||||
&outboundPeek.PeekID,
|
||||
&outboundPeek.CreationTimestamp,
|
||||
&outboundPeek.RenewedTimestamp,
|
||||
&outboundPeek.RenewalInterval,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &outboundPeek, nil
|
||||
}
|
||||
|
||||
func (s *outboundPeeksStatements) SelectOutboundPeeks(
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) (outboundPeeks []types.OutboundPeek, err error) {
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectOutboundPeeksStmt).QueryContext(ctx, roomID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "SelectOutboundPeeks: rows.close() failed")
|
||||
|
||||
for rows.Next() {
|
||||
outboundPeek := types.OutboundPeek{}
|
||||
if err = rows.Scan(
|
||||
&outboundPeek.RoomID,
|
||||
&outboundPeek.ServerName,
|
||||
&outboundPeek.PeekID,
|
||||
&outboundPeek.CreationTimestamp,
|
||||
&outboundPeek.RenewedTimestamp,
|
||||
&outboundPeek.RenewalInterval,
|
||||
); err != nil {
|
||||
return
|
||||
}
|
||||
outboundPeeks = append(outboundPeeks, outboundPeek)
|
||||
}
|
||||
|
||||
return outboundPeeks, rows.Err()
|
||||
}
|
||||
|
||||
func (s *outboundPeeksStatements) DeleteOutboundPeek(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string,
|
||||
) (err error) {
|
||||
_, err = sqlutil.TxStmt(txn, s.deleteOutboundPeekStmt).ExecContext(ctx, roomID, serverName, peekID)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *outboundPeeksStatements) DeleteOutboundPeeks(
|
||||
ctx context.Context, txn *sql.Tx, roomID string,
|
||||
) (err error) {
|
||||
_, err = sqlutil.TxStmt(txn, s.deleteOutboundPeeksStmt).ExecContext(ctx, roomID)
|
||||
return
|
||||
}
|
207
federationapi/storage/sqlite3/queue_edus_table.go
Normal file
207
federationapi/storage/sqlite3/queue_edus_table.go
Normal file
|
@ -0,0 +1,207 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const queueEDUsSchema = `
|
||||
CREATE TABLE IF NOT EXISTS federationsender_queue_edus (
|
||||
-- The type of the event (informational).
|
||||
edu_type TEXT NOT NULL,
|
||||
-- The domain part of the user ID the EDU event is for.
|
||||
server_name TEXT NOT NULL,
|
||||
-- The JSON NID from the federationsender_queue_edus_json table.
|
||||
json_nid BIGINT NOT NULL
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_edus_json_nid_idx
|
||||
ON federationsender_queue_edus (json_nid, server_name);
|
||||
`
|
||||
|
||||
const insertQueueEDUSQL = "" +
|
||||
"INSERT INTO federationsender_queue_edus (edu_type, server_name, json_nid)" +
|
||||
" VALUES ($1, $2, $3)"
|
||||
|
||||
const deleteQueueEDUsSQL = "" +
|
||||
"DELETE FROM federationsender_queue_edus WHERE server_name = $1 AND json_nid IN ($2)"
|
||||
|
||||
const selectQueueEDUSQL = "" +
|
||||
"SELECT json_nid FROM federationsender_queue_edus" +
|
||||
" WHERE server_name = $1" +
|
||||
" LIMIT $2"
|
||||
|
||||
const selectQueueEDUReferenceJSONCountSQL = "" +
|
||||
"SELECT COUNT(*) FROM federationsender_queue_edus" +
|
||||
" WHERE json_nid = $1"
|
||||
|
||||
const selectQueueEDUCountSQL = "" +
|
||||
"SELECT COUNT(*) FROM federationsender_queue_edus" +
|
||||
" WHERE server_name = $1"
|
||||
|
||||
const selectQueueServerNamesSQL = "" +
|
||||
"SELECT DISTINCT server_name FROM federationsender_queue_edus"
|
||||
|
||||
type queueEDUsStatements struct {
|
||||
db *sql.DB
|
||||
insertQueueEDUStmt *sql.Stmt
|
||||
selectQueueEDUStmt *sql.Stmt
|
||||
selectQueueEDUReferenceJSONCountStmt *sql.Stmt
|
||||
selectQueueEDUCountStmt *sql.Stmt
|
||||
selectQueueEDUServerNamesStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSQLiteQueueEDUsTable(db *sql.DB) (s *queueEDUsStatements, err error) {
|
||||
s = &queueEDUsStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err = db.Exec(queueEDUsSchema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if s.insertQueueEDUStmt, err = db.Prepare(insertQueueEDUSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectQueueEDUStmt, err = db.Prepare(selectQueueEDUSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectQueueEDUReferenceJSONCountStmt, err = db.Prepare(selectQueueEDUReferenceJSONCountSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectQueueEDUCountStmt, err = db.Prepare(selectQueueEDUCountSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectQueueEDUServerNamesStmt, err = db.Prepare(selectQueueServerNamesSQL); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *queueEDUsStatements) InsertQueueEDU(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx,
|
||||
eduType string,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
nid int64,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertQueueEDUStmt)
|
||||
_, err := stmt.ExecContext(
|
||||
ctx,
|
||||
eduType, // the EDU type
|
||||
serverName, // destination server name
|
||||
nid, // JSON blob NID
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *queueEDUsStatements) DeleteQueueEDUs(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
jsonNIDs []int64,
|
||||
) error {
|
||||
deleteSQL := strings.Replace(deleteQueueEDUsSQL, "($2)", sqlutil.QueryVariadicOffset(len(jsonNIDs), 1), 1)
|
||||
deleteStmt, err := txn.Prepare(deleteSQL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("s.deleteQueueJSON s.db.Prepare: %w", err)
|
||||
}
|
||||
|
||||
params := make([]interface{}, len(jsonNIDs)+1)
|
||||
params[0] = serverName
|
||||
for k, v := range jsonNIDs {
|
||||
params[k+1] = v
|
||||
}
|
||||
|
||||
stmt := sqlutil.TxStmt(txn, deleteStmt)
|
||||
_, err = stmt.ExecContext(ctx, params...)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *queueEDUsStatements) SelectQueueEDUs(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
limit int,
|
||||
) ([]int64, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectQueueEDUStmt)
|
||||
rows, err := stmt.QueryContext(ctx, serverName, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed")
|
||||
var result []int64
|
||||
for rows.Next() {
|
||||
var nid int64
|
||||
if err = rows.Scan(&nid); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, nid)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *queueEDUsStatements) SelectQueueEDUReferenceJSONCount(
|
||||
ctx context.Context, txn *sql.Tx, jsonNID int64,
|
||||
) (int64, error) {
|
||||
var count int64
|
||||
stmt := sqlutil.TxStmt(txn, s.selectQueueEDUReferenceJSONCountStmt)
|
||||
err := stmt.QueryRowContext(ctx, jsonNID).Scan(&count)
|
||||
if err == sql.ErrNoRows {
|
||||
return -1, nil
|
||||
}
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (s *queueEDUsStatements) SelectQueueEDUCount(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
||||
) (int64, error) {
|
||||
var count int64
|
||||
stmt := sqlutil.TxStmt(txn, s.selectQueueEDUCountStmt)
|
||||
err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
|
||||
if err == sql.ErrNoRows {
|
||||
// It's acceptable for there to be no rows referencing a given
|
||||
// JSON NID but it's not an error condition. Just return as if
|
||||
// there's a zero count.
|
||||
return 0, nil
|
||||
}
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (s *queueEDUsStatements) SelectQueueEDUServerNames(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
) ([]gomatrixserverlib.ServerName, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectQueueEDUServerNamesStmt)
|
||||
rows, err := stmt.QueryContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed")
|
||||
var result []gomatrixserverlib.ServerName
|
||||
for rows.Next() {
|
||||
var serverName gomatrixserverlib.ServerName
|
||||
if err = rows.Scan(&serverName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, serverName)
|
||||
}
|
||||
|
||||
return result, rows.Err()
|
||||
}
|
136
federationapi/storage/sqlite3/queue_json_table.go
Normal file
136
federationapi/storage/sqlite3/queue_json_table.go
Normal file
|
@ -0,0 +1,136 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
)
|
||||
|
||||
const queueJSONSchema = `
|
||||
-- The queue_retry_json table contains event contents that
|
||||
-- we failed to send.
|
||||
CREATE TABLE IF NOT EXISTS federationsender_queue_json (
|
||||
-- The JSON NID. This allows the federationsender_queue_retry table to
|
||||
-- cross-reference to find the JSON blob.
|
||||
json_nid INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
-- The JSON body. Text so that we preserve UTF-8.
|
||||
json_body TEXT NOT NULL
|
||||
);
|
||||
`
|
||||
|
||||
const insertJSONSQL = "" +
|
||||
"INSERT INTO federationsender_queue_json (json_body)" +
|
||||
" VALUES ($1)"
|
||||
|
||||
const deleteJSONSQL = "" +
|
||||
"DELETE FROM federationsender_queue_json WHERE json_nid IN ($1)"
|
||||
|
||||
const selectJSONSQL = "" +
|
||||
"SELECT json_nid, json_body FROM federationsender_queue_json" +
|
||||
" WHERE json_nid IN ($1)"
|
||||
|
||||
type queueJSONStatements struct {
|
||||
db *sql.DB
|
||||
insertJSONStmt *sql.Stmt
|
||||
//deleteJSONStmt *sql.Stmt - prepared at runtime due to variadic
|
||||
//selectJSONStmt *sql.Stmt - prepared at runtime due to variadic
|
||||
}
|
||||
|
||||
func NewSQLiteQueueJSONTable(db *sql.DB) (s *queueJSONStatements, err error) {
|
||||
s = &queueJSONStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err = db.Exec(queueJSONSchema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if s.insertJSONStmt, err = db.Prepare(insertJSONSQL); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *queueJSONStatements) InsertQueueJSON(
|
||||
ctx context.Context, txn *sql.Tx, json string,
|
||||
) (lastid int64, err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertJSONStmt)
|
||||
res, err := stmt.ExecContext(ctx, json)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("stmt.QueryContext: %w", err)
|
||||
}
|
||||
lastid, err = res.LastInsertId()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("res.LastInsertId: %w", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *queueJSONStatements) DeleteQueueJSON(
|
||||
ctx context.Context, txn *sql.Tx, nids []int64,
|
||||
) error {
|
||||
deleteSQL := strings.Replace(deleteJSONSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1)
|
||||
deleteStmt, err := txn.Prepare(deleteSQL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("s.deleteQueueJSON s.db.Prepare: %w", err)
|
||||
}
|
||||
|
||||
iNIDs := make([]interface{}, len(nids))
|
||||
for k, v := range nids {
|
||||
iNIDs[k] = v
|
||||
}
|
||||
|
||||
stmt := sqlutil.TxStmt(txn, deleteStmt)
|
||||
_, err = stmt.ExecContext(ctx, iNIDs...)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *queueJSONStatements) SelectQueueJSON(
|
||||
ctx context.Context, txn *sql.Tx, jsonNIDs []int64,
|
||||
) (map[int64][]byte, error) {
|
||||
selectSQL := strings.Replace(selectJSONSQL, "($1)", sqlutil.QueryVariadic(len(jsonNIDs)), 1)
|
||||
selectStmt, err := txn.Prepare(selectSQL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("s.selectQueueJSON s.db.Prepare: %w", err)
|
||||
}
|
||||
|
||||
iNIDs := make([]interface{}, len(jsonNIDs))
|
||||
for k, v := range jsonNIDs {
|
||||
iNIDs[k] = v
|
||||
}
|
||||
|
||||
blobs := map[int64][]byte{}
|
||||
stmt := sqlutil.TxStmt(txn, selectStmt)
|
||||
rows, err := stmt.QueryContext(ctx, iNIDs...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("s.selectQueueJSON stmt.QueryContext: %w", err)
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "selectJSON: rows.close() failed")
|
||||
for rows.Next() {
|
||||
var nid int64
|
||||
var blob []byte
|
||||
if err = rows.Scan(&nid, &blob); err != nil {
|
||||
return nil, fmt.Errorf("s.selectQueueJSON rows.Scan: %w", err)
|
||||
}
|
||||
blobs[nid] = blob
|
||||
}
|
||||
return blobs, err
|
||||
}
|
235
federationapi/storage/sqlite3/queue_pdus_table.go
Normal file
235
federationapi/storage/sqlite3/queue_pdus_table.go
Normal file
|
@ -0,0 +1,235 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const queuePDUsSchema = `
|
||||
CREATE TABLE IF NOT EXISTS federationsender_queue_pdus (
|
||||
-- The transaction ID that was generated before persisting the event.
|
||||
transaction_id TEXT NOT NULL,
|
||||
-- The domain part of the user ID the m.room.member event is for.
|
||||
server_name TEXT NOT NULL,
|
||||
-- The JSON NID from the federationsender_queue_pdus_json table.
|
||||
json_nid BIGINT NOT NULL
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS federationsender_queue_pdus_pdus_json_nid_idx
|
||||
ON federationsender_queue_pdus (json_nid, server_name);
|
||||
`
|
||||
|
||||
const insertQueuePDUSQL = "" +
|
||||
"INSERT INTO federationsender_queue_pdus (transaction_id, server_name, json_nid)" +
|
||||
" VALUES ($1, $2, $3)"
|
||||
|
||||
const deleteQueuePDUsSQL = "" +
|
||||
"DELETE FROM federationsender_queue_pdus WHERE server_name = $1 AND json_nid IN ($2)"
|
||||
|
||||
const selectQueueNextTransactionIDSQL = "" +
|
||||
"SELECT transaction_id FROM federationsender_queue_pdus" +
|
||||
" WHERE server_name = $1" +
|
||||
" ORDER BY transaction_id ASC" +
|
||||
" LIMIT 1"
|
||||
|
||||
const selectQueuePDUsSQL = "" +
|
||||
"SELECT json_nid FROM federationsender_queue_pdus" +
|
||||
" WHERE server_name = $1" +
|
||||
" LIMIT $2"
|
||||
|
||||
const selectQueuePDUsReferenceJSONCountSQL = "" +
|
||||
"SELECT COUNT(*) FROM federationsender_queue_pdus" +
|
||||
" WHERE json_nid = $1"
|
||||
|
||||
const selectQueuePDUsCountSQL = "" +
|
||||
"SELECT COUNT(*) FROM federationsender_queue_pdus" +
|
||||
" WHERE server_name = $1"
|
||||
|
||||
const selectQueuePDUsServerNamesSQL = "" +
|
||||
"SELECT DISTINCT server_name FROM federationsender_queue_pdus"
|
||||
|
||||
type queuePDUsStatements struct {
|
||||
db *sql.DB
|
||||
insertQueuePDUStmt *sql.Stmt
|
||||
selectQueueNextTransactionIDStmt *sql.Stmt
|
||||
selectQueuePDUsStmt *sql.Stmt
|
||||
selectQueueReferenceJSONCountStmt *sql.Stmt
|
||||
selectQueuePDUsCountStmt *sql.Stmt
|
||||
selectQueueServerNamesStmt *sql.Stmt
|
||||
// deleteQueuePDUsStmt *sql.Stmt - prepared at runtime due to variadic
|
||||
}
|
||||
|
||||
func NewSQLiteQueuePDUsTable(db *sql.DB) (s *queuePDUsStatements, err error) {
|
||||
s = &queuePDUsStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err = db.Exec(queuePDUsSchema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if s.insertQueuePDUStmt, err = db.Prepare(insertQueuePDUSQL); err != nil {
|
||||
return
|
||||
}
|
||||
//if s.deleteQueuePDUsStmt, err = db.Prepare(deleteQueuePDUsSQL); err != nil {
|
||||
// return
|
||||
//}
|
||||
if s.selectQueueNextTransactionIDStmt, err = db.Prepare(selectQueueNextTransactionIDSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectQueuePDUsStmt, err = db.Prepare(selectQueuePDUsSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectQueueReferenceJSONCountStmt, err = db.Prepare(selectQueuePDUsReferenceJSONCountSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectQueuePDUsCountStmt, err = db.Prepare(selectQueuePDUsCountSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectQueueServerNamesStmt, err = db.Prepare(selectQueuePDUsServerNamesSQL); err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *queuePDUsStatements) InsertQueuePDU(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx,
|
||||
transactionID gomatrixserverlib.TransactionID,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
nid int64,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertQueuePDUStmt)
|
||||
_, err := stmt.ExecContext(
|
||||
ctx,
|
||||
transactionID, // the transaction ID that we initially attempted
|
||||
serverName, // destination server name
|
||||
nid, // JSON blob NID
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *queuePDUsStatements) DeleteQueuePDUs(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
jsonNIDs []int64,
|
||||
) error {
|
||||
deleteSQL := strings.Replace(deleteQueuePDUsSQL, "($2)", sqlutil.QueryVariadicOffset(len(jsonNIDs), 1), 1)
|
||||
deleteStmt, err := txn.Prepare(deleteSQL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("s.deleteQueueJSON s.db.Prepare: %w", err)
|
||||
}
|
||||
|
||||
params := make([]interface{}, len(jsonNIDs)+1)
|
||||
params[0] = serverName
|
||||
for k, v := range jsonNIDs {
|
||||
params[k+1] = v
|
||||
}
|
||||
|
||||
stmt := sqlutil.TxStmt(txn, deleteStmt)
|
||||
_, err = stmt.ExecContext(ctx, params...)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *queuePDUsStatements) SelectQueuePDUNextTransactionID(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
||||
) (gomatrixserverlib.TransactionID, error) {
|
||||
var transactionID gomatrixserverlib.TransactionID
|
||||
stmt := sqlutil.TxStmt(txn, s.selectQueueNextTransactionIDStmt)
|
||||
err := stmt.QueryRowContext(ctx, serverName).Scan(&transactionID)
|
||||
if err == sql.ErrNoRows {
|
||||
return "", nil
|
||||
}
|
||||
return transactionID, err
|
||||
}
|
||||
|
||||
func (s *queuePDUsStatements) SelectQueuePDUReferenceJSONCount(
|
||||
ctx context.Context, txn *sql.Tx, jsonNID int64,
|
||||
) (int64, error) {
|
||||
var count int64
|
||||
stmt := sqlutil.TxStmt(txn, s.selectQueueReferenceJSONCountStmt)
|
||||
err := stmt.QueryRowContext(ctx, jsonNID).Scan(&count)
|
||||
if err == sql.ErrNoRows {
|
||||
return -1, nil
|
||||
}
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (s *queuePDUsStatements) SelectQueuePDUCount(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
||||
) (int64, error) {
|
||||
var count int64
|
||||
stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsCountStmt)
|
||||
err := stmt.QueryRowContext(ctx, serverName).Scan(&count)
|
||||
if err == sql.ErrNoRows {
|
||||
// It's acceptable for there to be no rows referencing a given
|
||||
// JSON NID but it's not an error condition. Just return as if
|
||||
// there's a zero count.
|
||||
return 0, nil
|
||||
}
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (s *queuePDUsStatements) SelectQueuePDUs(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
limit int,
|
||||
) ([]int64, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectQueuePDUsStmt)
|
||||
rows, err := stmt.QueryContext(ctx, serverName, limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed")
|
||||
var result []int64
|
||||
for rows.Next() {
|
||||
var nid int64
|
||||
if err = rows.Scan(&nid); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, nid)
|
||||
}
|
||||
|
||||
return result, rows.Err()
|
||||
}
|
||||
|
||||
func (s *queuePDUsStatements) SelectQueuePDUServerNames(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
) ([]gomatrixserverlib.ServerName, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectQueueServerNamesStmt)
|
||||
rows, err := stmt.QueryContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "queueFromStmt: rows.close() failed")
|
||||
var result []gomatrixserverlib.ServerName
|
||||
for rows.Next() {
|
||||
var serverName gomatrixserverlib.ServerName
|
||||
if err = rows.Scan(&serverName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, serverName)
|
||||
}
|
||||
|
||||
return result, rows.Err()
|
||||
}
|
157
federationapi/storage/sqlite3/server_key_table.go
Normal file
157
federationapi/storage/sqlite3/server_key_table.go
Normal file
|
@ -0,0 +1,157 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const serverSigningKeysSchema = `
|
||||
-- A cache of signing keys downloaded from remote servers.
|
||||
CREATE TABLE IF NOT EXISTS keydb_server_keys (
|
||||
-- The name of the matrix server the key is for.
|
||||
server_name TEXT NOT NULL,
|
||||
-- The ID of the server key.
|
||||
server_key_id TEXT NOT NULL,
|
||||
-- Combined server name and key ID separated by the ASCII unit separator
|
||||
-- to make it easier to run bulk queries.
|
||||
server_name_and_key_id TEXT NOT NULL,
|
||||
-- When the key is valid until as a millisecond timestamp.
|
||||
-- 0 if this is an expired key (in which case expired_ts will be non-zero)
|
||||
valid_until_ts BIGINT NOT NULL,
|
||||
-- When the key expired as a millisecond timestamp.
|
||||
-- 0 if this is an active key (in which case valid_until_ts will be non-zero)
|
||||
expired_ts BIGINT NOT NULL,
|
||||
-- The base64-encoded public key.
|
||||
server_key TEXT NOT NULL,
|
||||
UNIQUE (server_name, server_key_id)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS keydb_server_name_and_key_id ON keydb_server_keys (server_name_and_key_id);
|
||||
`
|
||||
|
||||
const bulkSelectServerSigningKeysSQL = "" +
|
||||
"SELECT server_name, server_key_id, valid_until_ts, expired_ts, " +
|
||||
" server_key FROM keydb_server_keys" +
|
||||
" WHERE server_name_and_key_id IN ($1)"
|
||||
|
||||
const upsertServerSigningKeysSQL = "" +
|
||||
"INSERT INTO keydb_server_keys (server_name, server_key_id," +
|
||||
" server_name_and_key_id, valid_until_ts, expired_ts, server_key)" +
|
||||
" VALUES ($1, $2, $3, $4, $5, $6)" +
|
||||
" ON CONFLICT (server_name, server_key_id)" +
|
||||
" DO UPDATE SET valid_until_ts = $4, expired_ts = $5, server_key = $6"
|
||||
|
||||
type serverSigningKeyStatements struct {
|
||||
db *sql.DB
|
||||
bulkSelectServerKeysStmt *sql.Stmt
|
||||
upsertServerKeysStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSQLiteServerSigningKeysTable(db *sql.DB) (s *serverSigningKeyStatements, err error) {
|
||||
s = &serverSigningKeyStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err = db.Exec(serverSigningKeysSchema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if s.bulkSelectServerKeysStmt, err = db.Prepare(bulkSelectServerSigningKeysSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.upsertServerKeysStmt, err = db.Prepare(upsertServerSigningKeysSQL); err != nil {
|
||||
return
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (s *serverSigningKeyStatements) BulkSelectServerKeys(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
|
||||
) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) {
|
||||
nameAndKeyIDs := make([]string, 0, len(requests))
|
||||
for request := range requests {
|
||||
nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request))
|
||||
}
|
||||
results := make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, len(requests))
|
||||
iKeyIDs := make([]interface{}, len(nameAndKeyIDs))
|
||||
for i, v := range nameAndKeyIDs {
|
||||
iKeyIDs[i] = v
|
||||
}
|
||||
|
||||
err := sqlutil.RunLimitedVariablesQuery(
|
||||
ctx, bulkSelectServerSigningKeysSQL, s.db, iKeyIDs, sqlutil.SQLite3MaxVariables,
|
||||
func(rows *sql.Rows) error {
|
||||
for rows.Next() {
|
||||
var serverName string
|
||||
var keyID string
|
||||
var key string
|
||||
var validUntilTS int64
|
||||
var expiredTS int64
|
||||
if err := rows.Scan(&serverName, &keyID, &validUntilTS, &expiredTS, &key); err != nil {
|
||||
return fmt.Errorf("bulkSelectServerKeys: %v", err)
|
||||
}
|
||||
r := gomatrixserverlib.PublicKeyLookupRequest{
|
||||
ServerName: gomatrixserverlib.ServerName(serverName),
|
||||
KeyID: gomatrixserverlib.KeyID(keyID),
|
||||
}
|
||||
vk := gomatrixserverlib.VerifyKey{}
|
||||
err := vk.Key.Decode(key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("bulkSelectServerKeys: %v", err)
|
||||
}
|
||||
results[r] = gomatrixserverlib.PublicKeyLookupResult{
|
||||
VerifyKey: vk,
|
||||
ValidUntilTS: gomatrixserverlib.Timestamp(validUntilTS),
|
||||
ExpiredTS: gomatrixserverlib.Timestamp(expiredTS),
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (s *serverSigningKeyStatements) UpsertServerKeys(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
request gomatrixserverlib.PublicKeyLookupRequest,
|
||||
key gomatrixserverlib.PublicKeyLookupResult,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.upsertServerKeysStmt)
|
||||
_, err := stmt.ExecContext(
|
||||
ctx,
|
||||
string(request.ServerName),
|
||||
string(request.KeyID),
|
||||
nameAndKeyID(request),
|
||||
key.ValidUntilTS,
|
||||
key.ExpiredTS,
|
||||
key.Key.Encode(),
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
func nameAndKeyID(request gomatrixserverlib.PublicKeyLookupRequest) string {
|
||||
return string(request.ServerName) + "\x1F" + string(request.KeyID)
|
||||
}
|
108
federationapi/storage/sqlite3/storage.go
Normal file
108
federationapi/storage/sqlite3/storage.go
Normal file
|
@ -0,0 +1,108 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/dendrite/federationapi/storage/shared"
|
||||
"github.com/matrix-org/dendrite/federationapi/storage/sqlite3/deltas"
|
||||
"github.com/matrix-org/dendrite/internal/caching"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
)
|
||||
|
||||
// Database stores information needed by the federation sender
|
||||
type Database struct {
|
||||
shared.Database
|
||||
sqlutil.PartitionOffsetStatements
|
||||
db *sql.DB
|
||||
writer sqlutil.Writer
|
||||
}
|
||||
|
||||
// NewDatabase opens a new database
|
||||
func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationCache) (*Database, error) {
|
||||
var d Database
|
||||
var err error
|
||||
if d.db, err = sqlutil.Open(dbProperties); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d.writer = sqlutil.NewExclusiveWriter()
|
||||
joinedHosts, err := NewSQLiteJoinedHostsTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
queuePDUs, err := NewSQLiteQueuePDUsTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
queueEDUs, err := NewSQLiteQueueEDUsTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
queueJSON, err := NewSQLiteQueueJSONTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
blacklist, err := NewSQLiteBlacklistTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
outboundPeeks, err := NewSQLiteOutboundPeeksTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
inboundPeeks, err := NewSQLiteInboundPeeksTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
notaryKeys, err := NewSQLiteNotaryServerKeysTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
notaryKeysMetadata, err := NewSQLiteNotaryServerKeysMetadataTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
serverSigningKeys, err := NewSQLiteServerSigningKeysTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m := sqlutil.NewMigrations()
|
||||
deltas.LoadRemoveRoomsTable(m)
|
||||
if err = m.RunDeltas(d.db, dbProperties); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
d.Database = shared.Database{
|
||||
DB: d.db,
|
||||
Cache: cache,
|
||||
Writer: d.writer,
|
||||
FederationJoinedHosts: joinedHosts,
|
||||
FederationQueuePDUs: queuePDUs,
|
||||
FederationQueueEDUs: queueEDUs,
|
||||
FederationQueueJSON: queueJSON,
|
||||
FederationBlacklist: blacklist,
|
||||
FederationOutboundPeeks: outboundPeeks,
|
||||
FederationInboundPeeks: inboundPeeks,
|
||||
NotaryServerKeysJSON: notaryKeys,
|
||||
NotaryServerKeysMetadata: notaryKeysMetadata,
|
||||
ServerSigningKeys: serverSigningKeys,
|
||||
}
|
||||
if err = d.PartitionOffsetStatements.Prepare(d.db, d.writer, "federationsender"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &d, nil
|
||||
}
|
39
federationapi/storage/storage.go
Normal file
39
federationapi/storage/storage.go
Normal file
|
@ -0,0 +1,39 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//go:build !wasm
|
||||
// +build !wasm
|
||||
|
||||
package storage
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/federationapi/storage/postgres"
|
||||
"github.com/matrix-org/dendrite/federationapi/storage/sqlite3"
|
||||
"github.com/matrix-org/dendrite/internal/caching"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
)
|
||||
|
||||
// NewDatabase opens a new database
|
||||
func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationCache) (Database, error) {
|
||||
switch {
|
||||
case dbProperties.ConnectionString.IsSQLite():
|
||||
return sqlite3.NewDatabase(dbProperties, cache)
|
||||
case dbProperties.ConnectionString.IsPostgres():
|
||||
return postgres.NewDatabase(dbProperties, cache)
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected database type")
|
||||
}
|
||||
}
|
35
federationapi/storage/storage_wasm.go
Normal file
35
federationapi/storage/storage_wasm.go
Normal file
|
@ -0,0 +1,35 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package storage
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/federationapi/storage/sqlite3"
|
||||
"github.com/matrix-org/dendrite/internal/caching"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
)
|
||||
|
||||
// NewDatabase opens a new database
|
||||
func NewDatabase(dbProperties *config.DatabaseOptions, cache caching.FederationCache) (Database, error) {
|
||||
switch {
|
||||
case dbProperties.ConnectionString.IsSQLite():
|
||||
return sqlite3.NewDatabase(dbProperties, cache)
|
||||
case dbProperties.ConnectionString.IsPostgres():
|
||||
return nil, fmt.Errorf("can't use Postgres implementation")
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected database type")
|
||||
}
|
||||
}
|
111
federationapi/storage/tables/interface.go
Normal file
111
federationapi/storage/tables/interface.go
Normal file
|
@ -0,0 +1,111 @@
|
|||
// Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package tables
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/dendrite/federationapi/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
type NotaryID int64
|
||||
|
||||
type FederationQueuePDUs interface {
|
||||
InsertQueuePDU(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error
|
||||
DeleteQueuePDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error
|
||||
SelectQueuePDUReferenceJSONCount(ctx context.Context, txn *sql.Tx, jsonNID int64) (int64, error)
|
||||
SelectQueuePDUCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error)
|
||||
SelectQueuePDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error)
|
||||
SelectQueuePDUServerNames(ctx context.Context, txn *sql.Tx) ([]gomatrixserverlib.ServerName, error)
|
||||
}
|
||||
|
||||
type FederationQueueEDUs interface {
|
||||
InsertQueueEDU(ctx context.Context, txn *sql.Tx, eduType string, serverName gomatrixserverlib.ServerName, nid int64) error
|
||||
DeleteQueueEDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error
|
||||
SelectQueueEDUs(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error)
|
||||
SelectQueueEDUReferenceJSONCount(ctx context.Context, txn *sql.Tx, jsonNID int64) (int64, error)
|
||||
SelectQueueEDUCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error)
|
||||
SelectQueueEDUServerNames(ctx context.Context, txn *sql.Tx) ([]gomatrixserverlib.ServerName, error)
|
||||
}
|
||||
|
||||
type FederationQueueJSON interface {
|
||||
InsertQueueJSON(ctx context.Context, txn *sql.Tx, json string) (int64, error)
|
||||
DeleteQueueJSON(ctx context.Context, txn *sql.Tx, nids []int64) error
|
||||
SelectQueueJSON(ctx context.Context, txn *sql.Tx, jsonNIDs []int64) (map[int64][]byte, error)
|
||||
}
|
||||
|
||||
type FederationJoinedHosts interface {
|
||||
InsertJoinedHosts(ctx context.Context, txn *sql.Tx, roomID, eventID string, serverName gomatrixserverlib.ServerName) error
|
||||
DeleteJoinedHosts(ctx context.Context, txn *sql.Tx, eventIDs []string) error
|
||||
DeleteJoinedHostsForRoom(ctx context.Context, txn *sql.Tx, roomID string) error
|
||||
SelectJoinedHostsWithTx(ctx context.Context, txn *sql.Tx, roomID string) ([]types.JoinedHost, error)
|
||||
SelectJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error)
|
||||
SelectAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
|
||||
SelectJoinedHostsForRooms(ctx context.Context, roomIDs []string) ([]gomatrixserverlib.ServerName, error)
|
||||
}
|
||||
|
||||
type FederationBlacklist interface {
|
||||
InsertBlacklist(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) error
|
||||
SelectBlacklist(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (bool, error)
|
||||
DeleteBlacklist(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) error
|
||||
DeleteAllBlacklist(ctx context.Context, txn *sql.Tx) error
|
||||
}
|
||||
|
||||
type FederationOutboundPeeks interface {
|
||||
InsertOutboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) (err error)
|
||||
RenewOutboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) (err error)
|
||||
SelectOutboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string) (outboundPeek *types.OutboundPeek, err error)
|
||||
SelectOutboundPeeks(ctx context.Context, txn *sql.Tx, roomID string) (outboundPeeks []types.OutboundPeek, err error)
|
||||
DeleteOutboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string) (err error)
|
||||
DeleteOutboundPeeks(ctx context.Context, txn *sql.Tx, roomID string) (err error)
|
||||
}
|
||||
|
||||
type FederationInboundPeeks interface {
|
||||
InsertInboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) (err error)
|
||||
RenewInboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) (err error)
|
||||
SelectInboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string) (inboundPeek *types.InboundPeek, err error)
|
||||
SelectInboundPeeks(ctx context.Context, txn *sql.Tx, roomID string) (inboundPeeks []types.InboundPeek, err error)
|
||||
DeleteInboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string) (err error)
|
||||
DeleteInboundPeeks(ctx context.Context, txn *sql.Tx, roomID string) (err error)
|
||||
}
|
||||
|
||||
// FederationNotaryServerKeysJSON contains the byte-for-byte responses from servers which contain their keys and is signed by them.
|
||||
type FederationNotaryServerKeysJSON interface {
|
||||
// InsertJSONResponse inserts a new response JSON. Useless on its own, needs querying via FederationNotaryServerKeysMetadata
|
||||
// `validUntil` should be the value of `valid_until_ts` with the 7-day check applied from:
|
||||
// "Servers MUST use the lesser of this field and 7 days into the future when determining if a key is valid.
|
||||
// This is to avoid a situation where an attacker publishes a key which is valid for a significant amount of time
|
||||
// without a way for the homeserver owner to revoke it.""
|
||||
InsertJSONResponse(ctx context.Context, txn *sql.Tx, keyQueryResponseJSON gomatrixserverlib.ServerKeys, serverName gomatrixserverlib.ServerName, validUntil gomatrixserverlib.Timestamp) (NotaryID, error)
|
||||
}
|
||||
|
||||
// FederationNotaryServerKeysMetadata persists the metadata for FederationNotaryServerKeysJSON
|
||||
type FederationNotaryServerKeysMetadata interface {
|
||||
// UpsertKey updates or inserts a (server_name, key_id) tuple, pointing it via NotaryID at the the response which has the longest valid_until_ts
|
||||
// `newNotaryID` and `newValidUntil` should be the notary ID / valid_until which has this (server_name, key_id) tuple already, e.g one you just inserted.
|
||||
UpsertKey(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, keyID gomatrixserverlib.KeyID, newNotaryID NotaryID, newValidUntil gomatrixserverlib.Timestamp) (NotaryID, error)
|
||||
// SelectKeys returns the signed JSON objects which contain the given key IDs. This will be at most the length of `keyIDs` and at least 1 (assuming
|
||||
// the keys exist in the first place). If `keyIDs` is empty, the signed JSON object with the longest valid_until_ts will be returned.
|
||||
SelectKeys(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, keyIDs []gomatrixserverlib.KeyID) ([]gomatrixserverlib.ServerKeys, error)
|
||||
// DeleteOldJSONResponses removes all responses which are not referenced in FederationNotaryServerKeysMetadata
|
||||
DeleteOldJSONResponses(ctx context.Context, txn *sql.Tx) error
|
||||
}
|
||||
|
||||
type FederationServerSigningKeys interface {
|
||||
BulkSelectServerKeys(ctx context.Context, txn *sql.Tx, requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error)
|
||||
UpsertServerKeys(ctx context.Context, txn *sql.Tx, request gomatrixserverlib.PublicKeyLookupRequest, key gomatrixserverlib.PublicKeyLookupResult) error
|
||||
}
|
53
federationapi/types/types.go
Normal file
53
federationapi/types/types.go
Normal file
|
@ -0,0 +1,53 @@
|
|||
// Copyright 2017 Vector Creations Ltd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package types
|
||||
|
||||
import (
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// A JoinedHost is a server that is joined to a matrix room.
|
||||
type JoinedHost struct {
|
||||
// The MemberEventID of a m.room.member join event.
|
||||
MemberEventID string
|
||||
// The domain part of the state key of the m.room.member join event
|
||||
ServerName gomatrixserverlib.ServerName
|
||||
}
|
||||
|
||||
type ServerNames []gomatrixserverlib.ServerName
|
||||
|
||||
func (s ServerNames) Len() int { return len(s) }
|
||||
func (s ServerNames) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
||||
func (s ServerNames) Less(i, j int) bool { return s[i] < s[j] }
|
||||
|
||||
// tracks peeks we're performing on another server over federation
|
||||
type OutboundPeek struct {
|
||||
PeekID string
|
||||
RoomID string
|
||||
ServerName gomatrixserverlib.ServerName
|
||||
CreationTimestamp int64
|
||||
RenewedTimestamp int64
|
||||
RenewalInterval int64
|
||||
}
|
||||
|
||||
// tracks peeks other servers are performing on us over federation
|
||||
type InboundPeek struct {
|
||||
PeekID string
|
||||
RoomID string
|
||||
ServerName gomatrixserverlib.ServerName
|
||||
CreationTimestamp int64
|
||||
RenewedTimestamp int64
|
||||
RenewalInterval int64
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue