mirror of
https://github.com/hoernschen/dendrite.git
synced 2024-12-27 07:28:27 +00:00
Don't get blacklisted hosts when querying joined servers (#2880)
Otherwise we just waste time/CPU.
This commit is contained in:
parent
5c9aed6af9
commit
9b8bb55430
12 changed files with 56 additions and 32 deletions
|
@ -198,8 +198,9 @@ type PerformInviteResponse struct {
|
||||||
|
|
||||||
// QueryJoinedHostServerNamesInRoomRequest is a request to QueryJoinedHostServerNames
|
// QueryJoinedHostServerNamesInRoomRequest is a request to QueryJoinedHostServerNames
|
||||||
type QueryJoinedHostServerNamesInRoomRequest struct {
|
type QueryJoinedHostServerNamesInRoomRequest struct {
|
||||||
RoomID string `json:"room_id"`
|
RoomID string `json:"room_id"`
|
||||||
ExcludeSelf bool `json:"exclude_self"`
|
ExcludeSelf bool `json:"exclude_self"`
|
||||||
|
ExcludeBlacklisted bool `json:"exclude_blacklisted"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// QueryJoinedHostServerNamesInRoomResponse is a response to QueryJoinedHostServerNames
|
// QueryJoinedHostServerNamesInRoomResponse is a response to QueryJoinedHostServerNames
|
||||||
|
|
|
@ -128,7 +128,7 @@ func (t *KeyChangeConsumer) onDeviceKeyMessage(m api.DeviceMessage) bool {
|
||||||
}
|
}
|
||||||
|
|
||||||
// send this key change to all servers who share rooms with this user.
|
// send this key change to all servers who share rooms with this user.
|
||||||
destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true)
|
destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sentry.CaptureException(err)
|
sentry.CaptureException(err)
|
||||||
logger.WithError(err).Error("failed to calculate joined hosts for rooms user is in")
|
logger.WithError(err).Error("failed to calculate joined hosts for rooms user is in")
|
||||||
|
@ -189,7 +189,7 @@ func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
// send this key change to all servers who share rooms with this user.
|
// send this key change to all servers who share rooms with this user.
|
||||||
destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true)
|
destinations, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
sentry.CaptureException(err)
|
sentry.CaptureException(err)
|
||||||
logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined hosts for rooms user is in")
|
logger.WithError(err).Error("fedsender key change consumer: failed to calculate joined hosts for rooms user is in")
|
||||||
|
|
|
@ -111,7 +111,7 @@ func (t *OutputPresenceConsumer) onMessage(ctx context.Context, msgs []*nats.Msg
|
||||||
}
|
}
|
||||||
|
|
||||||
// send this presence to all servers who share rooms with this user.
|
// send this presence to all servers who share rooms with this user.
|
||||||
joined, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true)
|
joined, err := t.db.GetJoinedHostsForRooms(t.ctx, queryRes.RoomIDs, true, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).Error("failed to get joined hosts")
|
log.WithError(err).Error("failed to get joined hosts")
|
||||||
return true
|
return true
|
||||||
|
|
|
@ -16,7 +16,7 @@ func (f *FederationInternalAPI) QueryJoinedHostServerNamesInRoom(
|
||||||
request *api.QueryJoinedHostServerNamesInRoomRequest,
|
request *api.QueryJoinedHostServerNamesInRoomRequest,
|
||||||
response *api.QueryJoinedHostServerNamesInRoomResponse,
|
response *api.QueryJoinedHostServerNamesInRoomResponse,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
joinedHosts, err := f.db.GetJoinedHostsForRooms(ctx, []string{request.RoomID}, request.ExcludeSelf)
|
joinedHosts, err := f.db.GetJoinedHostsForRooms(ctx, []string{request.RoomID}, request.ExcludeSelf, request.ExcludeBlacklisted)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -32,7 +32,7 @@ type Database interface {
|
||||||
GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error)
|
GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error)
|
||||||
GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
|
GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
|
||||||
// GetJoinedHostsForRooms returns the complete set of servers in the rooms given.
|
// GetJoinedHostsForRooms returns the complete set of servers in the rooms given.
|
||||||
GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf bool) ([]gomatrixserverlib.ServerName, error)
|
GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]gomatrixserverlib.ServerName, error)
|
||||||
|
|
||||||
StoreJSON(ctx context.Context, js string) (*shared.Receipt, error)
|
StoreJSON(ctx context.Context, js string) (*shared.Receipt, error)
|
||||||
|
|
||||||
|
|
|
@ -66,14 +66,20 @@ const selectAllJoinedHostsSQL = "" +
|
||||||
const selectJoinedHostsForRoomsSQL = "" +
|
const selectJoinedHostsForRoomsSQL = "" +
|
||||||
"SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id = ANY($1)"
|
"SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id = ANY($1)"
|
||||||
|
|
||||||
|
const selectJoinedHostsForRoomsExcludingBlacklistedSQL = "" +
|
||||||
|
"SELECT DISTINCT server_name FROM federationsender_joined_hosts j WHERE room_id = ANY($1) AND NOT EXISTS (" +
|
||||||
|
" SELECT server_name FROM federationsender_blacklist WHERE j.server_name = server_name" +
|
||||||
|
");"
|
||||||
|
|
||||||
type joinedHostsStatements struct {
|
type joinedHostsStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
insertJoinedHostsStmt *sql.Stmt
|
insertJoinedHostsStmt *sql.Stmt
|
||||||
deleteJoinedHostsStmt *sql.Stmt
|
deleteJoinedHostsStmt *sql.Stmt
|
||||||
deleteJoinedHostsForRoomStmt *sql.Stmt
|
deleteJoinedHostsForRoomStmt *sql.Stmt
|
||||||
selectJoinedHostsStmt *sql.Stmt
|
selectJoinedHostsStmt *sql.Stmt
|
||||||
selectAllJoinedHostsStmt *sql.Stmt
|
selectAllJoinedHostsStmt *sql.Stmt
|
||||||
selectJoinedHostsForRoomsStmt *sql.Stmt
|
selectJoinedHostsForRoomsStmt *sql.Stmt
|
||||||
|
selectJoinedHostsForRoomsExcludingBlacklistedStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPostgresJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) {
|
func NewPostgresJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) {
|
||||||
|
@ -102,6 +108,9 @@ func NewPostgresJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err erro
|
||||||
if s.selectJoinedHostsForRoomsStmt, err = s.db.Prepare(selectJoinedHostsForRoomsSQL); err != nil {
|
if s.selectJoinedHostsForRoomsStmt, err = s.db.Prepare(selectJoinedHostsForRoomsSQL); err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if s.selectJoinedHostsForRoomsExcludingBlacklistedStmt, err = s.db.Prepare(selectJoinedHostsForRoomsExcludingBlacklistedSQL); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -167,9 +176,13 @@ func (s *joinedHostsStatements) SelectAllJoinedHosts(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *joinedHostsStatements) SelectJoinedHostsForRooms(
|
func (s *joinedHostsStatements) SelectJoinedHostsForRooms(
|
||||||
ctx context.Context, roomIDs []string,
|
ctx context.Context, roomIDs []string, excludingBlacklisted bool,
|
||||||
) ([]gomatrixserverlib.ServerName, error) {
|
) ([]gomatrixserverlib.ServerName, error) {
|
||||||
rows, err := s.selectJoinedHostsForRoomsStmt.QueryContext(ctx, pq.StringArray(roomIDs))
|
stmt := s.selectJoinedHostsForRoomsStmt
|
||||||
|
if excludingBlacklisted {
|
||||||
|
stmt = s.selectJoinedHostsForRoomsExcludingBlacklistedStmt
|
||||||
|
}
|
||||||
|
rows, err := stmt.QueryContext(ctx, pq.StringArray(roomIDs))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -42,6 +42,10 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
|
||||||
if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()); err != nil {
|
if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewDummyWriter()); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
blacklist, err := NewPostgresBlacklistTable(d.db)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
joinedHosts, err := NewPostgresJoinedHostsTable(d.db)
|
joinedHosts, err := NewPostgresJoinedHostsTable(d.db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -58,10 +62,6 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
blacklist, err := NewPostgresBlacklistTable(d.db)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
inboundPeeks, err := NewPostgresInboundPeeksTable(d.db)
|
inboundPeeks, err := NewPostgresInboundPeeksTable(d.db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
@ -117,8 +117,8 @@ func (d *Database) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.S
|
||||||
return d.FederationJoinedHosts.SelectAllJoinedHosts(ctx)
|
return d.FederationJoinedHosts.SelectAllJoinedHosts(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf bool) ([]gomatrixserverlib.ServerName, error) {
|
func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludeSelf, excludeBlacklisted bool) ([]gomatrixserverlib.ServerName, error) {
|
||||||
servers, err := d.FederationJoinedHosts.SelectJoinedHostsForRooms(ctx, roomIDs)
|
servers, err := d.FederationJoinedHosts.SelectJoinedHostsForRooms(ctx, roomIDs, excludeBlacklisted)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -66,6 +66,11 @@ const selectAllJoinedHostsSQL = "" +
|
||||||
const selectJoinedHostsForRoomsSQL = "" +
|
const selectJoinedHostsForRoomsSQL = "" +
|
||||||
"SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id IN ($1)"
|
"SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id IN ($1)"
|
||||||
|
|
||||||
|
const selectJoinedHostsForRoomsExcludingBlacklistedSQL = "" +
|
||||||
|
"SELECT DISTINCT server_name FROM federationsender_joined_hosts j WHERE room_id IN ($1) AND NOT EXISTS (" +
|
||||||
|
" SELECT server_name FROM federationsender_blacklist WHERE j.server_name = server_name" +
|
||||||
|
");"
|
||||||
|
|
||||||
type joinedHostsStatements struct {
|
type joinedHostsStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
insertJoinedHostsStmt *sql.Stmt
|
insertJoinedHostsStmt *sql.Stmt
|
||||||
|
@ -74,6 +79,7 @@ type joinedHostsStatements struct {
|
||||||
selectJoinedHostsStmt *sql.Stmt
|
selectJoinedHostsStmt *sql.Stmt
|
||||||
selectAllJoinedHostsStmt *sql.Stmt
|
selectAllJoinedHostsStmt *sql.Stmt
|
||||||
// selectJoinedHostsForRoomsStmt *sql.Stmt - prepared at runtime due to variadic
|
// selectJoinedHostsForRoomsStmt *sql.Stmt - prepared at runtime due to variadic
|
||||||
|
// selectJoinedHostsForRoomsExcludingBlacklistedStmt *sql.Stmt - prepared at runtime due to variadic
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) {
|
func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) {
|
||||||
|
@ -168,14 +174,17 @@ func (s *joinedHostsStatements) SelectAllJoinedHosts(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *joinedHostsStatements) SelectJoinedHostsForRooms(
|
func (s *joinedHostsStatements) SelectJoinedHostsForRooms(
|
||||||
ctx context.Context, roomIDs []string,
|
ctx context.Context, roomIDs []string, excludingBlacklisted bool,
|
||||||
) ([]gomatrixserverlib.ServerName, error) {
|
) ([]gomatrixserverlib.ServerName, error) {
|
||||||
iRoomIDs := make([]interface{}, len(roomIDs))
|
iRoomIDs := make([]interface{}, len(roomIDs))
|
||||||
for i := range roomIDs {
|
for i := range roomIDs {
|
||||||
iRoomIDs[i] = roomIDs[i]
|
iRoomIDs[i] = roomIDs[i]
|
||||||
}
|
}
|
||||||
|
query := selectJoinedHostsForRoomsSQL
|
||||||
sql := strings.Replace(selectJoinedHostsForRoomsSQL, "($1)", sqlutil.QueryVariadic(len(iRoomIDs)), 1)
|
if excludingBlacklisted {
|
||||||
|
query = selectJoinedHostsForRoomsExcludingBlacklistedSQL
|
||||||
|
}
|
||||||
|
sql := strings.Replace(query, "($1)", sqlutil.QueryVariadic(len(iRoomIDs)), 1)
|
||||||
rows, err := s.db.QueryContext(ctx, sql, iRoomIDs...)
|
rows, err := s.db.QueryContext(ctx, sql, iRoomIDs...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
@ -41,6 +41,10 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
|
||||||
if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil {
|
if d.db, d.writer, err = base.DatabaseConnection(dbProperties, sqlutil.NewExclusiveWriter()); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
blacklist, err := NewSQLiteBlacklistTable(d.db)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
joinedHosts, err := NewSQLiteJoinedHostsTable(d.db)
|
joinedHosts, err := NewSQLiteJoinedHostsTable(d.db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -57,10 +61,6 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
blacklist, err := NewSQLiteBlacklistTable(d.db)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
outboundPeeks, err := NewSQLiteOutboundPeeksTable(d.db)
|
outboundPeeks, err := NewSQLiteOutboundPeeksTable(d.db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
@ -58,7 +58,7 @@ type FederationJoinedHosts interface {
|
||||||
SelectJoinedHostsWithTx(ctx context.Context, txn *sql.Tx, roomID string) ([]types.JoinedHost, error)
|
SelectJoinedHostsWithTx(ctx context.Context, txn *sql.Tx, roomID string) ([]types.JoinedHost, error)
|
||||||
SelectJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error)
|
SelectJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error)
|
||||||
SelectAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
|
SelectAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
|
||||||
SelectJoinedHostsForRooms(ctx context.Context, roomIDs []string) ([]gomatrixserverlib.ServerName, error)
|
SelectJoinedHostsForRooms(ctx context.Context, roomIDs []string, excludingBlacklisted bool) ([]gomatrixserverlib.ServerName, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type FederationBlacklist interface {
|
type FederationBlacklist interface {
|
||||||
|
|
|
@ -165,8 +165,9 @@ func (r *Inputer) processRoomEvent(
|
||||||
|
|
||||||
if missingAuth || missingPrev {
|
if missingAuth || missingPrev {
|
||||||
serverReq := &fedapi.QueryJoinedHostServerNamesInRoomRequest{
|
serverReq := &fedapi.QueryJoinedHostServerNamesInRoomRequest{
|
||||||
RoomID: event.RoomID(),
|
RoomID: event.RoomID(),
|
||||||
ExcludeSelf: true,
|
ExcludeSelf: true,
|
||||||
|
ExcludeBlacklisted: true,
|
||||||
}
|
}
|
||||||
if err = r.FSAPI.QueryJoinedHostServerNamesInRoom(ctx, serverReq, serverRes); err != nil {
|
if err = r.FSAPI.QueryJoinedHostServerNamesInRoom(ctx, serverReq, serverRes); err != nil {
|
||||||
return fmt.Errorf("r.FSAPI.QueryJoinedHostServerNamesInRoom: %w", err)
|
return fmt.Errorf("r.FSAPI.QueryJoinedHostServerNamesInRoom: %w", err)
|
||||||
|
|
Loading…
Reference in a new issue