mirror of
https://github.com/hoernschen/dendrite.git
synced 2024-12-26 15:08:28 +00:00
WIP Third Party Protocol Definition Lookup
This commit is contained in:
parent
0eb6078dad
commit
9c0919bbd5
12 changed files with 681 additions and 85 deletions
|
@ -21,6 +21,7 @@ import (
|
|||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/matrix-org/dendrite/appservice/types"
|
||||
commonHTTP "github.com/matrix-org/dendrite/common/http"
|
||||
opentracing "github.com/opentracing/opentracing-go"
|
||||
)
|
||||
|
@ -46,6 +47,29 @@ type RoomAliasExistsResponse struct {
|
|||
AliasExists bool `json:"exists"`
|
||||
}
|
||||
|
||||
// GetProtocolDefinitionRequest is a request to the appservice component asking
|
||||
// for the definition of a single third party protocol
|
||||
type GetProtocolDefinitionRequest struct {
|
||||
ProtocolID string `json:"protocol_definition"`
|
||||
}
|
||||
|
||||
// GetProtocolDefinitionResponse is a response providing a protocol definition
|
||||
// for the given protocol ID
|
||||
type GetProtocolDefinitionResponse struct {
|
||||
ProtocolDefinition string `json:"protocol_definition"`
|
||||
}
|
||||
|
||||
// GetAllProtocolDefinitionsRequest is a request to the appservice component
|
||||
// asking for what third party protocols are known and their definitions
|
||||
type GetAllProtocolDefinitionsRequest struct {
|
||||
}
|
||||
|
||||
// GetAllProtocolDefinitionsResponse is a response containing all known third
|
||||
// party IDs and their definitions
|
||||
type GetAllProtocolDefinitionsResponse struct {
|
||||
Protocols types.ThirdPartyProtocols `json:"protocols"`
|
||||
}
|
||||
|
||||
// AppServiceQueryAPI is used to query user and room alias data from application
|
||||
// services
|
||||
type AppServiceQueryAPI interface {
|
||||
|
@ -56,10 +80,26 @@ type AppServiceQueryAPI interface {
|
|||
response *RoomAliasExistsResponse,
|
||||
) error
|
||||
// TODO: QueryUserIDExists
|
||||
GetProtocolDefinition(
|
||||
ctx context.Context,
|
||||
req *GetProtocolDefinitionRequest,
|
||||
response *GetProtocolDefinitionResponse,
|
||||
) error
|
||||
GetAllProtocolDefinitions(
|
||||
ctx context.Context,
|
||||
req *GetAllProtocolDefinitionsRequest,
|
||||
response *GetAllProtocolDefinitionsResponse,
|
||||
) error
|
||||
}
|
||||
|
||||
// AppServiceRoomAliasExistsPath is the HTTP path for the RoomAliasExists API
|
||||
const AppServiceRoomAliasExistsPath = "/api/appservice/RoomAliasExists"
|
||||
// RoomAliasExistsPath is the HTTP path for the RoomAliasExists API
|
||||
const RoomAliasExistsPath = "/api/appservice/RoomAliasExists"
|
||||
|
||||
// GetProtocolDefinitionPath is the HTTP path for the GetProtocolDefinition API
|
||||
const GetProtocolDefinitionPath = "/api/appservice/GetProtocolDefinition"
|
||||
|
||||
// GetAllProtocolDefinitionsPath is the HTTP path for the GetAllProtocolDefinitions API
|
||||
const GetAllProtocolDefinitionsPath = "/api/appservice/GetAllProtocolDefinitions"
|
||||
|
||||
// httpAppServiceQueryAPI contains the URL to an appservice query API and a
|
||||
// reference to a httpClient used to reach it
|
||||
|
@ -90,6 +130,32 @@ func (h *httpAppServiceQueryAPI) RoomAliasExists(
|
|||
span, ctx := opentracing.StartSpanFromContext(ctx, "appserviceRoomAliasExists")
|
||||
defer span.Finish()
|
||||
|
||||
apiURL := h.appserviceURL + AppServiceRoomAliasExistsPath
|
||||
apiURL := h.appserviceURL + RoomAliasExistsPath
|
||||
return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
|
||||
}
|
||||
|
||||
// GetProtocolDefinition implements AppServiceQueryAPI
|
||||
func (h *httpAppServiceQueryAPI) GetProtocolDefinition(
|
||||
ctx context.Context,
|
||||
request *GetProtocolDefinitionRequest,
|
||||
response *GetProtocolDefinitionResponse,
|
||||
) error {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "appserviceGetProtocolDefinition")
|
||||
defer span.Finish()
|
||||
|
||||
apiURL := h.appserviceURL + GetProtocolDefinitionPath
|
||||
return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
|
||||
}
|
||||
|
||||
// GetAllProtocolDefinitions implements AppServiceQueryAPI
|
||||
func (h *httpAppServiceQueryAPI) GetAllProtocolDefinitions(
|
||||
ctx context.Context,
|
||||
request *GetAllProtocolDefinitionsRequest,
|
||||
response *GetAllProtocolDefinitionsResponse,
|
||||
) error {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "appserviceGetAllProtocolDefinitions")
|
||||
defer span.Finish()
|
||||
|
||||
apiURL := h.appserviceURL + GetAllProtocolDefinitionsPath
|
||||
return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
|
||||
}
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
package appservice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -31,7 +32,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/common/transactions"
|
||||
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/sirupsen/logrus"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// SetupAppServiceAPIComponent sets up and registers HTTP handlers for the AppServices
|
||||
|
@ -47,7 +48,7 @@ func SetupAppServiceAPIComponent(
|
|||
// Create a connection to the appservice postgres DB
|
||||
appserviceDB, err := storage.NewDatabase(string(base.Cfg.Database.AppService))
|
||||
if err != nil {
|
||||
logrus.WithError(err).Panicf("failed to connect to appservice db")
|
||||
log.WithError(err).Panicf("failed to connect to appservice db")
|
||||
}
|
||||
|
||||
// Wrap application services in a type that relates the application service and
|
||||
|
@ -72,6 +73,7 @@ func SetupAppServiceAPIComponent(
|
|||
appserviceQueryAPI := query.AppServiceQueryAPI{
|
||||
HTTPClient: httpClient,
|
||||
Cfg: base.Cfg,
|
||||
Db: appserviceDB,
|
||||
}
|
||||
|
||||
appserviceQueryAPI.SetupHTTP(http.DefaultServeMux)
|
||||
|
@ -81,13 +83,11 @@ func SetupAppServiceAPIComponent(
|
|||
roomserverQueryAPI, roomserverAliasAPI, workerStates,
|
||||
)
|
||||
if err := consumer.Start(); err != nil {
|
||||
logrus.WithError(err).Panicf("failed to start app service roomserver consumer")
|
||||
log.WithError(err).Panicf("failed to start app service roomserver consumer")
|
||||
}
|
||||
|
||||
// Create application service transaction workers
|
||||
if err := workers.SetupTransactionWorkers(appserviceDB, workerStates); err != nil {
|
||||
logrus.WithError(err).Panicf("failed to start app service transaction workers")
|
||||
}
|
||||
// Create application service transaction and third party workers
|
||||
setupWorkers(appserviceDB, workerStates)
|
||||
|
||||
// Set up HTTP Endpoints
|
||||
routing.Setup(
|
||||
|
@ -97,3 +97,29 @@ func SetupAppServiceAPIComponent(
|
|||
|
||||
return &appserviceQueryAPI
|
||||
}
|
||||
|
||||
// setupWorkers creates worker goroutines that each interface with a connected
|
||||
// application service.
|
||||
func setupWorkers(
|
||||
appserviceDB *storage.Database,
|
||||
workerStates []types.ApplicationServiceWorkerState,
|
||||
) {
|
||||
// Clear all old protocol definitions on startup
|
||||
appserviceDB.ClearProtocolDefinitions(context.TODO())
|
||||
|
||||
// Create a worker that handles transmitting events to a single homeserver
|
||||
for _, workerState := range workerStates {
|
||||
log.WithFields(log.Fields{
|
||||
"appservice": workerState.AppService.ID,
|
||||
}).Info("starting application service")
|
||||
|
||||
// Don't create a worker if this AS doesn't want to receive events
|
||||
if workerState.AppService.URL != "" {
|
||||
// Worker to handle sending event transactions
|
||||
go workers.TransactionWorker(appserviceDB, workerState)
|
||||
|
||||
// Worker to handle retreiving information about third parties
|
||||
go workers.ThirdPartyWorker(appserviceDB, workerState.AppService)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,6 +25,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/appservice/api"
|
||||
"github.com/matrix-org/dendrite/appservice/storage"
|
||||
"github.com/matrix-org/dendrite/common"
|
||||
"github.com/matrix-org/dendrite/common/config"
|
||||
"github.com/matrix-org/util"
|
||||
|
@ -38,6 +39,45 @@ const roomAliasExistsPath = "/rooms/"
|
|||
type AppServiceQueryAPI struct {
|
||||
HTTPClient *http.Client
|
||||
Cfg *config.Dendrite
|
||||
Db *storage.Database
|
||||
}
|
||||
|
||||
// GetProtocolDefinition queries the database for the protocol definition of a
|
||||
// protocol with given ID
|
||||
func (a *AppServiceQueryAPI) GetProtocolDefinition(
|
||||
ctx context.Context,
|
||||
request *api.GetProtocolDefinitionRequest,
|
||||
response *api.GetProtocolDefinitionResponse,
|
||||
) error {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "ApplicationServiceGetProtocolDefinition")
|
||||
defer span.Finish()
|
||||
|
||||
protocolDefinition, err := a.Db.GetProtocolDefinition(ctx, request.ProtocolID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
response.ProtocolDefinition = protocolDefinition
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAllProtocolDefinitions queries the database for all known protocol
|
||||
// definitions and their IDs
|
||||
func (a *AppServiceQueryAPI) GetAllProtocolDefinitions(
|
||||
ctx context.Context,
|
||||
request *api.GetAllProtocolDefinitionsRequest,
|
||||
response *api.GetAllProtocolDefinitionsResponse,
|
||||
) error {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "ApplicationServiceGetAllProtocolDefinitions")
|
||||
defer span.Finish()
|
||||
|
||||
protocolDefinitions, err := a.Db.GetAllProtocolDefinitions(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
response.Protocols = protocolDefinitions
|
||||
return nil
|
||||
}
|
||||
|
||||
// RoomAliasExists performs a request to '/room/{roomAlias}' on all known
|
||||
|
@ -120,7 +160,7 @@ func makeHTTPClient() *http.Client {
|
|||
// handles and muxes incoming api requests the to internal AppServiceQueryAPI.
|
||||
func (a *AppServiceQueryAPI) SetupHTTP(servMux *http.ServeMux) {
|
||||
servMux.Handle(
|
||||
api.AppServiceRoomAliasExistsPath,
|
||||
api.RoomAliasExistsPath,
|
||||
common.MakeInternalAPI("appserviceRoomAliasExists", func(req *http.Request) util.JSONResponse {
|
||||
var request api.RoomAliasExistsRequest
|
||||
var response api.RoomAliasExistsResponse
|
||||
|
@ -133,4 +173,32 @@ func (a *AppServiceQueryAPI) SetupHTTP(servMux *http.ServeMux) {
|
|||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||
}),
|
||||
)
|
||||
servMux.Handle(
|
||||
api.GetProtocolDefinitionPath,
|
||||
common.MakeInternalAPI("appserviceGetProtocolDefinition", func(req *http.Request) util.JSONResponse {
|
||||
var request api.GetProtocolDefinitionRequest
|
||||
var response api.GetProtocolDefinitionResponse
|
||||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
if err := a.GetProtocolDefinition(req.Context(), &request, &response); err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||
}),
|
||||
)
|
||||
servMux.Handle(
|
||||
api.GetAllProtocolDefinitionsPath,
|
||||
common.MakeInternalAPI("appserviceGetAllProtocolDefinitions", func(req *http.Request) util.JSONResponse {
|
||||
var request api.GetAllProtocolDefinitionsRequest
|
||||
var response api.GetAllProtocolDefinitionsResponse
|
||||
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
if err := a.GetAllProtocolDefinitions(req.Context(), &request, &response); err != nil {
|
||||
return util.ErrorResponse(err)
|
||||
}
|
||||
return util.JSONResponse{Code: http.StatusOK, JSON: &response}
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
|
|
@ -20,14 +20,16 @@ import (
|
|||
|
||||
// Import postgres database driver
|
||||
_ "github.com/lib/pq"
|
||||
"github.com/matrix-org/dendrite/appservice/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// Database stores events intended to be later sent to application services
|
||||
type Database struct {
|
||||
events eventsStatements
|
||||
txnID txnStatements
|
||||
db *sql.DB
|
||||
events eventsStatements
|
||||
txnID txnStatements
|
||||
thirdparty thirdPartyStatements
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewDatabase opens a new database
|
||||
|
@ -48,6 +50,10 @@ func (d *Database) prepare() error {
|
|||
return err
|
||||
}
|
||||
|
||||
if err := d.thirdparty.prepare(d.db); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return d.txnID.prepare(d.db)
|
||||
}
|
||||
|
||||
|
@ -108,3 +114,36 @@ func (d *Database) GetLatestTxnID(
|
|||
) (int, error) {
|
||||
return d.txnID.selectTxnID(ctx)
|
||||
}
|
||||
|
||||
// GetProtocolDefinition retreives a JSON-encoded protocol definition given a
|
||||
// protocol ID
|
||||
func (d *Database) GetProtocolDefinition(
|
||||
ctx context.Context,
|
||||
protocolID string,
|
||||
) (string, error) {
|
||||
return d.thirdparty.selectProtocolDefinition(ctx, protocolID)
|
||||
}
|
||||
|
||||
// GetAllProtocolDefinitions retrieves a map of all known third party protocols
|
||||
func (d *Database) GetAllProtocolDefinitions(
|
||||
ctx context.Context,
|
||||
) (types.ThirdPartyProtocols, error) {
|
||||
return d.thirdparty.selectAllProtocolDefinitions(ctx)
|
||||
}
|
||||
|
||||
// StoreProtocolDefinition stores a protocol and its definition
|
||||
func (d *Database) StoreProtocolDefinition(
|
||||
ctx context.Context,
|
||||
protocolID, protocolDefinition string,
|
||||
) error {
|
||||
return d.thirdparty.insertProtocolDefinition(ctx, protocolID, protocolDefinition)
|
||||
}
|
||||
|
||||
// ClearProtocolDefinition clears all protocol definition entries in the
|
||||
// database. This is done on each startup to wipe old protocol definitions from
|
||||
// previous application services.
|
||||
func (d *Database) ClearProtocolDefinitions(
|
||||
ctx context.Context,
|
||||
) error {
|
||||
return d.thirdparty.clearProtocolDefinitions(ctx)
|
||||
}
|
||||
|
|
|
@ -0,0 +1,151 @@
|
|||
// Copyright 2018 New Vector Ltd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/matrix-org/dendrite/appservice/types"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const thirdPartySchema = `
|
||||
-- Stores protocol definitions for clients to later request
|
||||
CREATE TABLE IF NOT EXISTS appservice_third_party_protocol_def (
|
||||
-- The ID of the procotol
|
||||
protocol_id TEXT NOT NULL PRIMARY KEY,
|
||||
-- The JSON-encoded protocol definition
|
||||
protocol_definition TEXT NOT NULL,
|
||||
UNIQUE(protocol_id)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS appservice_third_party_protocol_def_id
|
||||
ON appservice_third_party_protocol_def(protocol_id);
|
||||
`
|
||||
|
||||
const selectProtocolDefinitionSQL = "" +
|
||||
"SELECT protocol_definition FROM appservice_third_party_protocol_def " +
|
||||
"WHERE protocol_id = $1"
|
||||
|
||||
const selectAllProtocolDefinitionsSQL = "" +
|
||||
"SELECT protocol_id, protocol_definition FROM appservice_third_party_protocol_def"
|
||||
|
||||
const insertProtocolDefinitionSQL = "" +
|
||||
"INSERT INTO appservice_third_party_protocol_def(protocol_id, protocol_definition) " +
|
||||
"VALUES ($1, $2)"
|
||||
|
||||
const clearProtocolDefinitionsSQL = "" +
|
||||
"TRUNCATE appservice_third_party_protocol_def"
|
||||
|
||||
type thirdPartyStatements struct {
|
||||
selectProtocolDefinitionStmt *sql.Stmt
|
||||
selectAllProtocolDefinitionsStmt *sql.Stmt
|
||||
insertProtocolDefinitionStmt *sql.Stmt
|
||||
clearProtocolDefinitionsStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func (s *thirdPartyStatements) prepare(db *sql.DB) (err error) {
|
||||
_, err = db.Exec(thirdPartySchema)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if s.selectProtocolDefinitionStmt, err = db.Prepare(selectProtocolDefinitionSQL); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if s.selectAllProtocolDefinitionsStmt, err = db.Prepare(selectAllProtocolDefinitionsSQL); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if s.insertProtocolDefinitionStmt, err = db.Prepare(insertProtocolDefinitionSQL); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if s.clearProtocolDefinitionsStmt, err = db.Prepare(clearProtocolDefinitionsSQL); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// selectProtocolDefinition returns a single protocol definition for a given ID.
|
||||
// Returns an empty string if the ID was not found.
|
||||
func (s *thirdPartyStatements) selectProtocolDefinition(
|
||||
ctx context.Context,
|
||||
protocolID string,
|
||||
) (protocolDefinition string, err error) {
|
||||
err = s.selectProtocolDefinitionStmt.QueryRowContext(
|
||||
ctx, protocolID,
|
||||
).Scan(&protocolDefinition)
|
||||
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// selectAllProtocolDefinitions returns all protcol IDs and definitions in the
|
||||
// database. Returns an empty map if no definitions were found.
|
||||
func (s *thirdPartyStatements) selectAllProtocolDefinitions(
|
||||
ctx context.Context,
|
||||
) (protocols types.ThirdPartyProtocols, err error) {
|
||||
protocolDefinitionRows, err := s.selectAllProtocolDefinitionsStmt.QueryContext(ctx)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
err = protocolDefinitionRows.Close()
|
||||
if err != nil {
|
||||
log.WithError(err).Fatalf("unable to close protocol definitions")
|
||||
}
|
||||
}()
|
||||
|
||||
for protocolDefinitionRows.Next() {
|
||||
var protocolID, protocolDefinition string
|
||||
if err = protocolDefinitionRows.Scan(&protocolID, &protocolDefinition); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
protocols[protocolID] = gomatrixserverlib.RawJSON(protocolDefinition)
|
||||
}
|
||||
|
||||
return protocols, nil
|
||||
}
|
||||
|
||||
// insertProtocolDefinition inserts a protocol ID along with its definition in
|
||||
// order for clients to later retreive it from the client-server API.
|
||||
func (s *thirdPartyStatements) insertProtocolDefinition(
|
||||
ctx context.Context,
|
||||
protocolID, protocolDefinition string,
|
||||
) (err error) {
|
||||
_, err = s.insertProtocolDefinitionStmt.ExecContext(
|
||||
ctx,
|
||||
protocolID,
|
||||
protocolDefinition,
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
// clearProtocolDefinitions removes all protocol definitions from the database.
|
||||
func (s *thirdPartyStatements) clearProtocolDefinitions(
|
||||
ctx context.Context,
|
||||
) (err error) {
|
||||
_, err = s.clearProtocolDefinitionsStmt.ExecContext(ctx)
|
||||
return
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
// Copyright 2018 New Vector Ltd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package types
|
||||
|
||||
import (
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// ThirdPartyProtocols is a map of all third party protocols supported by
|
||||
// connected application services.
|
||||
type ThirdPartyProtocols map[string]gomatrixserverlib.RawJSON
|
|
@ -0,0 +1,151 @@
|
|||
// Copyright 2018 Vector Creations Ltd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package workers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"math"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/matrix-org/dendrite/appservice/storage"
|
||||
"github.com/matrix-org/dendrite/common/config"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// Timeout for requests to an application service to complete
|
||||
requestTimeout = time.Second * 300
|
||||
)
|
||||
|
||||
const protocolPath = "/_matrix/app/unstable/thirdparty/protocol/"
|
||||
|
||||
// ThirdPartyWorker interfaces with a given application service on third party
|
||||
// network related information.
|
||||
// At the moment it simply asks for information on protocols that an application
|
||||
// service supports, then exits.
|
||||
func ThirdPartyWorker(
|
||||
db *storage.Database,
|
||||
appservice config.ApplicationService,
|
||||
) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Grab the HTTP client for sending requests to app services
|
||||
client := &http.Client{
|
||||
Timeout: requestTimeout,
|
||||
// TODO: Verify certificates
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: true, // nolint: gas
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
backoffCount := 0
|
||||
|
||||
// Retrieve protocol information from the application service
|
||||
for i := 0; i < len(appservice.Protocols); i++ {
|
||||
protocolID := appservice.Protocols[i]
|
||||
protocolDefinition, err := retreiveProtocolInformation(
|
||||
ctx, client, appservice, protocolID,
|
||||
)
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
"appservice": appservice.ID,
|
||||
"backoff_exponent": backoffCount,
|
||||
}).WithError(err).Warn("error contacting appservice thirdparty endpoints")
|
||||
|
||||
// Backoff before contacting again
|
||||
backoff(&backoffCount)
|
||||
|
||||
// Try this protocol again
|
||||
i--
|
||||
continue
|
||||
}
|
||||
|
||||
// Cache protocol definition for clients to request later
|
||||
storeProtocolDefinition(ctx, db, appservice, protocolID, protocolDefinition)
|
||||
}
|
||||
}
|
||||
|
||||
// backoff for the request amount of 2^number seconds
|
||||
// We want to support a few different use cases. Application services that don't
|
||||
// implement these endpoints and thus will always return an error. Application
|
||||
// services that are not currently up when Dendrite starts. Application services
|
||||
// that are broken for a while but will come back online later.
|
||||
// We can support all of these without being too resource intensive with
|
||||
// exponential backoff.
|
||||
func backoff(exponent *int) {
|
||||
// Calculate how long to backoff for
|
||||
backoffDuration := time.Duration(math.Pow(2, float64(*exponent)))
|
||||
backoffSeconds := time.Second * backoffDuration
|
||||
|
||||
if *exponent < 6 {
|
||||
*exponent++
|
||||
}
|
||||
|
||||
// Backoff
|
||||
time.Sleep(backoffSeconds)
|
||||
}
|
||||
|
||||
// retreiveProtocolInformation contacts an application service and asks for
|
||||
// information about a given protocol.
|
||||
func retreiveProtocolInformation(
|
||||
ctx context.Context,
|
||||
httpClient *http.Client,
|
||||
appservice config.ApplicationService,
|
||||
protocol string,
|
||||
) (string, error) {
|
||||
// Create a request to the application service
|
||||
requestURL := appservice.URL + protocolPath + protocol
|
||||
req, err := http.NewRequest(http.MethodGet, requestURL, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Perform the request
|
||||
resp, err := httpClient.Do(req.WithContext(ctx))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Check that the request was successful
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
// TODO: Handle non-200 error codes from application services
|
||||
return "", fmt.Errorf("non-OK status code %d returned from AS", resp.StatusCode)
|
||||
}
|
||||
|
||||
// Read the response body
|
||||
body, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(body), nil
|
||||
}
|
||||
|
||||
// storeProtocolDefinition stores a protocol definition along with the protocol
|
||||
// ID in the database
|
||||
func storeProtocolDefinition(
|
||||
ctx context.Context,
|
||||
db *storage.Database,
|
||||
appservice config.ApplicationService,
|
||||
protocolID, protocolDefinition string,
|
||||
) error {
|
||||
return db.StoreProtocolDefinition(ctx, protocolID, protocolDefinition)
|
||||
}
|
|
@ -20,7 +20,6 @@ import (
|
|||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
|
@ -31,38 +30,19 @@ import (
|
|||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
const (
|
||||
// Maximum size of events sent in each transaction.
|
||||
transactionBatchSize = 50
|
||||
// Timeout for sending a single transaction to an application service.
|
||||
transactionTimeout = time.Second * 60
|
||||
)
|
||||
|
||||
// SetupTransactionWorkers spawns a separate goroutine for each application
|
||||
// service. Each of these "workers" handle taking all events intended for their
|
||||
// app service, batch them up into a single transaction (up to a max transaction
|
||||
// TransactionWorker is a goroutine that sends any queued events to the application service
|
||||
// it is given. Each worker handles taking all events intended for their app
|
||||
// service, batch them up into a single transaction (up to a max transaction
|
||||
// size), then send that off to the AS's /transactions/{txnID} endpoint. It also
|
||||
// handles exponentially backing off in case the AS isn't currently available.
|
||||
func SetupTransactionWorkers(
|
||||
appserviceDB *storage.Database,
|
||||
workerStates []types.ApplicationServiceWorkerState,
|
||||
) error {
|
||||
// Create a worker that handles transmitting events to a single homeserver
|
||||
for _, workerState := range workerStates {
|
||||
// Don't create a worker if this AS doesn't want to receive events
|
||||
if workerState.AppService.URL != "" {
|
||||
go worker(appserviceDB, workerState)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// worker is a goroutine that sends any queued events to the application service
|
||||
// it is given.
|
||||
func worker(db *storage.Database, ws types.ApplicationServiceWorkerState) {
|
||||
log.WithFields(log.Fields{
|
||||
"appservice": ws.AppService.ID,
|
||||
}).Info("starting application service")
|
||||
func TransactionWorker(db *storage.Database, ws types.ApplicationServiceWorkerState) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Create a HTTP client for sending requests to app services
|
||||
|
@ -107,8 +87,13 @@ func worker(db *storage.Database, ws types.ApplicationServiceWorkerState) {
|
|||
// Backoff if the application service does not respond
|
||||
err = send(client, ws.AppService, txnID, transactionJSON)
|
||||
if err != nil {
|
||||
log.WithFields(log.Fields{
|
||||
"appservice": ws.AppService.ID,
|
||||
"backoff_exponent": ws.Backoff,
|
||||
}).WithError(err).Warnf("unable to send transactions successfully, backing off")
|
||||
|
||||
// Backoff
|
||||
backoff(&ws, err)
|
||||
backoff(&ws.Backoff)
|
||||
continue
|
||||
}
|
||||
|
||||
|
@ -132,26 +117,6 @@ func worker(db *storage.Database, ws types.ApplicationServiceWorkerState) {
|
|||
}
|
||||
}
|
||||
|
||||
// backoff pauses the calling goroutine for a 2^some backoff exponent seconds
|
||||
func backoff(ws *types.ApplicationServiceWorkerState, err error) {
|
||||
// Calculate how long to backoff for
|
||||
backoffDuration := time.Duration(math.Pow(2, float64(ws.Backoff)))
|
||||
backoffSeconds := time.Second * backoffDuration
|
||||
|
||||
log.WithFields(log.Fields{
|
||||
"appservice": ws.AppService.ID,
|
||||
}).WithError(err).Warnf("unable to send transactions successfully, backing off for %ds",
|
||||
backoffDuration)
|
||||
|
||||
ws.Backoff++
|
||||
if ws.Backoff > 6 {
|
||||
ws.Backoff = 6
|
||||
}
|
||||
|
||||
// Backoff
|
||||
time.Sleep(backoffSeconds)
|
||||
}
|
||||
|
||||
// createTransaction takes in a slice of AS events, stores them in an AS
|
||||
// transaction, and JSON-encodes the results.
|
||||
func createTransaction(
|
||||
|
|
|
@ -376,6 +376,20 @@ func Setup(
|
|||
}),
|
||||
).Methods(http.MethodPut, http.MethodOptions)
|
||||
|
||||
// Third party lookups
|
||||
r0mux.Handle("/thirdparty/protocol/{protocolID}",
|
||||
common.MakeExternalAPI("get_protocols", func(req *http.Request) util.JSONResponse {
|
||||
vars := mux.Vars(req)
|
||||
return GetThirdPartyProtocol(req, asAPI, vars["protocolID"])
|
||||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
r0mux.Handle("/thirdparty/protocols",
|
||||
common.MakeExternalAPI("get_protocols", func(req *http.Request) util.JSONResponse {
|
||||
return GetThirdPartyProtocols(req, asAPI)
|
||||
}),
|
||||
).Methods(http.MethodGet, http.MethodOptions)
|
||||
|
||||
// Stub implementations for sytest
|
||||
r0mux.Handle("/events",
|
||||
common.MakeExternalAPI("events", func(req *http.Request) util.JSONResponse {
|
||||
|
|
|
@ -0,0 +1,74 @@
|
|||
// Copyright 2018 Vector Creations Ltd
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package routing
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
appserviceAPI "github.com/matrix-org/dendrite/appservice/api"
|
||||
"github.com/matrix-org/dendrite/clientapi/httputil"
|
||||
"github.com/matrix-org/util"
|
||||
)
|
||||
|
||||
// GetThirdPartyProtocol returns the protocol definition of a single, given
|
||||
// protocol ID
|
||||
func GetThirdPartyProtocol(
|
||||
req *http.Request,
|
||||
asAPI appserviceAPI.AppServiceQueryAPI,
|
||||
protocolID string,
|
||||
) util.JSONResponse {
|
||||
// Retrieve a single protocol definition from the appservice component
|
||||
queryReq := appserviceAPI.GetProtocolDefinitionRequest{
|
||||
ProtocolID: protocolID,
|
||||
}
|
||||
var queryRes appserviceAPI.GetProtocolDefinitionResponse
|
||||
if err := asAPI.GetProtocolDefinition(req.Context(), &queryReq, &queryRes); err != nil {
|
||||
return httputil.LogThenError(req, err)
|
||||
}
|
||||
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
JSON: queryRes.ProtocolDefinition,
|
||||
}
|
||||
}
|
||||
|
||||
// GetThirdPartyProtocols returns all known third party protocols provided by
|
||||
// application services connected to this homeserver
|
||||
func GetThirdPartyProtocols(
|
||||
req *http.Request,
|
||||
asAPI appserviceAPI.AppServiceQueryAPI,
|
||||
) util.JSONResponse {
|
||||
// Retrieve all known protocols from appservice component
|
||||
queryReq := appserviceAPI.GetAllProtocolDefinitionsRequest{}
|
||||
var queryRes appserviceAPI.GetAllProtocolDefinitionsResponse
|
||||
if err := asAPI.GetAllProtocolDefinitions(req.Context(), &queryReq, &queryRes); err != nil {
|
||||
return httputil.LogThenError(req, err)
|
||||
}
|
||||
|
||||
// TODO: Check what we get if no protocols defined by anyone
|
||||
|
||||
// Marshal protocols to JSON
|
||||
protocolJSON, err := json.Marshal(queryRes.Protocols)
|
||||
if err != nil {
|
||||
return httputil.LogThenError(req, err)
|
||||
}
|
||||
|
||||
// Return protocol IDs along with definitions
|
||||
return util.JSONResponse{
|
||||
Code: http.StatusOK,
|
||||
JSON: protocolJSON,
|
||||
}
|
||||
}
|
|
@ -221,6 +221,7 @@ func appendExclusiveNamespaceRegexs(
|
|||
func checkErrors(config *Dendrite) (err error) {
|
||||
var idMap = make(map[string]bool)
|
||||
var tokenMap = make(map[string]bool)
|
||||
var protocolMap = make(map[string]bool)
|
||||
|
||||
// Compile regexp object for checking groupIDs
|
||||
groupIDRegexp := regexp.MustCompile(`\+.*:.*`)
|
||||
|
@ -236,38 +237,56 @@ func checkErrors(config *Dendrite) (err error) {
|
|||
}
|
||||
}
|
||||
|
||||
// Check if we've already seen this ID. No two application services
|
||||
// can have the same ID or token.
|
||||
if idMap[appservice.ID] {
|
||||
return configErrors([]string{fmt.Sprintf(
|
||||
"Application service ID %s must be unique", appservice.ID,
|
||||
)})
|
||||
if err = duplicationCheck(appservice, &idMap, &tokenMap, &protocolMap); err != nil {
|
||||
return err
|
||||
}
|
||||
// Check if we've already seen this token
|
||||
if tokenMap[appservice.ASToken] {
|
||||
return configErrors([]string{fmt.Sprintf(
|
||||
"Application service Token %s must be unique", appservice.ASToken,
|
||||
)})
|
||||
}
|
||||
|
||||
// Add the id/token to their respective maps if we haven't already
|
||||
// seen them.
|
||||
idMap[appservice.ID] = true
|
||||
tokenMap[appservice.ASToken] = true
|
||||
|
||||
// TODO: Remove once rate_limited is implemented
|
||||
if appservice.RateLimited {
|
||||
log.Warn("WARNING: Application service option rate_limited is currently unimplemented")
|
||||
}
|
||||
// TODO: Remove once protocols is implemented
|
||||
if len(appservice.Protocols) > 0 {
|
||||
log.Warn("WARNING: Application service option protocols is currently unimplemented")
|
||||
log.Warn("WARNING: Application service option 'rate_limited' is currently unimplemented")
|
||||
}
|
||||
}
|
||||
|
||||
return setupRegexps(config)
|
||||
}
|
||||
|
||||
// duplicationCheck returns an error if any application service configuration
|
||||
// entries that are supposed to be unique appear than once.
|
||||
func duplicationCheck(appservice ApplicationService, idMap, tokenMap, protocolMap *map[string]bool) error {
|
||||
// Check if we've already seen this ID. No two application services
|
||||
// can have the same ID or token.
|
||||
if (*idMap)[appservice.ID] {
|
||||
return configErrors([]string{fmt.Sprintf(
|
||||
"Application service ID '%s' must be unique", appservice.ID,
|
||||
)})
|
||||
}
|
||||
// Check if we've already seen this token
|
||||
if (*tokenMap)[appservice.ASToken] {
|
||||
return configErrors([]string{fmt.Sprintf(
|
||||
"Application service token '%s' must be unique", appservice.ASToken,
|
||||
)})
|
||||
}
|
||||
|
||||
// Add the id/token to their respective maps if we haven't already
|
||||
// seen them.
|
||||
(*idMap)[appservice.ID] = true
|
||||
(*tokenMap)[appservice.ASToken] = true
|
||||
|
||||
// Check if any application services are already handling this protocol
|
||||
for _, protocol := range appservice.Protocols {
|
||||
if (*protocolMap)[protocol] {
|
||||
return configErrors([]string{fmt.Sprintf(
|
||||
"Application service protocol '%s' must be unique", protocol,
|
||||
)})
|
||||
}
|
||||
|
||||
// Add the protocol to the map of seen protocols
|
||||
(*protocolMap)[protocol] = true
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateNamespace returns nil or an error based on whether a given
|
||||
// application service namespace is valid. A namespace is valid if it has the
|
||||
// required fields, and its regex is correct.
|
||||
|
@ -287,12 +306,12 @@ func validateNamespace(
|
|||
// Check if GroupID for the users namespace is in the correct format
|
||||
if key == "users" && namespace.GroupID != "" {
|
||||
// TODO: Remove once group_id is implemented
|
||||
log.Warn("WARNING: Application service option group_id is currently unimplemented")
|
||||
log.Warn("WARNING: Application service option 'group_id' is currently unimplemented")
|
||||
|
||||
correctFormat := groupIDRegexp.MatchString(namespace.GroupID)
|
||||
if !correctFormat {
|
||||
return configErrors([]string{fmt.Sprintf(
|
||||
"Invalid user group_id field for application service %s.",
|
||||
"Invalid user 'group_id' field for application service %s.",
|
||||
appservice.ID,
|
||||
)})
|
||||
}
|
||||
|
|
|
@ -372,7 +372,7 @@ func (h *httpRoomserverQueryAPI) QueryMembershipForUser(
|
|||
defer span.Finish()
|
||||
|
||||
apiURL := h.roomserverURL + RoomserverQueryMembershipForUserPath
|
||||
return postJSON(ctx, span, h.httpClient, apiURL, request, response)
|
||||
return commonHTTP.PostJSON(ctx, span, h.httpClient, apiURL, request, response)
|
||||
}
|
||||
|
||||
// QueryMembershipsForRoom implements RoomserverQueryAPI
|
||||
|
|
Loading…
Reference in a new issue