diff --git a/federationsender/api/api.go b/federationsender/api/api.go index 82cdf9d8..8cec0bbd 100644 --- a/federationsender/api/api.go +++ b/federationsender/api/api.go @@ -48,6 +48,12 @@ type FederationSenderInternalAPI interface { request *PerformDirectoryLookupRequest, response *PerformDirectoryLookupResponse, ) error + // QueryServerJoinedToRoom checks if a single server is in a room right now. + QueryServerJoinedToRoom( + ctx context.Context, + request *QueryServerJoinedToRoomRequest, + response *QueryServerJoinedToRoomResponse, + ) error // Query the server names of the joined hosts in a room. // Unlike QueryJoinedHostsInRoom, this function returns a de-duplicated slice // containing only the server names (without information for membership events). @@ -183,6 +189,17 @@ type QueryJoinedHostServerNamesInRoomResponse struct { ServerNames []gomatrixserverlib.ServerName `json:"server_names"` } +// QueryJoinedHostServerNamesInRoomRequest is a request to QueryJoinedHostServerNames +type QueryServerJoinedToRoomRequest struct { + ServerName gomatrixserverlib.ServerName `json:"server_name"` + RoomID string `json:"room_id"` +} + +// QueryJoinedHostServerNamesInRoomResponse is a response to QueryJoinedHostServerNames +type QueryServerJoinedToRoomResponse struct { + Joined bool `json:"joined"` +} + type PerformBroadcastEDURequest struct { } diff --git a/federationsender/internal/query.go b/federationsender/internal/query.go index af531f7d..16cb6015 100644 --- a/federationsender/internal/query.go +++ b/federationsender/internal/query.go @@ -25,6 +25,16 @@ func (f *FederationSenderInternalAPI) QueryJoinedHostServerNamesInRoom( return } +// QueryJoinedHostServerNamesInRoom implements api.FederationSenderInternalAPI +func (f *FederationSenderInternalAPI) QueryServerJoinedToRoom( + ctx context.Context, + request *api.QueryServerJoinedToRoomRequest, + response *api.QueryServerJoinedToRoomResponse, +) (err error) { + response.Joined, err = f.db.GetServerJoinedToRoom(ctx, request.ServerName, request.RoomID) + return +} + func (a *FederationSenderInternalAPI) fetchServerKeysDirectly(ctx context.Context, serverName gomatrixserverlib.ServerName) (*gomatrixserverlib.ServerKeys, error) { ctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() diff --git a/federationsender/inthttp/client.go b/federationsender/inthttp/client.go index f08e610a..a54acf6a 100644 --- a/federationsender/inthttp/client.go +++ b/federationsender/inthttp/client.go @@ -15,6 +15,7 @@ import ( // HTTP paths for the internal HTTP API const ( FederationSenderQueryJoinedHostServerNamesInRoomPath = "/federationsender/queryJoinedHostServerNamesInRoom" + FederationSenderQueryServerJoinedToRoomPath = "/federationsender/queryServerJoinedToRoom" FederationSenderQueryServerKeysPath = "/federationsender/queryServerKeys" FederationSenderPerformDirectoryLookupRequestPath = "/federationsender/performDirectoryLookup" @@ -115,6 +116,19 @@ func (h *httpFederationSenderInternalAPI) QueryJoinedHostServerNamesInRoom( return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) } +// QueryServerJoinedToRoom implements FederationSenderInternalAPI +func (h *httpFederationSenderInternalAPI) QueryServerJoinedToRoom( + ctx context.Context, + request *api.QueryServerJoinedToRoomRequest, + response *api.QueryServerJoinedToRoomResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryServerJoinedToRoom") + defer span.Finish() + + apiURL := h.federationSenderURL + FederationSenderQueryServerJoinedToRoomPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + // Handle an instruction to make_join & send_join with a remote server. func (h *httpFederationSenderInternalAPI) PerformJoin( ctx context.Context, diff --git a/federationsender/inthttp/server.go b/federationsender/inthttp/server.go index a7fbc4ed..cd288055 100644 --- a/federationsender/inthttp/server.go +++ b/federationsender/inthttp/server.go @@ -27,6 +27,20 @@ func AddRoutes(intAPI api.FederationSenderInternalAPI, internalAPIMux *mux.Route return util.JSONResponse{Code: http.StatusOK, JSON: &response} }), ) + internalAPIMux.Handle( + FederationSenderQueryServerJoinedToRoomPath, + httputil.MakeInternalAPI("QueryServerJoinedToRoom", func(req *http.Request) util.JSONResponse { + var request api.QueryServerJoinedToRoomRequest + var response api.QueryServerJoinedToRoomResponse + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.ErrorResponse(err) + } + if err := intAPI.QueryServerJoinedToRoom(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) internalAPIMux.Handle( FederationSenderPerformJoinRequestPath, httputil.MakeInternalAPI("PerformJoinRequest", func(req *http.Request) util.JSONResponse { diff --git a/federationsender/storage/interface.go b/federationsender/storage/interface.go index 58c8a7cf..923dbf4e 100644 --- a/federationsender/storage/interface.go +++ b/federationsender/storage/interface.go @@ -32,6 +32,7 @@ type Database interface { GetAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) // GetJoinedHostsForRooms returns the complete set of servers in the rooms given. GetJoinedHostsForRooms(ctx context.Context, roomIDs []string) ([]gomatrixserverlib.ServerName, error) + GetServerJoinedToRoom(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID string) (bool, error) PurgeRoomState(ctx context.Context, roomID string) error StoreJSON(ctx context.Context, js string) (*shared.Receipt, error) diff --git a/federationsender/storage/postgres/joined_hosts_table.go b/federationsender/storage/postgres/joined_hosts_table.go index 0c1e91ee..1c62d509 100644 --- a/federationsender/storage/postgres/joined_hosts_table.go +++ b/federationsender/storage/postgres/joined_hosts_table.go @@ -66,6 +66,9 @@ const selectAllJoinedHostsSQL = "" + const selectJoinedHostsForRoomsSQL = "" + "SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id = ANY($1)" +const selectServerJoinedToRoomSQL = "" + + "SELECT COUNT(*) FROM federation_sender_joined_hosts WHERE server_name = $1 AND room_id = $2" + type joinedHostsStatements struct { db *sql.DB insertJoinedHostsStmt *sql.Stmt @@ -74,6 +77,7 @@ type joinedHostsStatements struct { selectJoinedHostsStmt *sql.Stmt selectAllJoinedHostsStmt *sql.Stmt selectJoinedHostsForRoomsStmt *sql.Stmt + selectServerJoinedToRoomStmt *sql.Stmt } func NewPostgresJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) { @@ -102,6 +106,9 @@ func NewPostgresJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err erro if s.selectJoinedHostsForRoomsStmt, err = s.db.Prepare(selectJoinedHostsForRoomsSQL); err != nil { return } + if s.selectServerJoinedToRoomStmt, err = s.db.Prepare(selectServerJoinedToRoomSQL); err != nil { + return + } return } @@ -145,6 +152,20 @@ func (s *joinedHostsStatements) SelectJoinedHosts( return joinedHostsFromStmt(ctx, s.selectJoinedHostsStmt, roomID) } +func (s *joinedHostsStatements) SelectServerJoinedToRoom( + ctx context.Context, serverName gomatrixserverlib.ServerName, roomID string, +) (bool, error) { + row := s.selectServerJoinedToRoomStmt.QueryRowContext(ctx, serverName, roomID) + if err := row.Err(); err != nil { + return false, err + } + var count int + if err := row.Scan(&count); err != nil { + return false, err + } + return count > 0, nil +} + func (s *joinedHostsStatements) SelectAllJoinedHosts( ctx context.Context, ) ([]gomatrixserverlib.ServerName, error) { diff --git a/federationsender/storage/shared/storage.go b/federationsender/storage/shared/storage.go index 45c9febd..b60ede92 100644 --- a/federationsender/storage/shared/storage.go +++ b/federationsender/storage/shared/storage.go @@ -105,6 +105,15 @@ func (d *Database) GetJoinedHostsForRooms(ctx context.Context, roomIDs []string) return d.FederationSenderJoinedHosts.SelectJoinedHostsForRooms(ctx, roomIDs) } +// GetJoinedHosts returns the currently joined hosts for room, +// as known to federationserver. +// Returns an error if something goes wrong. +func (d *Database) GetServerJoinedToRoom( + ctx context.Context, serverName gomatrixserverlib.ServerName, roomID string, +) (bool, error) { + return d.FederationSenderJoinedHosts.SelectServerJoinedToRoom(ctx, serverName, roomID) +} + // StoreJSON adds a JSON blob into the queue JSON table and returns // a NID. The NID will then be used when inserting the per-destination // metadata entries. diff --git a/federationsender/storage/sqlite3/joined_hosts_table.go b/federationsender/storage/sqlite3/joined_hosts_table.go index 4c0c1f51..71ea485f 100644 --- a/federationsender/storage/sqlite3/joined_hosts_table.go +++ b/federationsender/storage/sqlite3/joined_hosts_table.go @@ -66,6 +66,9 @@ const selectAllJoinedHostsSQL = "" + const selectJoinedHostsForRoomsSQL = "" + "SELECT DISTINCT server_name FROM federationsender_joined_hosts WHERE room_id IN ($1)" +const selectServerJoinedToRoomSQL = "" + + "SELECT COUNT(*) FROM federation_sender_joined_hosts WHERE server_name = $1 AND room_id = $2" + type joinedHostsStatements struct { db *sql.DB insertJoinedHostsStmt *sql.Stmt @@ -74,6 +77,7 @@ type joinedHostsStatements struct { selectJoinedHostsStmt *sql.Stmt selectAllJoinedHostsStmt *sql.Stmt // selectJoinedHostsForRoomsStmt *sql.Stmt - prepared at runtime due to variadic + selectServerJoinedToRoomStmt *sql.Stmt } func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) { @@ -99,6 +103,9 @@ func NewSQLiteJoinedHostsTable(db *sql.DB) (s *joinedHostsStatements, err error) if s.selectAllJoinedHostsStmt, err = db.Prepare(selectAllJoinedHostsSQL); err != nil { return } + if s.selectServerJoinedToRoomStmt, err = s.db.Prepare(selectServerJoinedToRoomSQL); err != nil { + return + } return } @@ -146,6 +153,20 @@ func (s *joinedHostsStatements) SelectJoinedHosts( return joinedHostsFromStmt(ctx, s.selectJoinedHostsStmt, roomID) } +func (s *joinedHostsStatements) SelectServerJoinedToRoom( + ctx context.Context, serverName gomatrixserverlib.ServerName, roomID string, +) (bool, error) { + row := s.selectServerJoinedToRoomStmt.QueryRowContext(ctx, serverName, roomID) + if err := row.Err(); err != nil { + return false, err + } + var count int + if err := row.Scan(&count); err != nil { + return false, err + } + return count > 0, nil +} + func (s *joinedHostsStatements) SelectAllJoinedHosts( ctx context.Context, ) ([]gomatrixserverlib.ServerName, error) { diff --git a/federationsender/storage/tables/interface.go b/federationsender/storage/tables/interface.go index 663a4cb2..80ad879b 100644 --- a/federationsender/storage/tables/interface.go +++ b/federationsender/storage/tables/interface.go @@ -56,6 +56,7 @@ type FederationSenderJoinedHosts interface { SelectJoinedHosts(ctx context.Context, roomID string) ([]types.JoinedHost, error) SelectAllJoinedHosts(ctx context.Context) ([]gomatrixserverlib.ServerName, error) SelectJoinedHostsForRooms(ctx context.Context, roomIDs []string) ([]gomatrixserverlib.ServerName, error) + SelectServerJoinedToRoom(ctx context.Context, serverName gomatrixserverlib.ServerName, roomID string) (bool, error) } type FederationSenderBlacklist interface {