mirror of
https://github.com/hoernschen/dendrite.git
synced 2025-07-31 13:22:46 +00:00
Initial Store & Forward Implementation (#2917)
This adds store & forward relays into dendrite for p2p. A few things have changed: - new relay api serves new http endpoints for s&f federation - updated outbound federation queueing which will attempt to forward using s&f if appropriate - database entries to track s&f relays for other nodes
This commit is contained in:
parent
48fa869fa3
commit
5b73592f5a
77 changed files with 7646 additions and 1373 deletions
|
@ -20,11 +20,12 @@ import (
|
|||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
||||
"github.com/matrix-org/dendrite/federationapi/storage/shared"
|
||||
"github.com/matrix-org/dendrite/federationapi/storage/shared/receipt"
|
||||
"github.com/matrix-org/dendrite/federationapi/types"
|
||||
)
|
||||
|
||||
type Database interface {
|
||||
P2PDatabase
|
||||
gomatrixserverlib.KeyDatabase
|
||||
|
||||
UpdateRoom(ctx context.Context, roomID string, addHosts []types.JoinedHost, removeHosts []string, purgeRoomFirst bool) (joinedHosts []types.JoinedHost, err error)
|
||||
|
@ -34,16 +35,16 @@ type Database interface {
|
|||
// GetJoinedHostsForRooms returns the complete set of servers in the rooms given.
|
||||
GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]gomatrixserverlib.ServerName, error)
|
||||
|
||||
StoreJSON(ctx context.Context, js string) (*shared.Receipt, error)
|
||||
StoreJSON(ctx context.Context, js string) (*receipt.Receipt, error)
|
||||
|
||||
GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent, err error)
|
||||
GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*shared.Receipt]*gomatrixserverlib.EDU, err error)
|
||||
GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent, err error)
|
||||
GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*receipt.Receipt]*gomatrixserverlib.EDU, err error)
|
||||
|
||||
AssociatePDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt) error
|
||||
AssociateEDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error
|
||||
AssociatePDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, dbReceipt *receipt.Receipt) error
|
||||
AssociateEDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, dbReceipt *receipt.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error
|
||||
|
||||
CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error
|
||||
CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error
|
||||
CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*receipt.Receipt) error
|
||||
CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*receipt.Receipt) error
|
||||
|
||||
GetPendingPDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
|
||||
GetPendingEDUServerNames(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
|
||||
|
@ -54,6 +55,18 @@ type Database interface {
|
|||
RemoveAllServersFromBlacklist() error
|
||||
IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error)
|
||||
|
||||
// Adds the server to the list of assumed offline servers.
|
||||
// If the server already exists in the table, nothing happens and returns success.
|
||||
SetServerAssumedOffline(ctx context.Context, serverName gomatrixserverlib.ServerName) error
|
||||
// Removes the server from the list of assumed offline servers.
|
||||
// If the server doesn't exist in the table, nothing happens and returns success.
|
||||
RemoveServerAssumedOffline(ctx context.Context, serverName gomatrixserverlib.ServerName) error
|
||||
// Purges all entries from the assumed offline table.
|
||||
RemoveAllServersAssumedOffline(ctx context.Context) error
|
||||
// Gets whether the provided server is present in the table.
|
||||
// If it is present, returns true. If not, returns false.
|
||||
IsServerAssumedOffline(ctx context.Context, serverName gomatrixserverlib.ServerName) (bool, error)
|
||||
|
||||
AddOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error
|
||||
RenewOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error
|
||||
GetOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.OutboundPeek, error)
|
||||
|
@ -74,3 +87,21 @@ type Database interface {
|
|||
|
||||
PurgeRoom(ctx context.Context, roomID string) error
|
||||
}
|
||||
|
||||
type P2PDatabase interface {
|
||||
// Stores the given list of servers as relay servers for the provided destination server.
|
||||
// Providing duplicates will only lead to a single entry and won't lead to an error.
|
||||
P2PAddRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error
|
||||
|
||||
// Get the list of relay servers associated with the provided destination server.
|
||||
// If no entry exists in the table, an empty list is returned and does not result in an error.
|
||||
P2PGetRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName) ([]gomatrixserverlib.ServerName, error)
|
||||
|
||||
// Deletes any entries for the provided destination server that match the provided relayServers list.
|
||||
// If any of the provided servers don't match an entry, nothing happens and no error is returned.
|
||||
P2PRemoveRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error
|
||||
|
||||
// Deletes all entries for the provided destination server.
|
||||
// If the destination server doesn't exist in the table, nothing happens and no error is returned.
|
||||
P2PRemoveAllRelayServersForServer(ctx context.Context, serverName gomatrixserverlib.ServerName) error
|
||||
}
|
||||
|
|
107
federationapi/storage/postgres/assumed_offline_table.go
Normal file
107
federationapi/storage/postgres/assumed_offline_table.go
Normal file
|
@ -0,0 +1,107 @@
|
|||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const assumedOfflineSchema = `
|
||||
CREATE TABLE IF NOT EXISTS federationsender_assumed_offline(
|
||||
-- The assumed offline server name
|
||||
server_name TEXT PRIMARY KEY NOT NULL
|
||||
);
|
||||
`
|
||||
|
||||
const insertAssumedOfflineSQL = "" +
|
||||
"INSERT INTO federationsender_assumed_offline (server_name) VALUES ($1)" +
|
||||
" ON CONFLICT DO NOTHING"
|
||||
|
||||
const selectAssumedOfflineSQL = "" +
|
||||
"SELECT server_name FROM federationsender_assumed_offline WHERE server_name = $1"
|
||||
|
||||
const deleteAssumedOfflineSQL = "" +
|
||||
"DELETE FROM federationsender_assumed_offline WHERE server_name = $1"
|
||||
|
||||
const deleteAllAssumedOfflineSQL = "" +
|
||||
"TRUNCATE federationsender_assumed_offline"
|
||||
|
||||
type assumedOfflineStatements struct {
|
||||
db *sql.DB
|
||||
insertAssumedOfflineStmt *sql.Stmt
|
||||
selectAssumedOfflineStmt *sql.Stmt
|
||||
deleteAssumedOfflineStmt *sql.Stmt
|
||||
deleteAllAssumedOfflineStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresAssumedOfflineTable(db *sql.DB) (s *assumedOfflineStatements, err error) {
|
||||
s = &assumedOfflineStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err = db.Exec(assumedOfflineSchema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.insertAssumedOfflineStmt, insertAssumedOfflineSQL},
|
||||
{&s.selectAssumedOfflineStmt, selectAssumedOfflineSQL},
|
||||
{&s.deleteAssumedOfflineStmt, deleteAssumedOfflineSQL},
|
||||
{&s.deleteAllAssumedOfflineStmt, deleteAllAssumedOfflineSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *assumedOfflineStatements) InsertAssumedOffline(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertAssumedOfflineStmt)
|
||||
_, err := stmt.ExecContext(ctx, serverName)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *assumedOfflineStatements) SelectAssumedOffline(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
||||
) (bool, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectAssumedOfflineStmt)
|
||||
res, err := stmt.QueryContext(ctx, serverName)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer res.Close() // nolint:errcheck
|
||||
// The query will return the server name if the server is assume offline, and
|
||||
// will return no rows if not. By calling Next, we find out if a row was
|
||||
// returned or not - we don't care about the value itself.
|
||||
return res.Next(), nil
|
||||
}
|
||||
|
||||
func (s *assumedOfflineStatements) DeleteAssumedOffline(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteAssumedOfflineStmt)
|
||||
_, err := stmt.ExecContext(ctx, serverName)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *assumedOfflineStatements) DeleteAllAssumedOffline(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteAllAssumedOfflineStmt)
|
||||
_, err := stmt.ExecContext(ctx)
|
||||
return err
|
||||
}
|
137
federationapi/storage/postgres/relay_servers_table.go
Normal file
137
federationapi/storage/postgres/relay_servers_table.go
Normal file
|
@ -0,0 +1,137 @@
|
|||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const relayServersSchema = `
|
||||
CREATE TABLE IF NOT EXISTS federationsender_relay_servers (
|
||||
-- The destination server name
|
||||
server_name TEXT NOT NULL,
|
||||
-- The relay server name for a given destination
|
||||
relay_server_name TEXT NOT NULL,
|
||||
UNIQUE (server_name, relay_server_name)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS federationsender_relay_servers_server_name_idx
|
||||
ON federationsender_relay_servers (server_name);
|
||||
`
|
||||
|
||||
const insertRelayServersSQL = "" +
|
||||
"INSERT INTO federationsender_relay_servers (server_name, relay_server_name) VALUES ($1, $2)" +
|
||||
" ON CONFLICT DO NOTHING"
|
||||
|
||||
const selectRelayServersSQL = "" +
|
||||
"SELECT relay_server_name FROM federationsender_relay_servers WHERE server_name = $1"
|
||||
|
||||
const deleteRelayServersSQL = "" +
|
||||
"DELETE FROM federationsender_relay_servers WHERE server_name = $1 AND relay_server_name = ANY($2)"
|
||||
|
||||
const deleteAllRelayServersSQL = "" +
|
||||
"DELETE FROM federationsender_relay_servers WHERE server_name = $1"
|
||||
|
||||
type relayServersStatements struct {
|
||||
db *sql.DB
|
||||
insertRelayServersStmt *sql.Stmt
|
||||
selectRelayServersStmt *sql.Stmt
|
||||
deleteRelayServersStmt *sql.Stmt
|
||||
deleteAllRelayServersStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresRelayServersTable(db *sql.DB) (s *relayServersStatements, err error) {
|
||||
s = &relayServersStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err = db.Exec(relayServersSchema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.insertRelayServersStmt, insertRelayServersSQL},
|
||||
{&s.selectRelayServersStmt, selectRelayServersSQL},
|
||||
{&s.deleteRelayServersStmt, deleteRelayServersSQL},
|
||||
{&s.deleteAllRelayServersStmt, deleteAllRelayServersSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *relayServersStatements) InsertRelayServers(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
relayServers []gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
for _, relayServer := range relayServers {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertRelayServersStmt)
|
||||
if _, err := stmt.ExecContext(ctx, serverName, relayServer); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *relayServersStatements) SelectRelayServers(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
) ([]gomatrixserverlib.ServerName, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectRelayServersStmt)
|
||||
rows, err := stmt.QueryContext(ctx, serverName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "SelectRelayServers: rows.close() failed")
|
||||
|
||||
var result []gomatrixserverlib.ServerName
|
||||
for rows.Next() {
|
||||
var relayServer string
|
||||
if err = rows.Scan(&relayServer); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, gomatrixserverlib.ServerName(relayServer))
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *relayServersStatements) DeleteRelayServers(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
relayServers []gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteRelayServersStmt)
|
||||
_, err := stmt.ExecContext(ctx, serverName, pq.Array(relayServers))
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *relayServersStatements) DeleteAllRelayServers(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteAllRelayServersStmt)
|
||||
if _, err := stmt.ExecContext(ctx, serverName); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -62,6 +62,14 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
assumedOffline, err := NewPostgresAssumedOfflineTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
relayServers, err := NewPostgresRelayServersTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
inboundPeeks, err := NewPostgresInboundPeeksTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -104,6 +112,8 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
|
|||
FederationQueueEDUs: queueEDUs,
|
||||
FederationQueueJSON: queueJSON,
|
||||
FederationBlacklist: blacklist,
|
||||
FederationAssumedOffline: assumedOffline,
|
||||
FederationRelayServers: relayServers,
|
||||
FederationInboundPeeks: inboundPeeks,
|
||||
FederationOutboundPeeks: outboundPeeks,
|
||||
NotaryServerKeysJSON: notaryJSON,
|
||||
|
|
42
federationapi/storage/shared/receipt/receipt.go
Normal file
42
federationapi/storage/shared/receipt/receipt.go
Normal file
|
@ -0,0 +1,42 @@
|
|||
// Copyright 2023 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// A Receipt contains the NIDs of a call to GetNextTransactionPDUs/EDUs.
|
||||
// We don't actually export the NIDs but we need the caller to be able
|
||||
// to pass them back so that we can clean up if the transaction sends
|
||||
// successfully.
|
||||
|
||||
package receipt
|
||||
|
||||
import "fmt"
|
||||
|
||||
// Receipt is a wrapper type used to represent a nid that corresponds to a unique row entry
|
||||
// in some database table.
|
||||
// The internal nid value cannot be modified after a Receipt has been created.
|
||||
// This guarantees a receipt will always refer to the same table entry that it was created
|
||||
// to represent.
|
||||
type Receipt struct {
|
||||
nid int64
|
||||
}
|
||||
|
||||
func NewReceipt(nid int64) Receipt {
|
||||
return Receipt{nid: nid}
|
||||
}
|
||||
|
||||
func (r *Receipt) GetNID() int64 {
|
||||
return r.nid
|
||||
}
|
||||
|
||||
func (r *Receipt) String() string {
|
||||
return fmt.Sprintf("%d", r.nid)
|
||||
}
|
|
@ -20,6 +20,7 @@ import (
|
|||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/federationapi/storage/shared/receipt"
|
||||
"github.com/matrix-org/dendrite/federationapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/federationapi/types"
|
||||
"github.com/matrix-org/dendrite/internal/caching"
|
||||
|
@ -37,6 +38,8 @@ type Database struct {
|
|||
FederationQueueJSON tables.FederationQueueJSON
|
||||
FederationJoinedHosts tables.FederationJoinedHosts
|
||||
FederationBlacklist tables.FederationBlacklist
|
||||
FederationAssumedOffline tables.FederationAssumedOffline
|
||||
FederationRelayServers tables.FederationRelayServers
|
||||
FederationOutboundPeeks tables.FederationOutboundPeeks
|
||||
FederationInboundPeeks tables.FederationInboundPeeks
|
||||
NotaryServerKeysJSON tables.FederationNotaryServerKeysJSON
|
||||
|
@ -44,22 +47,6 @@ type Database struct {
|
|||
ServerSigningKeys tables.FederationServerSigningKeys
|
||||
}
|
||||
|
||||
// An Receipt contains the NIDs of a call to GetNextTransactionPDUs/EDUs.
|
||||
// We don't actually export the NIDs but we need the caller to be able
|
||||
// to pass them back so that we can clean up if the transaction sends
|
||||
// successfully.
|
||||
type Receipt struct {
|
||||
nid int64
|
||||
}
|
||||
|
||||
func NewReceipt(nid int64) Receipt {
|
||||
return Receipt{nid: nid}
|
||||
}
|
||||
|
||||
func (r *Receipt) String() string {
|
||||
return fmt.Sprintf("%d", r.nid)
|
||||
}
|
||||
|
||||
// UpdateRoom updates the joined hosts for a room and returns what the joined
|
||||
// hosts were before the update, or nil if this was a duplicate message.
|
||||
// This is called when we receive a message from kafka, so we pass in
|
||||
|
@ -113,11 +100,18 @@ func (d *Database) GetJoinedHosts(
|
|||
// GetAllJoinedHosts returns the currently joined hosts for
|
||||
// all rooms known to the federation sender.
|
||||
// Returns an error if something goes wrong.
|
||||
func (d *Database) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) {
|
||||
func (d *Database) GetAllJoinedHosts(
|
||||
ctx context.Context,
|
||||
) ([]gomatrixserverlib.ServerName, error) {
|
||||
return d.FederationJoinedHosts.SelectAllJoinedHosts(ctx)
|
||||
}
|
||||
|
||||
func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]gomatrixserverlib.ServerName, error) {
|
||||
func (d *Database) GetJoinedHostsForRooms(
|
||||
ctx context.Context,
|
||||
roomIDs []string,
|
||||
excludeSelf,
|
||||
excludeBlacklisted bool,
|
||||
) ([]gomatrixserverlib.ServerName, error) {
|
||||
servers, err := d.FederationJoinedHosts.SelectJoinedHostsForRooms(ctx, roomIDs, excludeBlacklisted)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -139,7 +133,7 @@ func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string,
|
|||
// metadata entries.
|
||||
func (d *Database) StoreJSON(
|
||||
ctx context.Context, js string,
|
||||
) (*Receipt, error) {
|
||||
) (*receipt.Receipt, error) {
|
||||
var nid int64
|
||||
var err error
|
||||
_ = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
|
@ -149,18 +143,21 @@ func (d *Database) StoreJSON(
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("d.insertQueueJSON: %w", err)
|
||||
}
|
||||
return &Receipt{
|
||||
nid: nid,
|
||||
}, nil
|
||||
newReceipt := receipt.NewReceipt(nid)
|
||||
return &newReceipt, nil
|
||||
}
|
||||
|
||||
func (d *Database) AddServerToBlacklist(serverName gomatrixserverlib.ServerName) error {
|
||||
func (d *Database) AddServerToBlacklist(
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.FederationBlacklist.InsertBlacklist(context.TODO(), txn, serverName)
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Database) RemoveServerFromBlacklist(serverName gomatrixserverlib.ServerName) error {
|
||||
func (d *Database) RemoveServerFromBlacklist(
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.FederationBlacklist.DeleteBlacklist(context.TODO(), txn, serverName)
|
||||
})
|
||||
|
@ -172,51 +169,166 @@ func (d *Database) RemoveAllServersFromBlacklist() error {
|
|||
})
|
||||
}
|
||||
|
||||
func (d *Database) IsServerBlacklisted(serverName gomatrixserverlib.ServerName) (bool, error) {
|
||||
func (d *Database) IsServerBlacklisted(
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
) (bool, error) {
|
||||
return d.FederationBlacklist.SelectBlacklist(context.TODO(), nil, serverName)
|
||||
}
|
||||
|
||||
func (d *Database) AddOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error {
|
||||
func (d *Database) SetServerAssumedOffline(
|
||||
ctx context.Context,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.FederationAssumedOffline.InsertAssumedOffline(ctx, txn, serverName)
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Database) RemoveServerAssumedOffline(
|
||||
ctx context.Context,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.FederationAssumedOffline.DeleteAssumedOffline(ctx, txn, serverName)
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Database) RemoveAllServersAssumedOffline(
|
||||
ctx context.Context,
|
||||
) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.FederationAssumedOffline.DeleteAllAssumedOffline(ctx, txn)
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Database) IsServerAssumedOffline(
|
||||
ctx context.Context,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
) (bool, error) {
|
||||
return d.FederationAssumedOffline.SelectAssumedOffline(ctx, nil, serverName)
|
||||
}
|
||||
|
||||
func (d *Database) P2PAddRelayServersForServer(
|
||||
ctx context.Context,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
relayServers []gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.FederationRelayServers.InsertRelayServers(ctx, txn, serverName, relayServers)
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Database) P2PGetRelayServersForServer(
|
||||
ctx context.Context,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
) ([]gomatrixserverlib.ServerName, error) {
|
||||
return d.FederationRelayServers.SelectRelayServers(ctx, nil, serverName)
|
||||
}
|
||||
|
||||
func (d *Database) P2PRemoveRelayServersForServer(
|
||||
ctx context.Context,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
relayServers []gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.FederationRelayServers.DeleteRelayServers(ctx, txn, serverName, relayServers)
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Database) P2PRemoveAllRelayServersForServer(
|
||||
ctx context.Context,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.FederationRelayServers.DeleteAllRelayServers(ctx, txn, serverName)
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Database) AddOutboundPeek(
|
||||
ctx context.Context,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
roomID string,
|
||||
peekID string,
|
||||
renewalInterval int64,
|
||||
) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.FederationOutboundPeeks.InsertOutboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval)
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Database) RenewOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error {
|
||||
func (d *Database) RenewOutboundPeek(
|
||||
ctx context.Context,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
roomID string,
|
||||
peekID string,
|
||||
renewalInterval int64,
|
||||
) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.FederationOutboundPeeks.RenewOutboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval)
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Database) GetOutboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.OutboundPeek, error) {
|
||||
func (d *Database) GetOutboundPeek(
|
||||
ctx context.Context,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
roomID,
|
||||
peekID string,
|
||||
) (*types.OutboundPeek, error) {
|
||||
return d.FederationOutboundPeeks.SelectOutboundPeek(ctx, nil, serverName, roomID, peekID)
|
||||
}
|
||||
|
||||
func (d *Database) GetOutboundPeeks(ctx context.Context, roomID string) ([]types.OutboundPeek, error) {
|
||||
func (d *Database) GetOutboundPeeks(
|
||||
ctx context.Context,
|
||||
roomID string,
|
||||
) ([]types.OutboundPeek, error) {
|
||||
return d.FederationOutboundPeeks.SelectOutboundPeeks(ctx, nil, roomID)
|
||||
}
|
||||
|
||||
func (d *Database) AddInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error {
|
||||
func (d *Database) AddInboundPeek(
|
||||
ctx context.Context,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
roomID string,
|
||||
peekID string,
|
||||
renewalInterval int64,
|
||||
) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.FederationInboundPeeks.InsertInboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval)
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Database) RenewInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) error {
|
||||
func (d *Database) RenewInboundPeek(
|
||||
ctx context.Context,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
roomID string,
|
||||
peekID string,
|
||||
renewalInterval int64,
|
||||
) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.FederationInboundPeeks.RenewInboundPeek(ctx, txn, serverName, roomID, peekID, renewalInterval)
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Database) GetInboundPeek(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID, peekID string) (*types.InboundPeek, error) {
|
||||
func (d *Database) GetInboundPeek(
|
||||
ctx context.Context,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
roomID string,
|
||||
peekID string,
|
||||
) (*types.InboundPeek, error) {
|
||||
return d.FederationInboundPeeks.SelectInboundPeek(ctx, nil, serverName, roomID, peekID)
|
||||
}
|
||||
|
||||
func (d *Database) GetInboundPeeks(ctx context.Context, roomID string) ([]types.InboundPeek, error) {
|
||||
func (d *Database) GetInboundPeeks(
|
||||
ctx context.Context,
|
||||
roomID string,
|
||||
) ([]types.InboundPeek, error) {
|
||||
return d.FederationInboundPeeks.SelectInboundPeeks(ctx, nil, roomID)
|
||||
}
|
||||
|
||||
func (d *Database) UpdateNotaryKeys(ctx context.Context, serverName gomatrixserverlib.ServerName, serverKeys gomatrixserverlib.ServerKeys) error {
|
||||
func (d *Database) UpdateNotaryKeys(
|
||||
ctx context.Context,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
serverKeys gomatrixserverlib.ServerKeys,
|
||||
) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
validUntil := serverKeys.ValidUntilTS
|
||||
// Servers MUST use the lesser of this field and 7 days into the future when determining if a key is valid.
|
||||
|
@ -251,7 +363,9 @@ func (d *Database) UpdateNotaryKeys(ctx context.Context, serverName gomatrixserv
|
|||
}
|
||||
|
||||
func (d *Database) GetNotaryKeys(
|
||||
ctx context.Context, serverName gomatrixserverlib.ServerName, optKeyIDs []gomatrixserverlib.KeyID,
|
||||
ctx context.Context,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
optKeyIDs []gomatrixserverlib.KeyID,
|
||||
) (sks []gomatrixserverlib.ServerKeys, err error) {
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
sks, err = d.NotaryServerKeysMetadata.SelectKeys(ctx, txn, serverName, optKeyIDs)
|
||||
|
|
|
@ -22,6 +22,7 @@ import (
|
|||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/federationapi/storage/shared/receipt"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
|
@ -41,7 +42,7 @@ var defaultExpireEDUTypes = map[string]time.Duration{
|
|||
func (d *Database) AssociateEDUWithDestinations(
|
||||
ctx context.Context,
|
||||
destinations map[gomatrixserverlib.ServerName]struct{},
|
||||
receipt *Receipt,
|
||||
dbReceipt *receipt.Receipt,
|
||||
eduType string,
|
||||
expireEDUTypes map[string]time.Duration,
|
||||
) error {
|
||||
|
@ -62,12 +63,12 @@ func (d *Database) AssociateEDUWithDestinations(
|
|||
var err error
|
||||
for destination := range destinations {
|
||||
err = d.FederationQueueEDUs.InsertQueueEDU(
|
||||
ctx, // context
|
||||
txn, // SQL transaction
|
||||
eduType, // EDU type for coalescing
|
||||
destination, // destination server name
|
||||
receipt.nid, // NID from the federationapi_queue_json table
|
||||
expiresAt, // The timestamp this EDU will expire
|
||||
ctx, // context
|
||||
txn, // SQL transaction
|
||||
eduType, // EDU type for coalescing
|
||||
destination, // destination server name
|
||||
dbReceipt.GetNID(), // NID from the federationapi_queue_json table
|
||||
expiresAt, // The timestamp this EDU will expire
|
||||
)
|
||||
}
|
||||
return err
|
||||
|
@ -81,10 +82,10 @@ func (d *Database) GetPendingEDUs(
|
|||
serverName gomatrixserverlib.ServerName,
|
||||
limit int,
|
||||
) (
|
||||
edus map[*Receipt]*gomatrixserverlib.EDU,
|
||||
edus map[*receipt.Receipt]*gomatrixserverlib.EDU,
|
||||
err error,
|
||||
) {
|
||||
edus = make(map[*Receipt]*gomatrixserverlib.EDU)
|
||||
edus = make(map[*receipt.Receipt]*gomatrixserverlib.EDU)
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
nids, err := d.FederationQueueEDUs.SelectQueueEDUs(ctx, txn, serverName, limit)
|
||||
if err != nil {
|
||||
|
@ -94,7 +95,8 @@ func (d *Database) GetPendingEDUs(
|
|||
retrieve := make([]int64, 0, len(nids))
|
||||
for _, nid := range nids {
|
||||
if edu, ok := d.Cache.GetFederationQueuedEDU(nid); ok {
|
||||
edus[&Receipt{nid}] = edu
|
||||
newReceipt := receipt.NewReceipt(nid)
|
||||
edus[&newReceipt] = edu
|
||||
} else {
|
||||
retrieve = append(retrieve, nid)
|
||||
}
|
||||
|
@ -110,7 +112,8 @@ func (d *Database) GetPendingEDUs(
|
|||
if err := json.Unmarshal(blob, &event); err != nil {
|
||||
return fmt.Errorf("json.Unmarshal: %w", err)
|
||||
}
|
||||
edus[&Receipt{nid}] = &event
|
||||
newReceipt := receipt.NewReceipt(nid)
|
||||
edus[&newReceipt] = &event
|
||||
d.Cache.StoreFederationQueuedEDU(nid, &event)
|
||||
}
|
||||
|
||||
|
@ -124,7 +127,7 @@ func (d *Database) GetPendingEDUs(
|
|||
func (d *Database) CleanEDUs(
|
||||
ctx context.Context,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
receipts []*Receipt,
|
||||
receipts []*receipt.Receipt,
|
||||
) error {
|
||||
if len(receipts) == 0 {
|
||||
return errors.New("expected receipt")
|
||||
|
@ -132,7 +135,7 @@ func (d *Database) CleanEDUs(
|
|||
|
||||
nids := make([]int64, len(receipts))
|
||||
for i := range receipts {
|
||||
nids[i] = receipts[i].nid
|
||||
nids[i] = receipts[i].GetNID()
|
||||
}
|
||||
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
|
|
|
@ -21,6 +21,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/federationapi/storage/shared/receipt"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
|
@ -30,17 +31,17 @@ import (
|
|||
func (d *Database) AssociatePDUWithDestinations(
|
||||
ctx context.Context,
|
||||
destinations map[gomatrixserverlib.ServerName]struct{},
|
||||
receipt *Receipt,
|
||||
dbReceipt *receipt.Receipt,
|
||||
) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
var err error
|
||||
for destination := range destinations {
|
||||
err = d.FederationQueuePDUs.InsertQueuePDU(
|
||||
ctx, // context
|
||||
txn, // SQL transaction
|
||||
"", // transaction ID
|
||||
destination, // destination server name
|
||||
receipt.nid, // NID from the federationapi_queue_json table
|
||||
ctx, // context
|
||||
txn, // SQL transaction
|
||||
"", // transaction ID
|
||||
destination, // destination server name
|
||||
dbReceipt.GetNID(), // NID from the federationapi_queue_json table
|
||||
)
|
||||
}
|
||||
return err
|
||||
|
@ -54,7 +55,7 @@ func (d *Database) GetPendingPDUs(
|
|||
serverName gomatrixserverlib.ServerName,
|
||||
limit int,
|
||||
) (
|
||||
events map[*Receipt]*gomatrixserverlib.HeaderedEvent,
|
||||
events map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent,
|
||||
err error,
|
||||
) {
|
||||
// Strictly speaking this doesn't need to be using the writer
|
||||
|
@ -62,7 +63,7 @@ func (d *Database) GetPendingPDUs(
|
|||
// a guarantee of transactional isolation, it's actually useful
|
||||
// to know in SQLite mode that nothing else is trying to modify
|
||||
// the database.
|
||||
events = make(map[*Receipt]*gomatrixserverlib.HeaderedEvent)
|
||||
events = make(map[*receipt.Receipt]*gomatrixserverlib.HeaderedEvent)
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
nids, err := d.FederationQueuePDUs.SelectQueuePDUs(ctx, txn, serverName, limit)
|
||||
if err != nil {
|
||||
|
@ -72,7 +73,8 @@ func (d *Database) GetPendingPDUs(
|
|||
retrieve := make([]int64, 0, len(nids))
|
||||
for _, nid := range nids {
|
||||
if event, ok := d.Cache.GetFederationQueuedPDU(nid); ok {
|
||||
events[&Receipt{nid}] = event
|
||||
newReceipt := receipt.NewReceipt(nid)
|
||||
events[&newReceipt] = event
|
||||
} else {
|
||||
retrieve = append(retrieve, nid)
|
||||
}
|
||||
|
@ -88,7 +90,8 @@ func (d *Database) GetPendingPDUs(
|
|||
if err := json.Unmarshal(blob, &event); err != nil {
|
||||
return fmt.Errorf("json.Unmarshal: %w", err)
|
||||
}
|
||||
events[&Receipt{nid}] = &event
|
||||
newReceipt := receipt.NewReceipt(nid)
|
||||
events[&newReceipt] = &event
|
||||
d.Cache.StoreFederationQueuedPDU(nid, &event)
|
||||
}
|
||||
|
||||
|
@ -103,7 +106,7 @@ func (d *Database) GetPendingPDUs(
|
|||
func (d *Database) CleanPDUs(
|
||||
ctx context.Context,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
receipts []*Receipt,
|
||||
receipts []*receipt.Receipt,
|
||||
) error {
|
||||
if len(receipts) == 0 {
|
||||
return errors.New("expected receipt")
|
||||
|
@ -111,7 +114,7 @@ func (d *Database) CleanPDUs(
|
|||
|
||||
nids := make([]int64, len(receipts))
|
||||
for i := range receipts {
|
||||
nids[i] = receipts[i].nid
|
||||
nids[i] = receipts[i].GetNID()
|
||||
}
|
||||
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
|
|
107
federationapi/storage/sqlite3/assumed_offline_table.go
Normal file
107
federationapi/storage/sqlite3/assumed_offline_table.go
Normal file
|
@ -0,0 +1,107 @@
|
|||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const assumedOfflineSchema = `
|
||||
CREATE TABLE IF NOT EXISTS federationsender_assumed_offline(
|
||||
-- The assumed offline server name
|
||||
server_name TEXT PRIMARY KEY NOT NULL
|
||||
);
|
||||
`
|
||||
|
||||
const insertAssumedOfflineSQL = "" +
|
||||
"INSERT INTO federationsender_assumed_offline (server_name) VALUES ($1)" +
|
||||
" ON CONFLICT DO NOTHING"
|
||||
|
||||
const selectAssumedOfflineSQL = "" +
|
||||
"SELECT server_name FROM federationsender_assumed_offline WHERE server_name = $1"
|
||||
|
||||
const deleteAssumedOfflineSQL = "" +
|
||||
"DELETE FROM federationsender_assumed_offline WHERE server_name = $1"
|
||||
|
||||
const deleteAllAssumedOfflineSQL = "" +
|
||||
"DELETE FROM federationsender_assumed_offline"
|
||||
|
||||
type assumedOfflineStatements struct {
|
||||
db *sql.DB
|
||||
insertAssumedOfflineStmt *sql.Stmt
|
||||
selectAssumedOfflineStmt *sql.Stmt
|
||||
deleteAssumedOfflineStmt *sql.Stmt
|
||||
deleteAllAssumedOfflineStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSQLiteAssumedOfflineTable(db *sql.DB) (s *assumedOfflineStatements, err error) {
|
||||
s = &assumedOfflineStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err = db.Exec(assumedOfflineSchema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.insertAssumedOfflineStmt, insertAssumedOfflineSQL},
|
||||
{&s.selectAssumedOfflineStmt, selectAssumedOfflineSQL},
|
||||
{&s.deleteAssumedOfflineStmt, deleteAssumedOfflineSQL},
|
||||
{&s.deleteAllAssumedOfflineStmt, deleteAllAssumedOfflineSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *assumedOfflineStatements) InsertAssumedOffline(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertAssumedOfflineStmt)
|
||||
_, err := stmt.ExecContext(ctx, serverName)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *assumedOfflineStatements) SelectAssumedOffline(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
||||
) (bool, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectAssumedOfflineStmt)
|
||||
res, err := stmt.QueryContext(ctx, serverName)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer res.Close() // nolint:errcheck
|
||||
// The query will return the server name if the server is assume offline, and
|
||||
// will return no rows if not. By calling Next, we find out if a row was
|
||||
// returned or not - we don't care about the value itself.
|
||||
return res.Next(), nil
|
||||
}
|
||||
|
||||
func (s *assumedOfflineStatements) DeleteAssumedOffline(
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteAssumedOfflineStmt)
|
||||
_, err := stmt.ExecContext(ctx, serverName)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *assumedOfflineStatements) DeleteAllAssumedOffline(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteAllAssumedOfflineStmt)
|
||||
_, err := stmt.ExecContext(ctx)
|
||||
return err
|
||||
}
|
148
federationapi/storage/sqlite3/relay_servers_table.go
Normal file
148
federationapi/storage/sqlite3/relay_servers_table.go
Normal file
|
@ -0,0 +1,148 @@
|
|||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strings"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const relayServersSchema = `
|
||||
CREATE TABLE IF NOT EXISTS federationsender_relay_servers (
|
||||
-- The destination server name
|
||||
server_name TEXT NOT NULL,
|
||||
-- The relay server name for a given destination
|
||||
relay_server_name TEXT NOT NULL,
|
||||
UNIQUE (server_name, relay_server_name)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS federationsender_relay_servers_server_name_idx
|
||||
ON federationsender_relay_servers (server_name);
|
||||
`
|
||||
|
||||
const insertRelayServersSQL = "" +
|
||||
"INSERT INTO federationsender_relay_servers (server_name, relay_server_name) VALUES ($1, $2)" +
|
||||
" ON CONFLICT DO NOTHING"
|
||||
|
||||
const selectRelayServersSQL = "" +
|
||||
"SELECT relay_server_name FROM federationsender_relay_servers WHERE server_name = $1"
|
||||
|
||||
const deleteRelayServersSQL = "" +
|
||||
"DELETE FROM federationsender_relay_servers WHERE server_name = $1 AND relay_server_name IN ($2)"
|
||||
|
||||
const deleteAllRelayServersSQL = "" +
|
||||
"DELETE FROM federationsender_relay_servers WHERE server_name = $1"
|
||||
|
||||
type relayServersStatements struct {
|
||||
db *sql.DB
|
||||
insertRelayServersStmt *sql.Stmt
|
||||
selectRelayServersStmt *sql.Stmt
|
||||
// deleteRelayServersStmt *sql.Stmt - prepared at runtime due to variadic
|
||||
deleteAllRelayServersStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSQLiteRelayServersTable(db *sql.DB) (s *relayServersStatements, err error) {
|
||||
s = &relayServersStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err = db.Exec(relayServersSchema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.insertRelayServersStmt, insertRelayServersSQL},
|
||||
{&s.selectRelayServersStmt, selectRelayServersSQL},
|
||||
{&s.deleteAllRelayServersStmt, deleteAllRelayServersSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *relayServersStatements) InsertRelayServers(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
relayServers []gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
for _, relayServer := range relayServers {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertRelayServersStmt)
|
||||
if _, err := stmt.ExecContext(ctx, serverName, relayServer); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *relayServersStatements) SelectRelayServers(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
) ([]gomatrixserverlib.ServerName, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectRelayServersStmt)
|
||||
rows, err := stmt.QueryContext(ctx, serverName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "SelectRelayServers: rows.close() failed")
|
||||
|
||||
var result []gomatrixserverlib.ServerName
|
||||
for rows.Next() {
|
||||
var relayServer string
|
||||
if err = rows.Scan(&relayServer); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
result = append(result, gomatrixserverlib.ServerName(relayServer))
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *relayServersStatements) DeleteRelayServers(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
relayServers []gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
deleteSQL := strings.Replace(deleteRelayServersSQL, "($2)", sqlutil.QueryVariadicOffset(len(relayServers), 1), 1)
|
||||
deleteStmt, err := s.db.Prepare(deleteSQL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
stmt := sqlutil.TxStmt(txn, deleteStmt)
|
||||
params := make([]interface{}, len(relayServers)+1)
|
||||
params[0] = serverName
|
||||
for i, v := range relayServers {
|
||||
params[i+1] = v
|
||||
}
|
||||
|
||||
_, err = stmt.ExecContext(ctx, params...)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *relayServersStatements) DeleteAllRelayServers(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx,
|
||||
serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteAllRelayServersStmt)
|
||||
if _, err := stmt.ExecContext(ctx, serverName); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -1,5 +1,4 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
// Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
|
@ -61,6 +60,14 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
assumedOffline, err := NewSQLiteAssumedOfflineTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
relayServers, err := NewSQLiteRelayServersTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
outboundPeeks, err := NewSQLiteOutboundPeeksTable(d.db)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -103,6 +110,8 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
|
|||
FederationQueueEDUs: queueEDUs,
|
||||
FederationQueueJSON: queueJSON,
|
||||
FederationBlacklist: blacklist,
|
||||
FederationAssumedOffline: assumedOffline,
|
||||
FederationRelayServers: relayServers,
|
||||
FederationOutboundPeeks: outboundPeeks,
|
||||
FederationInboundPeeks: inboundPeeks,
|
||||
NotaryServerKeysJSON: notaryKeys,
|
||||
|
|
|
@ -6,14 +6,13 @@ import (
|
|||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/matrix-org/dendrite/federationapi/storage"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
"github.com/matrix-org/dendrite/test/testrig"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/matrix-org/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func mustCreateFederationDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
|
||||
|
@ -246,3 +245,99 @@ func TestInboundPeeking(t *testing.T) {
|
|||
assert.ElementsMatch(t, gotPeekIDs, peekIDs)
|
||||
})
|
||||
}
|
||||
|
||||
func TestServersAssumedOffline(t *testing.T) {
|
||||
server1 := gomatrixserverlib.ServerName("server1")
|
||||
server2 := gomatrixserverlib.ServerName("server2")
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, closeDB := mustCreateFederationDatabase(t, dbType)
|
||||
defer closeDB()
|
||||
|
||||
// Set server1 & server2 as assumed offline.
|
||||
err := db.SetServerAssumedOffline(context.Background(), server1)
|
||||
assert.Nil(t, err)
|
||||
err = db.SetServerAssumedOffline(context.Background(), server2)
|
||||
assert.Nil(t, err)
|
||||
|
||||
// Ensure both servers are assumed offline.
|
||||
isOffline, err := db.IsServerAssumedOffline(context.Background(), server1)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, isOffline)
|
||||
isOffline, err = db.IsServerAssumedOffline(context.Background(), server2)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, isOffline)
|
||||
|
||||
// Set server1 as not assumed offline.
|
||||
err = db.RemoveServerAssumedOffline(context.Background(), server1)
|
||||
assert.Nil(t, err)
|
||||
|
||||
// Ensure both servers have correct state.
|
||||
isOffline, err = db.IsServerAssumedOffline(context.Background(), server1)
|
||||
assert.Nil(t, err)
|
||||
assert.False(t, isOffline)
|
||||
isOffline, err = db.IsServerAssumedOffline(context.Background(), server2)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, isOffline)
|
||||
|
||||
// Re-set server1 as assumed offline.
|
||||
err = db.SetServerAssumedOffline(context.Background(), server1)
|
||||
assert.Nil(t, err)
|
||||
|
||||
// Ensure server1 is assumed offline.
|
||||
isOffline, err = db.IsServerAssumedOffline(context.Background(), server1)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, isOffline)
|
||||
|
||||
err = db.RemoveAllServersAssumedOffline(context.Background())
|
||||
assert.Nil(t, err)
|
||||
|
||||
// Ensure both servers have correct state.
|
||||
isOffline, err = db.IsServerAssumedOffline(context.Background(), server1)
|
||||
assert.Nil(t, err)
|
||||
assert.False(t, isOffline)
|
||||
isOffline, err = db.IsServerAssumedOffline(context.Background(), server2)
|
||||
assert.Nil(t, err)
|
||||
assert.False(t, isOffline)
|
||||
})
|
||||
}
|
||||
|
||||
func TestRelayServersStored(t *testing.T) {
|
||||
server := gomatrixserverlib.ServerName("server")
|
||||
relayServer1 := gomatrixserverlib.ServerName("relayserver1")
|
||||
relayServer2 := gomatrixserverlib.ServerName("relayserver2")
|
||||
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, closeDB := mustCreateFederationDatabase(t, dbType)
|
||||
defer closeDB()
|
||||
|
||||
err := db.P2PAddRelayServersForServer(context.Background(), server, []gomatrixserverlib.ServerName{relayServer1})
|
||||
assert.Nil(t, err)
|
||||
|
||||
relayServers, err := db.P2PGetRelayServersForServer(context.Background(), server)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, relayServer1, relayServers[0])
|
||||
|
||||
err = db.P2PRemoveRelayServersForServer(context.Background(), server, []gomatrixserverlib.ServerName{relayServer1})
|
||||
assert.Nil(t, err)
|
||||
|
||||
relayServers, err = db.P2PGetRelayServersForServer(context.Background(), server)
|
||||
assert.Nil(t, err)
|
||||
assert.Zero(t, len(relayServers))
|
||||
|
||||
err = db.P2PAddRelayServersForServer(context.Background(), server, []gomatrixserverlib.ServerName{relayServer1, relayServer2})
|
||||
assert.Nil(t, err)
|
||||
|
||||
relayServers, err = db.P2PGetRelayServersForServer(context.Background(), server)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, relayServer1, relayServers[0])
|
||||
assert.Equal(t, relayServer2, relayServers[1])
|
||||
|
||||
err = db.P2PRemoveAllRelayServersForServer(context.Background(), server)
|
||||
assert.Nil(t, err)
|
||||
|
||||
relayServers, err = db.P2PGetRelayServersForServer(context.Background(), server)
|
||||
assert.Nil(t, err)
|
||||
assert.Zero(t, len(relayServers))
|
||||
})
|
||||
}
|
||||
|
|
|
@ -49,6 +49,19 @@ type FederationQueueJSON interface {
|
|||
SelectQueueJSON(ctx context.Context, txn *sql.Tx, jsonNIDs []int64) (map[int64][]byte, error)
|
||||
}
|
||||
|
||||
type FederationQueueTransactions interface {
|
||||
InsertQueueTransaction(ctx context.Context, txn *sql.Tx, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nid int64) error
|
||||
DeleteQueueTransactions(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, jsonNIDs []int64) error
|
||||
SelectQueueTransactions(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, limit int) ([]int64, error)
|
||||
SelectQueueTransactionCount(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (int64, error)
|
||||
}
|
||||
|
||||
type FederationTransactionJSON interface {
|
||||
InsertTransactionJSON(ctx context.Context, txn *sql.Tx, json string) (int64, error)
|
||||
DeleteTransactionJSON(ctx context.Context, txn *sql.Tx, nids []int64) error
|
||||
SelectTransactionJSON(ctx context.Context, txn *sql.Tx, jsonNIDs []int64) (map[int64][]byte, error)
|
||||
}
|
||||
|
||||
type FederationJoinedHosts interface {
|
||||
InsertJoinedHosts(ctx context.Context, txn *sql.Tx, roomID, eventID string, serverName gomatrixserverlib.ServerName) error
|
||||
DeleteJoinedHosts(ctx context.Context, txn *sql.Tx, eventIDs []string) error
|
||||
|
@ -66,6 +79,20 @@ type FederationBlacklist interface {
|
|||
DeleteAllBlacklist(ctx context.Context, txn *sql.Tx) error
|
||||
}
|
||||
|
||||
type FederationAssumedOffline interface {
|
||||
InsertAssumedOffline(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) error
|
||||
SelectAssumedOffline(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) (bool, error)
|
||||
DeleteAssumedOffline(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) error
|
||||
DeleteAllAssumedOffline(ctx context.Context, txn *sql.Tx) error
|
||||
}
|
||||
|
||||
type FederationRelayServers interface {
|
||||
InsertRelayServers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error
|
||||
SelectRelayServers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) ([]gomatrixserverlib.ServerName, error)
|
||||
DeleteRelayServers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, relayServers []gomatrixserverlib.ServerName) error
|
||||
DeleteAllRelayServers(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName) error
|
||||
}
|
||||
|
||||
type FederationOutboundPeeks interface {
|
||||
InsertOutboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) (err error)
|
||||
RenewOutboundPeek(ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName, roomID, peekID string, renewalInterval int64) (err error)
|
||||
|
|
224
federationapi/storage/tables/relay_servers_table_test.go
Normal file
224
federationapi/storage/tables/relay_servers_table_test.go
Normal file
|
@ -0,0 +1,224 @@
|
|||
package tables_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/dendrite/federationapi/storage/postgres"
|
||||
"github.com/matrix-org/dendrite/federationapi/storage/sqlite3"
|
||||
"github.com/matrix-org/dendrite/federationapi/storage/tables"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/setup/config"
|
||||
"github.com/matrix-org/dendrite/test"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const (
|
||||
server1 = "server1"
|
||||
server2 = "server2"
|
||||
server3 = "server3"
|
||||
server4 = "server4"
|
||||
)
|
||||
|
||||
type RelayServersDatabase struct {
|
||||
DB *sql.DB
|
||||
Writer sqlutil.Writer
|
||||
Table tables.FederationRelayServers
|
||||
}
|
||||
|
||||
func mustCreateRelayServersTable(
|
||||
t *testing.T,
|
||||
dbType test.DBType,
|
||||
) (database RelayServersDatabase, close func()) {
|
||||
t.Helper()
|
||||
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||
db, err := sqlutil.Open(&config.DatabaseOptions{
|
||||
ConnectionString: config.DataSource(connStr),
|
||||
}, sqlutil.NewExclusiveWriter())
|
||||
assert.NoError(t, err)
|
||||
var tab tables.FederationRelayServers
|
||||
switch dbType {
|
||||
case test.DBTypePostgres:
|
||||
tab, err = postgres.NewPostgresRelayServersTable(db)
|
||||
assert.NoError(t, err)
|
||||
case test.DBTypeSQLite:
|
||||
tab, err = sqlite3.NewSQLiteRelayServersTable(db)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
database = RelayServersDatabase{
|
||||
DB: db,
|
||||
Writer: sqlutil.NewDummyWriter(),
|
||||
Table: tab,
|
||||
}
|
||||
return database, close
|
||||
}
|
||||
|
||||
func Equal(a, b []gomatrixserverlib.ServerName) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i, v := range a {
|
||||
if v != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func TestShouldInsertRelayServers(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateRelayServersTable(t, dbType)
|
||||
defer close()
|
||||
expectedRelayServers := []gomatrixserverlib.ServerName{server2, server3}
|
||||
|
||||
err := db.Table.InsertRelayServers(ctx, nil, server1, expectedRelayServers)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed inserting transaction: %s", err.Error())
|
||||
}
|
||||
|
||||
relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error())
|
||||
}
|
||||
|
||||
if !Equal(relayServers, expectedRelayServers) {
|
||||
t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers, relayServers)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestShouldInsertRelayServersWithDuplicates(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateRelayServersTable(t, dbType)
|
||||
defer close()
|
||||
insertRelayServers := []gomatrixserverlib.ServerName{server2, server2, server2, server3, server2}
|
||||
expectedRelayServers := []gomatrixserverlib.ServerName{server2, server3}
|
||||
|
||||
err := db.Table.InsertRelayServers(ctx, nil, server1, insertRelayServers)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed inserting transaction: %s", err.Error())
|
||||
}
|
||||
|
||||
// Insert the same list again, this shouldn't fail and should have no effect.
|
||||
err = db.Table.InsertRelayServers(ctx, nil, server1, insertRelayServers)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed inserting transaction: %s", err.Error())
|
||||
}
|
||||
|
||||
relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error())
|
||||
}
|
||||
|
||||
if !Equal(relayServers, expectedRelayServers) {
|
||||
t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers, relayServers)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestShouldGetRelayServersUnknownDestination(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateRelayServersTable(t, dbType)
|
||||
defer close()
|
||||
|
||||
// Query relay servers for a destination that doesn't exist in the table.
|
||||
relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error())
|
||||
}
|
||||
|
||||
if !Equal(relayServers, []gomatrixserverlib.ServerName{}) {
|
||||
t.Fatalf("Expected: %v \nActual: %v", []gomatrixserverlib.ServerName{}, relayServers)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestShouldDeleteCorrectRelayServers(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateRelayServersTable(t, dbType)
|
||||
defer close()
|
||||
relayServers1 := []gomatrixserverlib.ServerName{server2, server3}
|
||||
relayServers2 := []gomatrixserverlib.ServerName{server1, server3, server4}
|
||||
|
||||
err := db.Table.InsertRelayServers(ctx, nil, server1, relayServers1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed inserting transaction: %s", err.Error())
|
||||
}
|
||||
err = db.Table.InsertRelayServers(ctx, nil, server2, relayServers2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed inserting transaction: %s", err.Error())
|
||||
}
|
||||
|
||||
err = db.Table.DeleteRelayServers(ctx, nil, server1, []gomatrixserverlib.ServerName{server2})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed deleting relay servers for %s: %s", server1, err.Error())
|
||||
}
|
||||
err = db.Table.DeleteRelayServers(ctx, nil, server2, []gomatrixserverlib.ServerName{server1, server4})
|
||||
if err != nil {
|
||||
t.Fatalf("Failed deleting relay servers for %s: %s", server2, err.Error())
|
||||
}
|
||||
|
||||
expectedRelayServers := []gomatrixserverlib.ServerName{server3}
|
||||
relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error())
|
||||
}
|
||||
if !Equal(relayServers, expectedRelayServers) {
|
||||
t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers, relayServers)
|
||||
}
|
||||
relayServers, err = db.Table.SelectRelayServers(ctx, nil, server2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error())
|
||||
}
|
||||
if !Equal(relayServers, expectedRelayServers) {
|
||||
t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers, relayServers)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestShouldDeleteAllRelayServers(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||
db, close := mustCreateRelayServersTable(t, dbType)
|
||||
defer close()
|
||||
expectedRelayServers := []gomatrixserverlib.ServerName{server2, server3}
|
||||
|
||||
err := db.Table.InsertRelayServers(ctx, nil, server1, expectedRelayServers)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed inserting transaction: %s", err.Error())
|
||||
}
|
||||
err = db.Table.InsertRelayServers(ctx, nil, server2, expectedRelayServers)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed inserting transaction: %s", err.Error())
|
||||
}
|
||||
|
||||
err = db.Table.DeleteAllRelayServers(ctx, nil, server1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed deleting relay servers for %s: %s", server1, err.Error())
|
||||
}
|
||||
|
||||
expectedRelayServers1 := []gomatrixserverlib.ServerName{}
|
||||
relayServers, err := db.Table.SelectRelayServers(ctx, nil, server1)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error())
|
||||
}
|
||||
if !Equal(relayServers, expectedRelayServers1) {
|
||||
t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers1, relayServers)
|
||||
}
|
||||
relayServers, err = db.Table.SelectRelayServers(ctx, nil, server2)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed retrieving relay servers for %s: %s", relayServers, err.Error())
|
||||
}
|
||||
if !Equal(relayServers, expectedRelayServers) {
|
||||
t.Fatalf("Expected: %v \nActual: %v", expectedRelayServers, relayServers)
|
||||
}
|
||||
})
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue