From b0a3ee6c5c063962384bb91c59ec753ddc8cfe5f Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 16 Jul 2020 13:42:22 +0100 Subject: [PATCH] Add support for broadcasting wake-up EDUs to known hosts --- federationsender/api/api.go | 12 +++++++ federationsender/internal/perform.go | 22 ++++++++++++ federationsender/inthttp/client.go | 14 ++++++++ federationsender/inthttp/server.go | 13 +++++++ federationsender/queue/destinationqueue.go | 5 ++- federationsender/storage/interface.go | 1 + .../storage/postgres/joined_hosts_table.go | 34 +++++++++++++++++-- federationsender/storage/postgres/storage.go | 7 ++++ .../storage/sqlite3/joined_hosts_table.go | 34 +++++++++++++++++-- federationsender/storage/sqlite3/storage.go | 7 ++++ 10 files changed, 142 insertions(+), 7 deletions(-) diff --git a/federationsender/api/api.go b/federationsender/api/api.go index d90ffd29..b87af0eb 100644 --- a/federationsender/api/api.go +++ b/federationsender/api/api.go @@ -42,6 +42,12 @@ type FederationSenderInternalAPI interface { request *PerformServersAliveRequest, response *PerformServersAliveResponse, ) error + // Broadcasts an EDU to all servers in rooms we are joined to. + PerformBroadcastEDU( + ctx context.Context, + request *PerformBroadcastEDURequest, + response *PerformBroadcastEDUResponse, + ) error } type PerformDirectoryLookupRequest struct { @@ -91,3 +97,9 @@ type QueryJoinedHostServerNamesInRoomRequest struct { type QueryJoinedHostServerNamesInRoomResponse struct { ServerNames []gomatrixserverlib.ServerName `json:"server_names"` } + +type PerformBroadcastEDURequest struct { +} + +type PerformBroadcastEDUResponse struct { +} diff --git a/federationsender/internal/perform.go b/federationsender/internal/perform.go index 96b1149d..d9a4b963 100644 --- a/federationsender/internal/perform.go +++ b/federationsender/internal/perform.go @@ -308,3 +308,25 @@ func (r *FederationSenderInternalAPI) PerformServersAlive( return nil } + +// PerformServersAlive implements api.FederationSenderInternalAPI +func (r *FederationSenderInternalAPI) PerformBroadcastEDU( + ctx context.Context, + request *api.PerformBroadcastEDURequest, + response *api.PerformBroadcastEDUResponse, +) (err error) { + destinations, err := r.db.GetAllJoinedHosts(ctx) + if err != nil { + return fmt.Errorf("r.db.GetAllJoinedHosts: %w", err) + } + + edu := &gomatrixserverlib.EDU{ + Type: "org.matrix.dendrite.wakeup", + Origin: string(r.cfg.Matrix.ServerName), + } + if err = r.queues.SendEDU(edu, r.cfg.Matrix.ServerName, destinations); err != nil { + return fmt.Errorf("r.queues.SendEDU: %w", err) + } + + return nil +} diff --git a/federationsender/inthttp/client.go b/federationsender/inthttp/client.go index 25de99cc..4d968919 100644 --- a/federationsender/inthttp/client.go +++ b/federationsender/inthttp/client.go @@ -19,6 +19,7 @@ const ( FederationSenderPerformJoinRequestPath = "/federationsender/performJoinRequest" FederationSenderPerformLeaveRequestPath = "/federationsender/performLeaveRequest" FederationSenderPerformServersAlivePath = "/federationsender/performServersAlive" + FederationSenderPerformBroadcastEDUPath = "/federationsender/performBroadcastEDU" ) // NewFederationSenderClient creates a FederationSenderInternalAPI implemented by talking to a HTTP POST API. @@ -105,3 +106,16 @@ func (h *httpFederationSenderInternalAPI) PerformDirectoryLookup( apiURL := h.federationSenderURL + FederationSenderPerformDirectoryLookupRequestPath return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } + +// Handle an instruction to broadcast an EDU to all servers in rooms we are joined to. +func (h *httpFederationSenderInternalAPI) PerformBroadcastEDU( + ctx context.Context, + request *api.PerformBroadcastEDURequest, + response *api.PerformBroadcastEDUResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformBroadcastEDU") + defer span.Finish() + + apiURL := h.federationSenderURL + FederationSenderPerformBroadcastEDUPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} diff --git a/federationsender/inthttp/server.go b/federationsender/inthttp/server.go index a4f3d63d..ee05cf95 100644 --- a/federationsender/inthttp/server.go +++ b/federationsender/inthttp/server.go @@ -76,4 +76,17 @@ func AddRoutes(intAPI api.FederationSenderInternalAPI, internalAPIMux *mux.Route return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle(FederationSenderPerformBroadcastEDUPath, + httputil.MakeInternalAPI("PerformBroadcastEDU", func(req *http.Request) util.JSONResponse { + var request api.PerformBroadcastEDURequest + var response api.PerformBroadcastEDUResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := intAPI.PerformBroadcastEDU(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) } diff --git a/federationsender/queue/destinationqueue.go b/federationsender/queue/destinationqueue.go index 82cb343f..33741f80 100644 --- a/federationsender/queue/destinationqueue.go +++ b/federationsender/queue/destinationqueue.go @@ -256,7 +256,10 @@ func (oq *destinationQueue) backgroundSend() { // PDUs waiting to be sent. By sending a message into the wake chan, // the next loop iteration will try processing these PDUs again, // subject to the backoff. - oq.notifyPDUs <- true + select { + case oq.notifyPDUs <- true: + default: + } } } else if transaction { // If we successfully sent the transaction then clear out diff --git a/federationsender/storage/interface.go b/federationsender/storage/interface.go index 4bf36c24..6fff3518 100644 --- a/federationsender/storage/interface.go +++ b/federationsender/storage/interface.go @@ -26,6 +26,7 @@ type Database interface { internal.PartitionStorer UpdateRoom(ctx context.Context, roomID, oldEventID, newEventID string, addHosts []types.JoinedHost, removeHosts []string) (joinedHosts []types.JoinedHost, err error) GetJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error) + GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) StoreJSON(ctx context.Context, js string) (int64, error) AssociatePDUWithDestination(ctx context.Context, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, nids []int64) error GetNextTransactionPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (gomatrixserverlib.TransactionID, []*gomatrixserverlib.HeaderedEvent, error) diff --git a/federationsender/storage/postgres/joined_hosts_table.go b/federationsender/storage/postgres/joined_hosts_table.go index c0f9a7d5..2612e7e0 100644 --- a/federationsender/storage/postgres/joined_hosts_table.go +++ b/federationsender/storage/postgres/joined_hosts_table.go @@ -57,10 +57,14 @@ const selectJoinedHostsSQL = "" + "SELECT event_id, server_name FROM federationsender_joined_hosts" + " WHERE room_id = $1" +const selectAllJoinedHostsSQL = "" + + "SELECT DISTINCT server_name FROM federationsender_joined_hosts" + type joinedHostsStatements struct { - insertJoinedHostsStmt *sql.Stmt - deleteJoinedHostsStmt *sql.Stmt - selectJoinedHostsStmt *sql.Stmt + insertJoinedHostsStmt *sql.Stmt + deleteJoinedHostsStmt *sql.Stmt + selectJoinedHostsStmt *sql.Stmt + selectAllJoinedHostsStmt *sql.Stmt } func (s *joinedHostsStatements) prepare(db *sql.DB) (err error) { @@ -77,6 +81,9 @@ func (s *joinedHostsStatements) prepare(db *sql.DB) (err error) { if s.selectJoinedHostsStmt, err = db.Prepare(selectJoinedHostsSQL); err != nil { return } + if s.selectAllJoinedHostsStmt, err = db.Prepare(selectAllJoinedHostsSQL); err != nil { + return + } return } @@ -112,6 +119,27 @@ func (s *joinedHostsStatements) selectJoinedHosts( return joinedHostsFromStmt(ctx, s.selectJoinedHostsStmt, roomID) } +func (s *joinedHostsStatements) selectAllJoinedHosts( + ctx context.Context, +) ([]gomatrixserverlib.ServerName, error) { + rows, err := s.selectAllJoinedHostsStmt.QueryContext(ctx) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectAllJoinedHosts: rows.close() failed") + + var result []gomatrixserverlib.ServerName + for rows.Next() { + var serverName string + if err = rows.Scan(&serverName); err != nil { + return nil, err + } + result = append(result, gomatrixserverlib.ServerName(serverName)) + } + + return result, rows.Err() +} + func joinedHostsFromStmt( ctx context.Context, stmt *sql.Stmt, roomID string, ) ([]types.JoinedHost, error) { diff --git a/federationsender/storage/postgres/storage.go b/federationsender/storage/postgres/storage.go index 80686e09..1535ebdf 100644 --- a/federationsender/storage/postgres/storage.go +++ b/federationsender/storage/postgres/storage.go @@ -134,6 +134,13 @@ func (d *Database) GetJoinedHosts( return d.selectJoinedHosts(ctx, roomID) } +// GetAllJoinedHosts returns the currently joined hosts for +// all rooms known to the federation sender. +// Returns an error if something goes wrong. +func (d *Database) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) { + return d.selectAllJoinedHosts(ctx) +} + // StoreJSON adds a JSON blob into the queue JSON table and returns // a NID. The NID will then be used when inserting the per-destination // metadata entries. diff --git a/federationsender/storage/sqlite3/joined_hosts_table.go b/federationsender/storage/sqlite3/joined_hosts_table.go index d9824658..fd9ffedc 100644 --- a/federationsender/storage/sqlite3/joined_hosts_table.go +++ b/federationsender/storage/sqlite3/joined_hosts_table.go @@ -56,10 +56,14 @@ const selectJoinedHostsSQL = "" + "SELECT event_id, server_name FROM federationsender_joined_hosts" + " WHERE room_id = $1" +const selectAllJoinedHostsSQL = "" + + "SELECT DISTINCT server_name FROM federationsender_joined_hosts" + type joinedHostsStatements struct { - insertJoinedHostsStmt *sql.Stmt - deleteJoinedHostsStmt *sql.Stmt - selectJoinedHostsStmt *sql.Stmt + insertJoinedHostsStmt *sql.Stmt + deleteJoinedHostsStmt *sql.Stmt + selectJoinedHostsStmt *sql.Stmt + selectAllJoinedHostsStmt *sql.Stmt } func (s *joinedHostsStatements) prepare(db *sql.DB) (err error) { @@ -76,6 +80,9 @@ func (s *joinedHostsStatements) prepare(db *sql.DB) (err error) { if s.selectJoinedHostsStmt, err = db.Prepare(selectJoinedHostsSQL); err != nil { return } + if s.selectAllJoinedHostsStmt, err = db.Prepare(selectAllJoinedHostsSQL); err != nil { + return + } return } @@ -115,6 +122,27 @@ func (s *joinedHostsStatements) selectJoinedHosts( return joinedHostsFromStmt(ctx, s.selectJoinedHostsStmt, roomID) } +func (s *joinedHostsStatements) selectAllJoinedHosts( + ctx context.Context, +) ([]gomatrixserverlib.ServerName, error) { + rows, err := s.selectAllJoinedHostsStmt.QueryContext(ctx) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "selectAllJoinedHosts: rows.close() failed") + + var result []gomatrixserverlib.ServerName + for rows.Next() { + var serverName string + if err = rows.Scan(&serverName); err != nil { + return nil, err + } + result = append(result, gomatrixserverlib.ServerName(serverName)) + } + + return result, rows.Err() +} + func joinedHostsFromStmt( ctx context.Context, stmt *sql.Stmt, roomID string, ) ([]types.JoinedHost, error) { diff --git a/federationsender/storage/sqlite3/storage.go b/federationsender/storage/sqlite3/storage.go index 7fe6b65b..fee2a5b3 100644 --- a/federationsender/storage/sqlite3/storage.go +++ b/federationsender/storage/sqlite3/storage.go @@ -145,6 +145,13 @@ func (d *Database) GetJoinedHosts( return d.selectJoinedHosts(ctx, roomID) } +// GetAllJoinedHosts returns the currently joined hosts for +// all rooms known to the federation sender. +// Returns an error if something goes wrong. +func (d *Database) GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) { + return d.selectAllJoinedHosts(ctx) +} + // StoreJSON adds a JSON blob into the queue JSON table and returns // a NID. The NID will then be used when inserting the per-destination // metadata entries.