Add support for broadcasting wake-up EDUs to known hosts

This commit is contained in:
Neil Alexander 2020-07-16 13:42:22 +01:00
parent 8a5c2020b3
commit b0a3ee6c5c
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
10 changed files with 142 additions and 7 deletions

View file

@ -42,6 +42,12 @@ type FederationSenderInternalAPI interface {
request *PerformServersAliveRequest, request *PerformServersAliveRequest,
response *PerformServersAliveResponse, response *PerformServersAliveResponse,
) error ) error
// Broadcasts an EDU to all servers in rooms we are joined to.
PerformBroadcastEDU(
ctx context.Context,
request *PerformBroadcastEDURequest,
response *PerformBroadcastEDUResponse,
) error
} }
type PerformDirectoryLookupRequest struct { type PerformDirectoryLookupRequest struct {
@ -91,3 +97,9 @@ type QueryJoinedHostServerNamesInRoomRequest struct {
type QueryJoinedHostServerNamesInRoomResponse struct { type QueryJoinedHostServerNamesInRoomResponse struct {
ServerNames []gomatrixserverlib.ServerName `json:"server_names"` ServerNames []gomatrixserverlib.ServerName `json:"server_names"`
} }
type PerformBroadcastEDURequest struct {
}
type PerformBroadcastEDUResponse struct {
}

View file

@ -308,3 +308,25 @@ func (r *FederationSenderInternalAPI) PerformServersAlive(
return nil return nil
} }
// PerformServersAlive implements api.FederationSenderInternalAPI
func (r *FederationSenderInternalAPI) PerformBroadcastEDU(
ctx context.Context,
request *api.PerformBroadcastEDURequest,
response *api.PerformBroadcastEDUResponse,
) (err error) {
destinations, err := r.db.GetAllJoinedHosts(ctx)
if err != nil {
return fmt.Errorf("r.db.GetAllJoinedHosts: %w", err)
}
edu := &gomatrixserverlib.EDU{
Type: "org.matrix.dendrite.wakeup",
Origin: string(r.cfg.Matrix.ServerName),
}
if err = r.queues.SendEDU(edu, r.cfg.Matrix.ServerName, destinations); err != nil {
return fmt.Errorf("r.queues.SendEDU: %w", err)
}
return nil
}

View file

@ -19,6 +19,7 @@ const (
FederationSenderPerformJoinRequestPath = "/federationsender/performJoinRequest" FederationSenderPerformJoinRequestPath = "/federationsender/performJoinRequest"
FederationSenderPerformLeaveRequestPath = "/federationsender/performLeaveRequest" FederationSenderPerformLeaveRequestPath = "/federationsender/performLeaveRequest"
FederationSenderPerformServersAlivePath = "/federationsender/performServersAlive" FederationSenderPerformServersAlivePath = "/federationsender/performServersAlive"
FederationSenderPerformBroadcastEDUPath = "/federationsender/performBroadcastEDU"
) )
// NewFederationSenderClient creates a FederationSenderInternalAPI implemented by talking to a HTTP POST API. // NewFederationSenderClient creates a FederationSenderInternalAPI implemented by talking to a HTTP POST API.
@ -105,3 +106,16 @@ func (h *httpFederationSenderInternalAPI) PerformDirectoryLookup(
apiURL := h.federationSenderURL + FederationSenderPerformDirectoryLookupRequestPath apiURL := h.federationSenderURL + FederationSenderPerformDirectoryLookupRequestPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
} }
// Handle an instruction to broadcast an EDU to all servers in rooms we are joined to.
func (h *httpFederationSenderInternalAPI) PerformBroadcastEDU(
ctx context.Context,
request *api.PerformBroadcastEDURequest,
response *api.PerformBroadcastEDUResponse,
) error {
span, ctx := opentracing.StartSpanFromContext(ctx, "PerformBroadcastEDU")
defer span.Finish()
apiURL := h.federationSenderURL + FederationSenderPerformBroadcastEDUPath
return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
}

View file

@ -76,4 +76,17 @@ func AddRoutes(intAPI api.FederationSenderInternalAPI, internalAPIMux *mux.Route
return util.JSONResponse{Code: http.StatusOK, JSON: &response} return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}), }),
) )
internalAPIMux.Handle(FederationSenderPerformBroadcastEDUPath,
httputil.MakeInternalAPI("PerformBroadcastEDU", func(req *http.Request) util.JSONResponse {
var request api.PerformBroadcastEDURequest
var response api.PerformBroadcastEDUResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
return util.MessageResponse(http.StatusBadRequest, err.Error())
}
if err := intAPI.PerformBroadcastEDU(req.Context(), &request, &response); err != nil {
return util.ErrorResponse(err)
}
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
}),
)
} }

View file

@ -256,7 +256,10 @@ func (oq *destinationQueue) backgroundSend() {
// PDUs waiting to be sent. By sending a message into the wake chan, // PDUs waiting to be sent. By sending a message into the wake chan,
// the next loop iteration will try processing these PDUs again, // the next loop iteration will try processing these PDUs again,
// subject to the backoff. // subject to the backoff.
oq.notifyPDUs <- true select {
case oq.notifyPDUs <- true:
default:
}
} }
} else if transaction { } else if transaction {
// If we successfully sent the transaction then clear out // If we successfully sent the transaction then clear out

View file

@ -26,6 +26,7 @@ type Database interface {
internal.PartitionStorer internal.PartitionStorer
UpdateRoom(ctx context.Context, roomID, oldEventID, newEventID string, addHosts []types.JoinedHost, removeHosts []string) (joinedHosts []types.JoinedHost, err error) UpdateRoom(ctx context.Context, roomID, oldEventID, newEventID string, addHosts []types.JoinedHost, removeHosts []string) (joinedHosts []types.JoinedHost, err error)
GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error) GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error)
GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error)
StoreJSON(ctx context.Context, js string) (int64, error) StoreJSON(ctx context.Context, js string) (int64, error)
AssociatePDUWithDestination(ctx context.Context, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nids []int64) error AssociatePDUWithDestination(ctx context.Context, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nids []int64) error
GetNextTransactionPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (gomatrixserverlib.TransactionID, []*gomatrixserverlib.HeaderedEvent, error) GetNextTransactionPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (gomatrixserverlib.TransactionID, []*gomatrixserverlib.HeaderedEvent, error)

View file

@ -57,10 +57,14 @@ const selectJoinedHostsSQL = "" +
"SELECT event_id, server_name FROM federationsender_joined_hosts" + "SELECT event_id, server_name FROM federationsender_joined_hosts" +
" WHERE room_id = $1" " WHERE room_id = $1"
const selectAllJoinedHostsSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_joined_hosts"
type joinedHostsStatements struct { type joinedHostsStatements struct {
insertJoinedHostsStmt *sql.Stmt insertJoinedHostsStmt *sql.Stmt
deleteJoinedHostsStmt *sql.Stmt deleteJoinedHostsStmt *sql.Stmt
selectJoinedHostsStmt *sql.Stmt selectJoinedHostsStmt *sql.Stmt
selectAllJoinedHostsStmt *sql.Stmt
} }
func (s *joinedHostsStatements) prepare(db *sql.DB) (err error) { func (s *joinedHostsStatements) prepare(db *sql.DB) (err error) {
@ -77,6 +81,9 @@ func (s *joinedHostsStatements) prepare(db *sql.DB) (err error) {
if s.selectJoinedHostsStmt, err = db.Prepare(selectJoinedHostsSQL); err != nil { if s.selectJoinedHostsStmt, err = db.Prepare(selectJoinedHostsSQL); err != nil {
return return
} }
if s.selectAllJoinedHostsStmt, err = db.Prepare(selectAllJoinedHostsSQL); err != nil {
return
}
return return
} }
@ -112,6 +119,27 @@ func (s *joinedHostsStatements) selectJoinedHosts(
return joinedHostsFromStmt(ctx, s.selectJoinedHostsStmt, roomID) return joinedHostsFromStmt(ctx, s.selectJoinedHostsStmt, roomID)
} }
func (s *joinedHostsStatements) selectAllJoinedHosts(
ctx context.Context,
) ([]gomatrixserverlib.ServerName, error) {
rows, err := s.selectAllJoinedHostsStmt.QueryContext(ctx)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectAllJoinedHosts: rows.close() failed")
var result []gomatrixserverlib.ServerName
for rows.Next() {
var serverName string
if err = rows.Scan(&serverName); err != nil {
return nil, err
}
result = append(result, gomatrixserverlib.ServerName(serverName))
}
return result, rows.Err()
}
func joinedHostsFromStmt( func joinedHostsFromStmt(
ctx context.Context, stmt *sql.Stmt, roomID string, ctx context.Context, stmt *sql.Stmt, roomID string,
) ([]types.JoinedHost, error) { ) ([]types.JoinedHost, error) {

View file

@ -134,6 +134,13 @@ func (d *Database) GetJoinedHosts(
return d.selectJoinedHosts(ctx, roomID) return d.selectJoinedHosts(ctx, roomID)
} }
// GetAllJoinedHosts returns the currently joined hosts for
// all rooms known to the federation sender.
// Returns an error if something goes wrong.
func (d *Database) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) {
return d.selectAllJoinedHosts(ctx)
}
// StoreJSON adds a JSON blob into the queue JSON table and returns // StoreJSON adds a JSON blob into the queue JSON table and returns
// a NID. The NID will then be used when inserting the per-destination // a NID. The NID will then be used when inserting the per-destination
// metadata entries. // metadata entries.

View file

@ -56,10 +56,14 @@ const selectJoinedHostsSQL = "" +
"SELECT event_id, server_name FROM federationsender_joined_hosts" + "SELECT event_id, server_name FROM federationsender_joined_hosts" +
" WHERE room_id = $1" " WHERE room_id = $1"
const selectAllJoinedHostsSQL = "" +
"SELECT DISTINCT server_name FROM federationsender_joined_hosts"
type joinedHostsStatements struct { type joinedHostsStatements struct {
insertJoinedHostsStmt *sql.Stmt insertJoinedHostsStmt *sql.Stmt
deleteJoinedHostsStmt *sql.Stmt deleteJoinedHostsStmt *sql.Stmt
selectJoinedHostsStmt *sql.Stmt selectJoinedHostsStmt *sql.Stmt
selectAllJoinedHostsStmt *sql.Stmt
} }
func (s *joinedHostsStatements) prepare(db *sql.DB) (err error) { func (s *joinedHostsStatements) prepare(db *sql.DB) (err error) {
@ -76,6 +80,9 @@ func (s *joinedHostsStatements) prepare(db *sql.DB) (err error) {
if s.selectJoinedHostsStmt, err = db.Prepare(selectJoinedHostsSQL); err != nil { if s.selectJoinedHostsStmt, err = db.Prepare(selectJoinedHostsSQL); err != nil {
return return
} }
if s.selectAllJoinedHostsStmt, err = db.Prepare(selectAllJoinedHostsSQL); err != nil {
return
}
return return
} }
@ -115,6 +122,27 @@ func (s *joinedHostsStatements) selectJoinedHosts(
return joinedHostsFromStmt(ctx, s.selectJoinedHostsStmt, roomID) return joinedHostsFromStmt(ctx, s.selectJoinedHostsStmt, roomID)
} }
func (s *joinedHostsStatements) selectAllJoinedHosts(
ctx context.Context,
) ([]gomatrixserverlib.ServerName, error) {
rows, err := s.selectAllJoinedHostsStmt.QueryContext(ctx)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectAllJoinedHosts: rows.close() failed")
var result []gomatrixserverlib.ServerName
for rows.Next() {
var serverName string
if err = rows.Scan(&serverName); err != nil {
return nil, err
}
result = append(result, gomatrixserverlib.ServerName(serverName))
}
return result, rows.Err()
}
func joinedHostsFromStmt( func joinedHostsFromStmt(
ctx context.Context, stmt *sql.Stmt, roomID string, ctx context.Context, stmt *sql.Stmt, roomID string,
) ([]types.JoinedHost, error) { ) ([]types.JoinedHost, error) {

View file

@ -145,6 +145,13 @@ func (d *Database) GetJoinedHosts(
return d.selectJoinedHosts(ctx, roomID) return d.selectJoinedHosts(ctx, roomID)
} }
// GetAllJoinedHosts returns the currently joined hosts for
// all rooms known to the federation sender.
// Returns an error if something goes wrong.
func (d *Database) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) {
return d.selectAllJoinedHosts(ctx)
}
// StoreJSON adds a JSON blob into the queue JSON table and returns // StoreJSON adds a JSON blob into the queue JSON table and returns
// a NID. The NID will then be used when inserting the per-destination // a NID. The NID will then be used when inserting the per-destination
// metadata entries. // metadata entries.