Merge remote-tracking branch 'origin/erikj/opentracing_sql' into anoa/opentracing

This commit is contained in:
Andrew Morgan 2018-07-23 16:18:48 +01:00
commit 8d0a9d7ccf
67 changed files with 1824 additions and 243 deletions

View file

@ -29,6 +29,7 @@ import (
"github.com/matrix-org/dendrite/appservice/workers"
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts"
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/common/basecomponent"
"github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/dendrite/common/transactions"
@ -41,6 +42,7 @@ import (
// component.
func SetupAppServiceAPIComponent(
base *basecomponent.BaseDendrite,
tracers *common.Tracers,
accountsDB *accounts.Database,
deviceDB *devices.Database,
federation *gomatrixserverlib.FederationClient,
@ -48,8 +50,10 @@ func SetupAppServiceAPIComponent(
roomserverQueryAPI roomserverAPI.RoomserverQueryAPI,
transactionsCache *transactions.Cache,
) appserviceAPI.AppServiceQueryAPI {
tracer := tracers.SetupNewTracer("Dendrite - Appservice")
// Create a connection to the appservice postgres DB
appserviceDB, err := storage.NewDatabase(string(base.Cfg.Database.AppService))
appserviceDB, err := storage.NewDatabase(tracers, string(base.Cfg.Database.AppService))
if err != nil {
logrus.WithError(err).Panicf("failed to connect to appservice db")
}
@ -90,6 +94,7 @@ func SetupAppServiceAPIComponent(
consumer := consumers.NewOutputRoomEventConsumer(
base.Cfg, base.KafkaConsumer, accountsDB, appserviceDB,
roomserverQueryAPI, roomserverAliasAPI, workerStates,
tracer,
)
if err := consumer.Start(); err != nil {
logrus.WithError(err).Panicf("failed to start app service roomserver consumer")
@ -103,7 +108,7 @@ func SetupAppServiceAPIComponent(
// Set up HTTP Endpoints
routing.Setup(
base.APIMux, *base.Cfg, roomserverQueryAPI, roomserverAliasAPI,
accountsDB, federation, transactionsCache,
accountsDB, federation, transactionsCache, tracer,
)
return &appserviceQueryAPI

View file

@ -25,6 +25,7 @@ import (
"github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib"
opentracing "github.com/opentracing/opentracing-go"
log "github.com/sirupsen/logrus"
sarama "gopkg.in/Shopify/sarama.v1"
@ -38,6 +39,7 @@ type OutputRoomEventConsumer struct {
query api.RoomserverQueryAPI
alias api.RoomserverAliasAPI
serverName string
tracer opentracing.Tracer
workerStates []types.ApplicationServiceWorkerState
}
@ -51,6 +53,7 @@ func NewOutputRoomEventConsumer(
queryAPI api.RoomserverQueryAPI,
aliasAPI api.RoomserverAliasAPI,
workerStates []types.ApplicationServiceWorkerState,
tracer opentracing.Tracer,
) *OutputRoomEventConsumer {
consumer := common.ContinualConsumer{
Topic: string(cfg.Kafka.Topics.OutputRoomEvent),
@ -64,6 +67,7 @@ func NewOutputRoomEventConsumer(
query: queryAPI,
alias: aliasAPI,
serverName: string(cfg.Matrix.ServerName),
tracer: tracer,
workerStates: workerStates,
}
consumer.ProcessMessage = s.onMessage

View file

@ -25,6 +25,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
opentracing "github.com/opentracing/opentracing-go"
)
const pathPrefixApp = "/_matrix/app/r0"
@ -37,11 +38,12 @@ func Setup(
accountDB *accounts.Database, // nolint: unparam
federation *gomatrixserverlib.FederationClient, // nolint: unparam
transactionsCache *transactions.Cache, // nolint: unparam
tracer opentracing.Tracer,
) {
appMux := apiMux.PathPrefix(pathPrefixApp).Subrouter()
appMux.Handle("/alias",
common.MakeExternalAPI("alias", func(req *http.Request) util.JSONResponse {
common.MakeExternalAPI(tracer, "alias", func(req *http.Request) util.JSONResponse {
// TODO: Implement
return util.JSONResponse{
Code: http.StatusOK,
@ -50,7 +52,7 @@ func Setup(
}),
).Methods(http.MethodGet, http.MethodOptions)
appMux.Handle("/user",
common.MakeExternalAPI("user", func(req *http.Request) util.JSONResponse {
common.MakeExternalAPI(tracer, "user", func(req *http.Request) util.JSONResponse {
// TODO: Implement
return util.JSONResponse{
Code: http.StatusOK,

View file

@ -20,6 +20,7 @@ import (
// Import postgres database driver
_ "github.com/lib/pq"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/gomatrixserverlib"
)
@ -31,10 +32,10 @@ type Database struct {
}
// NewDatabase opens a new database
func NewDatabase(dataSourceName string) (*Database, error) {
func NewDatabase(tracers *common.Tracers, dataSourceName string) (*Database, error) {
var result Database
var err error
if result.db, err = sql.Open("postgres", dataSourceName); err != nil {
if result.db, err = common.OpenPostgresWithTracing(tracers, "appservice", dataSourceName); err != nil {
return nil, err
}
if err = result.prepare(); err != nil {

View file

@ -24,7 +24,6 @@ import (
"github.com/matrix-org/gomatrixserverlib"
"golang.org/x/crypto/bcrypt"
// Import the postgres database driver.
_ "github.com/lib/pq"
)
// Database represents an account database
@ -41,12 +40,14 @@ type Database struct {
}
// NewDatabase creates a new accounts and profiles database
func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) {
func NewDatabase(tracers *common.Tracers, dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) {
var db *sql.DB
var err error
if db, err = sql.Open("postgres", dataSourceName); err != nil {
if db, err = common.OpenPostgresWithTracing(tracers, "accounts", dataSourceName); err != nil {
return nil, err
}
// TODO: Some files have prepare in a separate method such as in appservice.
// Some do not. We should be consistent.
partitions := common.PartitionOffsetStatements{}
if err = partitions.Prepare(db, "account"); err != nil {
return nil, err

View file

@ -35,10 +35,10 @@ type Database struct {
}
// NewDatabase creates a new device database
func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) {
func NewDatabase(tracers *common.Tracers, dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) {
var db *sql.DB
var err error
if db, err = sql.Open("postgres", dataSourceName); err != nil {
if db, err = common.OpenPostgresWithTracing(tracers, "devices", dataSourceName); err != nil {
return nil, err
}
d := devicesStatements{}

View file

@ -20,6 +20,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/consumers"
"github.com/matrix-org/dendrite/clientapi/producers"
"github.com/matrix-org/dendrite/clientapi/routing"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/common/basecomponent"
"github.com/matrix-org/dendrite/common/transactions"
roomserverAPI "github.com/matrix-org/dendrite/roomserver/api"
@ -31,6 +32,7 @@ import (
// component.
func SetupClientAPIComponent(
base *basecomponent.BaseDendrite,
tracers *common.Tracers,
deviceDB *devices.Database,
accountsDB *accounts.Database,
federation *gomatrixserverlib.FederationClient,
@ -40,6 +42,8 @@ func SetupClientAPIComponent(
queryAPI roomserverAPI.RoomserverQueryAPI,
transactionsCache *transactions.Cache,
) {
tracer := tracers.SetupNewTracer("Dendrite - ClientAPI")
roomserverProducer := producers.NewRoomserverProducer(inputAPI)
userUpdateProducer := &producers.UserUpdateProducer{
@ -53,7 +57,7 @@ func SetupClientAPIComponent(
}
consumer := consumers.NewOutputRoomEventConsumer(
base.Cfg, base.KafkaConsumer, accountsDB, queryAPI,
base.Cfg, base.KafkaConsumer, accountsDB, queryAPI, tracer,
)
if err := consumer.Start(); err != nil {
logrus.WithError(err).Panicf("failed to start room server consumer")
@ -62,6 +66,6 @@ func SetupClientAPIComponent(
routing.Setup(
base.APIMux, *base.Cfg, roomserverProducer, queryAPI, aliasAPI,
accountsDB, deviceDB, federation, *keyRing, userUpdateProducer,
syncProducer, transactionsCache,
syncProducer, transactionsCache, tracer,
)
}

View file

@ -23,6 +23,7 @@ import (
"github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib"
opentracing "github.com/opentracing/opentracing-go"
log "github.com/sirupsen/logrus"
sarama "gopkg.in/Shopify/sarama.v1"
@ -34,6 +35,7 @@ type OutputRoomEventConsumer struct {
db *accounts.Database
query api.RoomserverQueryAPI
serverName string
tracer opentracing.Tracer
}
// NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers.
@ -42,6 +44,7 @@ func NewOutputRoomEventConsumer(
kafkaConsumer sarama.Consumer,
store *accounts.Database,
queryAPI api.RoomserverQueryAPI,
tracer opentracing.Tracer,
) *OutputRoomEventConsumer {
consumer := common.ContinualConsumer{
@ -54,6 +57,7 @@ func NewOutputRoomEventConsumer(
db: store,
query: queryAPI,
serverName: string(cfg.Matrix.ServerName),
tracer: tracer,
}
consumer.ProcessMessage = s.onMessage
@ -77,6 +81,9 @@ func (s *OutputRoomEventConsumer) onMessage(msg *sarama.ConsumerMessage) error {
return nil
}
ctx, span := output.StartSpanAndReplaceContext(context.Background(), s.tracer)
defer span.Finish()
if output.Type != api.OutputTypeNewRoomEvent {
log.WithField("type", output.Type).Debug(
"roomserver output log: ignoring unknown output type",
@ -96,7 +103,7 @@ func (s *OutputRoomEventConsumer) onMessage(msg *sarama.ConsumerMessage) error {
return err
}
return s.db.UpdateMemberships(context.TODO(), events, output.NewRoomEvent.RemovesStateEventIDs)
return s.db.UpdateMemberships(ctx, events, output.NewRoomEvent.RemovesStateEventIDs)
}
// lookupStateEvents looks up the state events that are added by a new event.

View file

@ -32,6 +32,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
opentracing "github.com/opentracing/opentracing-go"
)
const pathPrefixV1 = "/_matrix/client/api/v1"
@ -51,10 +52,10 @@ func Setup(
userUpdateProducer *producers.UserUpdateProducer,
syncProducer *producers.SyncAPIProducer,
transactionsCache *transactions.Cache,
tracer opentracing.Tracer,
) {
apiMux.Handle("/_matrix/client/versions",
common.MakeExternalAPI("versions", func(req *http.Request) util.JSONResponse {
common.MakeExternalAPI(tracer, "versions", func(req *http.Request) util.JSONResponse {
return util.JSONResponse{
Code: http.StatusOK,
JSON: struct {
@ -76,12 +77,12 @@ func Setup(
authData := auth.Data{accountDB, deviceDB, cfg.Derived.ApplicationServices}
r0mux.Handle("/createRoom",
common.MakeAuthAPI("createRoom", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "createRoom", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
return CreateRoom(req, device, cfg, producer, accountDB, aliasAPI)
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/join/{roomIDOrAlias}",
common.MakeAuthAPI("join", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "join", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
return JoinRoomByIDOrAlias(
req, device, vars["roomIDOrAlias"], cfg, federation, producer, queryAPI, aliasAPI, keyRing, accountDB,
@ -89,19 +90,19 @@ func Setup(
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/{membership:(?:join|kick|ban|unban|leave|invite)}",
common.MakeAuthAPI("membership", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "membership", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
return SendMembership(req, accountDB, device, vars["roomID"], vars["membership"], cfg, queryAPI, producer)
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/send/{eventType}",
common.MakeAuthAPI("send_message", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "send_message", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, nil, cfg, queryAPI, producer, nil)
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/send/{eventType}/{txnID}",
common.MakeAuthAPI("send_message", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "send_message", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
txnID := vars["txnID"]
return SendEvent(req, device, vars["roomID"], vars["eventType"], &txnID,
@ -109,7 +110,7 @@ func Setup(
}),
).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/state/{eventType:[^/]+/?}",
common.MakeAuthAPI("send_message", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "send_message", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
emptyString := ""
eventType := vars["eventType"]
@ -121,54 +122,54 @@ func Setup(
}),
).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/state/{eventType}/{stateKey}",
common.MakeAuthAPI("send_message", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "send_message", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
stateKey := vars["stateKey"]
return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, &stateKey, cfg, queryAPI, producer, nil)
}),
).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/register", common.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse {
r0mux.Handle("/register", common.MakeExternalAPI(tracer, "register", func(req *http.Request) util.JSONResponse {
return Register(req, accountDB, deviceDB, &cfg)
})).Methods(http.MethodPost, http.MethodOptions)
v1mux.Handle("/register", common.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse {
v1mux.Handle("/register", common.MakeExternalAPI(tracer, "register", func(req *http.Request) util.JSONResponse {
return LegacyRegister(req, accountDB, deviceDB, &cfg)
})).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/register/available", common.MakeExternalAPI("registerAvailable", func(req *http.Request) util.JSONResponse {
r0mux.Handle("/register/available", common.MakeExternalAPI(tracer, "registerAvailable", func(req *http.Request) util.JSONResponse {
return RegisterAvailable(req, accountDB)
})).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/directory/room/{roomAlias}",
common.MakeAuthAPI("directory_room", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "directory_room", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
return DirectoryRoom(req, vars["roomAlias"], federation, &cfg, aliasAPI)
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/directory/room/{roomAlias}",
common.MakeAuthAPI("directory_room", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "directory_room", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
return SetLocalAlias(req, device, vars["roomAlias"], &cfg, aliasAPI)
}),
).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/directory/room/{roomAlias}",
common.MakeAuthAPI("directory_room", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "directory_room", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
return RemoveLocalAlias(req, device, vars["roomAlias"], aliasAPI)
}),
).Methods(http.MethodDelete, http.MethodOptions)
r0mux.Handle("/logout",
common.MakeAuthAPI("logout", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "logout", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
return Logout(req, deviceDB, device)
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/logout/all",
common.MakeAuthAPI("logout", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "logout", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
return LogoutAll(req, deviceDB, device)
}),
).Methods(http.MethodPost, http.MethodOptions)
@ -176,13 +177,13 @@ func Setup(
// Stub endpoints required by Riot
r0mux.Handle("/login",
common.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse {
common.MakeExternalAPI(tracer, "login", func(req *http.Request) util.JSONResponse {
return Login(req, accountDB, deviceDB, cfg)
}),
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)
r0mux.Handle("/pushrules/",
common.MakeExternalAPI("push_rules", func(req *http.Request) util.JSONResponse {
common.MakeExternalAPI(tracer, "push_rules", func(req *http.Request) util.JSONResponse {
// TODO: Implement push rules API
res := json.RawMessage(`{
"global": {
@ -201,14 +202,14 @@ func Setup(
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/user/{userId}/filter",
common.MakeAuthAPI("put_filter", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "put_filter", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
return PutFilter(req, device, accountDB, vars["userId"])
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/user/{userId}/filter/{filterId}",
common.MakeAuthAPI("get_filter", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "get_filter", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
return GetFilter(req, device, accountDB, vars["userId"], vars["filterId"])
}),
@ -217,21 +218,21 @@ func Setup(
// Riot user settings
r0mux.Handle("/profile/{userID}",
common.MakeExternalAPI("profile", func(req *http.Request) util.JSONResponse {
common.MakeExternalAPI(tracer, "profile", func(req *http.Request) util.JSONResponse {
vars := mux.Vars(req)
return GetProfile(req, accountDB, vars["userID"])
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/profile/{userID}/avatar_url",
common.MakeExternalAPI("profile_avatar_url", func(req *http.Request) util.JSONResponse {
common.MakeExternalAPI(tracer, "profile_avatar_url", func(req *http.Request) util.JSONResponse {
vars := mux.Vars(req)
return GetAvatarURL(req, accountDB, vars["userID"])
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/profile/{userID}/avatar_url",
common.MakeAuthAPI("profile_avatar_url", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "profile_avatar_url", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
return SetAvatarURL(req, accountDB, device, vars["userID"], userUpdateProducer, &cfg, producer, queryAPI)
}),
@ -240,14 +241,14 @@ func Setup(
// PUT requests, so we need to allow this method
r0mux.Handle("/profile/{userID}/displayname",
common.MakeExternalAPI("profile_displayname", func(req *http.Request) util.JSONResponse {
common.MakeExternalAPI(tracer, "profile_displayname", func(req *http.Request) util.JSONResponse {
vars := mux.Vars(req)
return GetDisplayName(req, accountDB, vars["userID"])
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/profile/{userID}/displayname",
common.MakeAuthAPI("profile_displayname", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "profile_displayname", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
return SetDisplayName(req, accountDB, device, vars["userID"], userUpdateProducer, &cfg, producer, queryAPI)
}),
@ -256,32 +257,32 @@ func Setup(
// PUT requests, so we need to allow this method
r0mux.Handle("/account/3pid",
common.MakeAuthAPI("account_3pid", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "account_3pid", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
return GetAssociated3PIDs(req, accountDB, device)
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/account/3pid",
common.MakeAuthAPI("account_3pid", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "account_3pid", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
return CheckAndSave3PIDAssociation(req, accountDB, device, cfg)
}),
).Methods(http.MethodPost, http.MethodOptions)
unstableMux.Handle("/account/3pid/delete",
common.MakeAuthAPI("account_3pid", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "account_3pid", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
return Forget3PID(req, accountDB)
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/{path:(?:account/3pid|register)}/email/requestToken",
common.MakeExternalAPI("account_3pid_request_token", func(req *http.Request) util.JSONResponse {
common.MakeExternalAPI(tracer, "account_3pid_request_token", func(req *http.Request) util.JSONResponse {
return RequestEmailToken(req, accountDB, cfg)
}),
).Methods(http.MethodPost, http.MethodOptions)
// Riot logs get flooded unless this is handled
r0mux.Handle("/presence/{userID}/status",
common.MakeExternalAPI("presence", func(req *http.Request) util.JSONResponse {
common.MakeExternalAPI(tracer, "presence", func(req *http.Request) util.JSONResponse {
// TODO: Set presence (probably the responsibility of a presence server not clientapi)
return util.JSONResponse{
Code: http.StatusOK,
@ -291,13 +292,13 @@ func Setup(
).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/voip/turnServer",
common.MakeAuthAPI("turn_server", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "turn_server", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
return RequestTurnServer(req, device, cfg)
}),
).Methods(http.MethodGet, http.MethodOptions)
unstableMux.Handle("/thirdparty/protocols",
common.MakeExternalAPI("thirdparty_protocols", func(req *http.Request) util.JSONResponse {
common.MakeExternalAPI(tracer, "thirdparty_protocols", func(req *http.Request) util.JSONResponse {
// TODO: Return the third party protcols
return util.JSONResponse{
Code: http.StatusOK,
@ -307,7 +308,7 @@ func Setup(
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/initialSync",
common.MakeExternalAPI("rooms_initial_sync", func(req *http.Request) util.JSONResponse {
common.MakeExternalAPI(tracer, "rooms_initial_sync", func(req *http.Request) util.JSONResponse {
// TODO: Allow people to peek into rooms.
return util.JSONResponse{
Code: http.StatusForbidden,
@ -317,62 +318,62 @@ func Setup(
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/user/{userID}/account_data/{type}",
common.MakeAuthAPI("user_account_data", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "user_account_data", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
return SaveAccountData(req, accountDB, device, vars["userID"], "", vars["type"], syncProducer)
}),
).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/user/{userID}/rooms/{roomID}/account_data/{type}",
common.MakeAuthAPI("user_account_data", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "user_account_data", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
return SaveAccountData(req, accountDB, device, vars["userID"], vars["roomID"], vars["type"], syncProducer)
}),
).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/members",
common.MakeAuthAPI("rooms_members", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "rooms_members", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
return GetMemberships(req, device, vars["roomID"], false, cfg, queryAPI)
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/joined_members",
common.MakeAuthAPI("rooms_members", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "rooms_members", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
return GetMemberships(req, device, vars["roomID"], true, cfg, queryAPI)
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/read_markers",
common.MakeExternalAPI("rooms_read_markers", func(req *http.Request) util.JSONResponse {
common.MakeExternalAPI(tracer, "rooms_read_markers", func(req *http.Request) util.JSONResponse {
// TODO: return the read_markers.
return util.JSONResponse{Code: http.StatusOK, JSON: struct{}{}}
}),
).Methods(http.MethodPost, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/typing/{userID}",
common.MakeExternalAPI("rooms_typing", func(req *http.Request) util.JSONResponse {
common.MakeExternalAPI(tracer, "rooms_typing", func(req *http.Request) util.JSONResponse {
// TODO: handling typing
return util.JSONResponse{Code: http.StatusOK, JSON: struct{}{}}
}),
).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/devices",
common.MakeAuthAPI("get_devices", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "get_devices", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
return GetDevicesByLocalpart(req, deviceDB, device)
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/devices/{deviceID}",
common.MakeAuthAPI("get_device", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "get_device", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
return GetDeviceByID(req, deviceDB, device, vars["deviceID"])
}),
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/devices/{deviceID}",
common.MakeAuthAPI("device_data", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "device_data", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
return UpdateDeviceByID(req, deviceDB, device, vars["deviceID"])
}),
@ -380,7 +381,7 @@ func Setup(
// Stub implementations for sytest
r0mux.Handle("/events",
common.MakeExternalAPI("events", func(req *http.Request) util.JSONResponse {
common.MakeExternalAPI(tracer, "events", func(req *http.Request) util.JSONResponse {
return util.JSONResponse{Code: http.StatusOK, JSON: map[string]interface{}{
"chunk": []interface{}{},
"start": "",
@ -390,7 +391,7 @@ func Setup(
).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/initialSync",
common.MakeExternalAPI("initial_sync", func(req *http.Request) util.JSONResponse {
common.MakeExternalAPI(tracer, "initial_sync", func(req *http.Request) util.JSONResponse {
return util.JSONResponse{Code: http.StatusOK, JSON: map[string]interface{}{
"end": "",
}}

View file

@ -22,6 +22,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts"
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/gomatrixserverlib"
)
@ -63,7 +64,7 @@ func main() {
serverName := gomatrixserverlib.ServerName(*serverNameStr)
accountDB, err := accounts.NewDatabase(*database, serverName)
accountDB, err := accounts.NewDatabase(common.NoopTracers(), *database, serverName)
if err != nil {
fmt.Println(err.Error())
os.Exit(1)
@ -78,7 +79,7 @@ func main() {
os.Exit(1)
}
deviceDB, err := devices.NewDatabase(*database, serverName)
deviceDB, err := devices.NewDatabase(common.NoopTracers(), *database, serverName)
if err != nil {
fmt.Println(err.Error())
os.Exit(1)

View file

@ -16,13 +16,18 @@ package main
import (
"github.com/matrix-org/dendrite/appservice"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/common/basecomponent"
"github.com/matrix-org/dendrite/common/transactions"
)
func main() {
cfg := basecomponent.ParseFlags()
base := basecomponent.NewBaseDendrite(cfg, "AppServiceAPI")
tracers := common.NewTracers(cfg)
defer tracers.Close() // nolint: errcheck
base := basecomponent.NewBaseDendrite(cfg, tracers, "AppServiceAPI")
defer base.Close() // nolint: errcheck
accountDB := base.CreateAccountsDB()
@ -32,7 +37,7 @@ func main() {
cache := transactions.New()
appservice.SetupAppServiceAPIComponent(
base, accountDB, deviceDB, federation, alias, query, cache,
base, tracers, accountDB, deviceDB, federation, alias, query, cache,
)
base.SetupAndServeHTTP(string(base.Cfg.Listen.FederationSender))

View file

@ -16,6 +16,7 @@ package main
import (
"github.com/matrix-org/dendrite/clientapi"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/common/basecomponent"
"github.com/matrix-org/dendrite/common/keydb"
"github.com/matrix-org/dendrite/common/transactions"
@ -24,7 +25,10 @@ import (
func main() {
cfg := basecomponent.ParseFlags()
base := basecomponent.NewBaseDendrite(cfg, "ClientAPI")
tracers := common.NewTracers(cfg)
defer tracers.Close() // nolint: errcheck
base := basecomponent.NewBaseDendrite(cfg, tracers, "ClientAPI")
defer base.Close() // nolint: errcheck
accountDB := base.CreateAccountsDB()
@ -37,7 +41,7 @@ func main() {
cache := transactions.New()
clientapi.SetupClientAPIComponent(
base, deviceDB, accountDB, federation, &keyRing,
base, tracers, deviceDB, accountDB, federation, &keyRing,
alias, input, query, cache,
)

View file

@ -15,6 +15,7 @@
package main
import (
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/common/basecomponent"
"github.com/matrix-org/dendrite/common/keydb"
"github.com/matrix-org/dendrite/federationapi"
@ -22,7 +23,11 @@ import (
func main() {
cfg := basecomponent.ParseFlags()
base := basecomponent.NewBaseDendrite(cfg, "FederationAPI")
tracers := common.NewTracers(cfg)
defer tracers.Close() // nolint: errcheck
base := basecomponent.NewBaseDendrite(cfg, tracers, "FederationAPI")
defer base.Close() // nolint: errcheck
accountDB := base.CreateAccountsDB()
@ -34,7 +39,7 @@ func main() {
alias, input, query := base.CreateHTTPRoomserverAPIs()
federationapi.SetupFederationAPIComponent(
base, accountDB, deviceDB, federation, &keyRing,
base, tracers, accountDB, deviceDB, federation, &keyRing,
alias, input, query,
)

View file

@ -15,21 +15,33 @@
package main
import (
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/common/basecomponent"
"github.com/matrix-org/dendrite/federationsender"
)
func main() {
cfg := basecomponent.ParseFlags()
base := basecomponent.NewBaseDendrite(cfg, "FederationSender")
tracers := common.NewTracers(cfg)
defer tracers.Close() // nolint: errcheck
base := basecomponent.NewBaseDendrite(cfg, tracers, "FederationSender")
defer base.Close() // nolint: errcheck
federation := base.CreateFederationClient()
/* TODO delete
err = tracers.InitGlobalTracer("Dendrite - Federation Sender")
if err != nil {
log.WithError(err).Fatalf("Failed to start tracer")
}
*/
_, _, query := base.CreateHTTPRoomserverAPIs()
federationsender.SetupFederationSenderComponent(
base, federation, query,
base, tracers, federation, query,
)
base.SetupAndServeHTTP(string(base.Cfg.Listen.FederationSender))

View file

@ -15,18 +15,23 @@
package main
import (
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/common/basecomponent"
"github.com/matrix-org/dendrite/mediaapi"
)
func main() {
cfg := basecomponent.ParseFlags()
base := basecomponent.NewBaseDendrite(cfg, "MediaAPI")
tracers := common.NewTracers(cfg)
defer tracers.Close() // nolint: errcheck
base := basecomponent.NewBaseDendrite(cfg, tracers, "MediaAPI")
defer base.Close() // nolint: errcheck
deviceDB := base.CreateDeviceDB()
mediaapi.SetupMediaAPIComponent(base, deviceDB)
mediaapi.SetupMediaAPIComponent(base, tracers, deviceDB)
base.SetupAndServeHTTP(string(base.Cfg.Listen.MediaAPI))
}

View file

@ -18,13 +18,12 @@ import (
"flag"
"net/http"
"github.com/matrix-org/dendrite/common/keydb"
"github.com/matrix-org/dendrite/common/transactions"
"github.com/matrix-org/dendrite/appservice"
"github.com/matrix-org/dendrite/clientapi"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/common/basecomponent"
"github.com/matrix-org/dendrite/common/keydb"
"github.com/matrix-org/dendrite/common/transactions"
"github.com/matrix-org/dendrite/federationapi"
"github.com/matrix-org/dendrite/federationsender"
"github.com/matrix-org/dendrite/mediaapi"
@ -45,7 +44,11 @@ var (
func main() {
cfg := basecomponent.ParseMonolithFlags()
base := basecomponent.NewBaseDendrite(cfg, "Monolith")
tracers := common.NewTracers(cfg)
defer tracers.Close() // nolint: errcheck
base := basecomponent.NewBaseDendrite(cfg, tracers, "Monolith")
defer base.Close() // nolint: errcheck
accountDB := base.CreateAccountsDB()
@ -54,19 +57,19 @@ func main() {
federation := base.CreateFederationClient()
keyRing := keydb.CreateKeyRing(federation.Client, keyDB)
alias, input, query := roomserver.SetupRoomServerComponent(base)
alias, input, query := roomserver.SetupRoomServerComponent(base, tracers)
clientapi.SetupClientAPIComponent(
base, deviceDB, accountDB,
base, tracers, deviceDB, accountDB,
federation, &keyRing, alias, input, query,
transactions.New(),
)
federationapi.SetupFederationAPIComponent(base, accountDB, deviceDB, federation, &keyRing, alias, input, query)
federationsender.SetupFederationSenderComponent(base, federation, query)
mediaapi.SetupMediaAPIComponent(base, deviceDB)
publicroomsapi.SetupPublicRoomsAPIComponent(base, deviceDB)
syncapi.SetupSyncAPIComponent(base, deviceDB, accountDB, query)
appservice.SetupAppServiceAPIComponent(base, accountDB, deviceDB, federation, alias, query, transactions.New())
federationapi.SetupFederationAPIComponent(base, tracers, accountDB, deviceDB, federation, &keyRing, alias, input, query)
federationsender.SetupFederationSenderComponent(base, tracers, federation, query)
mediaapi.SetupMediaAPIComponent(base, tracers, deviceDB)
publicroomsapi.SetupPublicRoomsAPIComponent(base, tracers, deviceDB)
syncapi.SetupSyncAPIComponent(base, tracers, deviceDB, accountDB, query)
appservice.SetupAppServiceAPIComponent(base, tracers, accountDB, deviceDB, federation, alias, query, transactions.New())
httpHandler := common.WrapHandlerInCORS(base.APIMux)

View file

@ -15,18 +15,23 @@
package main
import (
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/common/basecomponent"
"github.com/matrix-org/dendrite/publicroomsapi"
)
func main() {
cfg := basecomponent.ParseFlags()
base := basecomponent.NewBaseDendrite(cfg, "PublicRoomsAPI")
tracers := common.NewTracers(cfg)
defer tracers.Close() // nolint: errcheck
base := basecomponent.NewBaseDendrite(cfg, tracers, "PublicRoomsAPI")
defer base.Close() // nolint: errcheck
deviceDB := base.CreateDeviceDB()
publicroomsapi.SetupPublicRoomsAPIComponent(base, deviceDB)
publicroomsapi.SetupPublicRoomsAPIComponent(base, tracers, deviceDB)
base.SetupAndServeHTTP(string(base.Cfg.Listen.PublicRoomsAPI))
}

View file

@ -17,16 +17,21 @@ package main
import (
_ "net/http/pprof"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/common/basecomponent"
"github.com/matrix-org/dendrite/roomserver"
)
func main() {
cfg := basecomponent.ParseFlags()
base := basecomponent.NewBaseDendrite(cfg, "RoomServerAPI")
tracers := common.NewTracers(cfg)
defer tracers.Close() // nolint: errcheck
base := basecomponent.NewBaseDendrite(cfg, tracers, "RoomServerAPI")
defer base.Close() // nolint: errcheck
roomserver.SetupRoomServerComponent(base)
roomserver.SetupRoomServerComponent(base, tracers)
base.SetupAndServeHTTP(string(base.Cfg.Listen.RoomServer))
}

View file

@ -15,13 +15,18 @@
package main
import (
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/common/basecomponent"
"github.com/matrix-org/dendrite/syncapi"
)
func main() {
cfg := basecomponent.ParseFlags()
base := basecomponent.NewBaseDendrite(cfg, "SyncAPI")
tracers := common.NewTracers(cfg)
defer tracers.Close() // nolint: errcheck
base := basecomponent.NewBaseDendrite(cfg, tracers, "SyncAPI")
defer base.Close() // nolint: errcheck
deviceDB := base.CreateDeviceDB()
@ -29,7 +34,7 @@ func main() {
_, _, query := base.CreateHTTPRoomserverAPIs()
syncapi.SetupSyncAPIComponent(base, deviceDB, accountDB, query)
syncapi.SetupSyncAPIComponent(base, tracers, deviceDB, accountDB, query)
base.SetupAndServeHTTP(string(base.Cfg.Listen.SyncAPI))
}

View file

@ -16,7 +16,6 @@ package basecomponent
import (
"database/sql"
"io"
"net/http"
"github.com/matrix-org/dendrite/common/keydb"
@ -43,7 +42,7 @@ import (
// Must be closed when shutting down.
type BaseDendrite struct {
componentName string
tracerCloser io.Closer
tracers *common.Tracers
// APIMux should be used to register new public matrix api endpoints
APIMux *mux.Router
@ -55,20 +54,19 @@ type BaseDendrite struct {
// NewBaseDendrite creates a new instance to be used by a component.
// The componentName is used for logging purposes, and should be a friendly name
// of the compontent running, e.g. "SyncAPI"
func NewBaseDendrite(cfg *config.Dendrite, componentName string) *BaseDendrite {
func NewBaseDendrite(
cfg *config.Dendrite,
tracers *common.Tracers,
componentName string,
) *BaseDendrite {
common.SetupStdLogging()
common.SetupHookLogging(cfg.Logging, componentName)
closer, err := cfg.SetupTracing("Dendrite" + componentName)
if err != nil {
logrus.WithError(err).Panicf("failed to start opentracing")
}
kafkaConsumer, kafkaProducer := setupKafka(cfg)
return &BaseDendrite{
componentName: componentName,
tracerCloser: closer,
tracers: tracers,
Cfg: cfg,
APIMux: mux.NewRouter(),
KafkaConsumer: kafkaConsumer,
@ -78,7 +76,7 @@ func NewBaseDendrite(cfg *config.Dendrite, componentName string) *BaseDendrite {
// Close implements io.Closer
func (b *BaseDendrite) Close() error {
return b.tracerCloser.Close()
return b.tracers.Close()
}
// CreateHTTPAppServiceAPIs returns the QueryAPI for hitting the appservice
@ -103,7 +101,7 @@ func (b *BaseDendrite) CreateHTTPRoomserverAPIs() (
// CreateDeviceDB creates a new instance of the device database. Should only be
// called once per component.
func (b *BaseDendrite) CreateDeviceDB() *devices.Database {
db, err := devices.NewDatabase(string(b.Cfg.Database.Device), b.Cfg.Matrix.ServerName)
db, err := devices.NewDatabase(b.tracers, string(b.Cfg.Database.Device), b.Cfg.Matrix.ServerName)
if err != nil {
logrus.WithError(err).Panicf("failed to connect to devices db")
}
@ -114,7 +112,7 @@ func (b *BaseDendrite) CreateDeviceDB() *devices.Database {
// CreateAccountsDB creates a new instance of the accounts database. Should only
// be called once per component.
func (b *BaseDendrite) CreateAccountsDB() *accounts.Database {
db, err := accounts.NewDatabase(string(b.Cfg.Database.Account), b.Cfg.Matrix.ServerName)
db, err := accounts.NewDatabase(b.tracers, string(b.Cfg.Database.Account), b.Cfg.Matrix.ServerName)
if err != nil {
logrus.WithError(err).Panicf("failed to connect to accounts db")
}
@ -125,7 +123,7 @@ func (b *BaseDendrite) CreateAccountsDB() *accounts.Database {
// CreateKeyDB creates a new instance of the key database. Should only be called
// once per component.
func (b *BaseDendrite) CreateKeyDB() *keydb.Database {
db, err := keydb.NewDatabase(string(b.Cfg.Database.ServerKey))
db, err := keydb.NewDatabase(b.tracers, string(b.Cfg.Database.ServerKey))
if err != nil {
logrus.WithError(err).Panicf("failed to connect to keys db")
}

View file

@ -19,7 +19,6 @@ import (
"crypto/sha256"
"encoding/pem"
"fmt"
"io"
"io/ioutil"
"path/filepath"
"regexp"
@ -28,12 +27,10 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/ed25519"
yaml "gopkg.in/yaml.v2"
jaegerconfig "github.com/uber/jaeger-client-go/config"
jaegermetrics "github.com/uber/jaeger-lib/metrics"
)
// Version is the current version of the config format.
@ -658,25 +655,3 @@ func (config *Dendrite) RoomServerURL() string {
// internet for an internal API.
return "http://" + string(config.Listen.RoomServer)
}
// SetupTracing configures the opentracing using the supplied configuration.
func (config *Dendrite) SetupTracing(serviceName string) (closer io.Closer, err error) {
return config.Tracing.Jaeger.InitGlobalTracer(
serviceName,
jaegerconfig.Logger(logrusLogger{logrus.StandardLogger()}),
jaegerconfig.Metrics(jaegermetrics.NullFactory),
)
}
// logrusLogger is a small wrapper that implements jaeger.Logger using logrus.
type logrusLogger struct {
l *logrus.Logger
}
func (l logrusLogger) Error(msg string) {
l.l.Error(msg)
}
func (l logrusLogger) Infof(msg string, args ...interface{}) {
l.l.Infof(msg, args...)
}

View file

@ -15,7 +15,7 @@ import (
// MakeAuthAPI turns a util.JSONRequestHandler function into an http.Handler which authenticates the request.
func MakeAuthAPI(
metricsName string, data auth.Data,
tracer opentracing.Tracer, metricsName string, data auth.Data,
f func(*http.Request, *authtypes.Device) util.JSONResponse,
) http.Handler {
h := func(req *http.Request) util.JSONResponse {
@ -26,15 +26,15 @@ func MakeAuthAPI(
return f(req, device)
}
return MakeExternalAPI(metricsName, h)
return MakeExternalAPI(tracer, metricsName, h)
}
// MakeExternalAPI turns a util.JSONRequestHandler function into an http.Handler.
// This is used for APIs that are called from the internet.
func MakeExternalAPI(metricsName string, f func(*http.Request) util.JSONResponse) http.Handler {
func MakeExternalAPI(tracer opentracing.Tracer, metricsName string, f func(*http.Request) util.JSONResponse) http.Handler {
h := util.MakeJSONAPI(util.NewJSONRequestHandler(f))
withSpan := func(w http.ResponseWriter, req *http.Request) {
span := opentracing.StartSpan(metricsName)
span := tracer.StartSpan(metricsName)
defer span.Finish()
req = req.WithContext(opentracing.ContextWithSpan(req.Context(), span))
h.ServeHTTP(w, req)
@ -47,11 +47,10 @@ func MakeExternalAPI(metricsName string, f func(*http.Request) util.JSONResponse
// This is used for APIs that are internal to dendrite.
// If we are passed a tracing context in the request headers then we use that
// as the parent of any tracing spans we create.
func MakeInternalAPI(metricsName string, f func(*http.Request) util.JSONResponse) http.Handler {
func MakeInternalAPI(tracer opentracing.Tracer, metricsName string, f func(*http.Request) util.JSONResponse) http.Handler {
h := util.MakeJSONAPI(util.NewJSONRequestHandler(f))
withSpan := func(w http.ResponseWriter, req *http.Request) {
carrier := opentracing.HTTPHeadersCarrier(req.Header)
tracer := opentracing.GlobalTracer()
clientContext, err := tracer.Extract(opentracing.HTTPHeaders, carrier)
var span opentracing.Span
if err == nil {
@ -71,6 +70,7 @@ func MakeInternalAPI(metricsName string, f func(*http.Request) util.JSONResponse
// MakeFedAPI makes an http.Handler that checks matrix federation authentication.
func MakeFedAPI(
tracer opentracing.Tracer,
metricsName string,
serverName gomatrixserverlib.ServerName,
keyRing gomatrixserverlib.KeyRing,
@ -85,7 +85,7 @@ func MakeFedAPI(
}
return f(req, fedReq)
}
return MakeExternalAPI(metricsName, h)
return MakeExternalAPI(tracer, metricsName, h)
}
// SetupHTTPAPI registers an HTTP API mux under /api and sets up a metrics
@ -105,7 +105,7 @@ func WrapHandlerInCORS(h http.Handler) http.HandlerFunc {
w.Header().Set("Access-Control-Allow-Headers", "Origin, X-Requested-With, Content-Type, Accept, Authorization")
if r.Method == http.MethodOptions && r.Header.Get("Access-Control-Request-Method") != "" {
// Its easiest just to always return a 200 OK for everything. Whether
// It's easiest just to always return a 200 OK for everything. Whether
// this is technically correct or not is a question, but in the end this
// is what a lot of other people do (including synapse) and the clients
// are perfectly happy with it.

View file

@ -16,8 +16,8 @@ package keydb
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/gomatrixserverlib"
)
@ -31,8 +31,8 @@ type Database struct {
// It creates the necessary tables if they don't already exist.
// It prepares all the SQL statements that it will use.
// Returns an error if there was a problem talking to the database.
func NewDatabase(dataSourceName string) (*Database, error) {
db, err := sql.Open("postgres", dataSourceName)
func NewDatabase(tracers *common.Tracers, dataSourceName string) (*Database, error) {
db, err := common.OpenPostgresWithTracing(tracers, "keys", dataSourceName)
if err != nil {
return nil, err
}

View file

@ -16,8 +16,12 @@ package common
import (
"database/sql"
"fmt"
"github.com/gchaincl/sqlhooks"
"github.com/gchaincl/sqlhooks/hooks/othooks"
"github.com/lib/pq"
"github.com/matrix-org/util"
)
// A Transaction is something that can be committed or rolledback.
@ -74,3 +78,18 @@ func IsUniqueConstraintViolationErr(err error) bool {
pqErr, ok := err.(*pq.Error)
return ok && pqErr.Code == "23505"
}
// OpenPostgresWithTracing creates a new DB instance where calls will be
// traced with the given tracer
func OpenPostgresWithTracing(tracers *Tracers, databaseName, connstr string) (*sql.DB, error) {
tracer := tracers.SetupNewTracer("sql: " + databaseName)
hooks := othooks.New(tracer)
// This is a hack to get around the fact that you can't directly open
// a sql.DB with a given driver, you *have* to register it.
registrationName := fmt.Sprintf("postgres-ot-%s", util.RandomString(5))
sql.Register(registrationName, sqlhooks.Wrap(&pq.Driver{}, hooks))
return sql.Open(registrationName, connstr)
}

View file

@ -0,0 +1,103 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package common
import (
"io"
"github.com/matrix-org/dendrite/common/config"
opentracing "github.com/opentracing/opentracing-go"
"github.com/sirupsen/logrus"
jaegerconfig "github.com/uber/jaeger-client-go/config"
jaegermetrics "github.com/uber/jaeger-lib/metrics"
)
type Tracers struct {
cfg *config.Dendrite
closers []io.Closer
}
func NoopTracers() *Tracers {
return &Tracers{}
}
func NewTracers(cfg *config.Dendrite) *Tracers {
return &Tracers{
cfg: cfg,
}
}
func (t *Tracers) InitGlobalTracer(serviceName string) error {
if t.cfg == nil {
return nil
}
// Set up GlobalTracer
closer, err := t.cfg.Tracing.Jaeger.InitGlobalTracer(
serviceName,
jaegerconfig.Logger(logrusLogger{logrus.StandardLogger()}),
jaegerconfig.Metrics(jaegermetrics.NullFactory),
)
if err != nil {
return err
}
t.closers = append(t.closers, closer)
return nil
}
// SetupTracing configures the opentracing using the supplied configuration.
func (t *Tracers) SetupNewTracer(serviceName string) opentracing.Tracer {
if t.cfg == nil {
return opentracing.NoopTracer{}
}
tracer, closer, err := t.cfg.Tracing.Jaeger.New(
serviceName,
jaegerconfig.Logger(logrusLogger{logrus.StandardLogger()}),
jaegerconfig.Metrics(jaegermetrics.NullFactory),
)
if err != nil {
logrus.Panicf("Failed to create new tracer %s: %s", serviceName, err)
}
t.closers = append(t.closers, closer)
return tracer
}
func (t *Tracers) Close() error {
for _, c := range t.closers {
c.Close() // nolint: errcheck
}
return nil
}
// logrusLogger is a small wrapper that implements jaeger.Logger using logrus.
type logrusLogger struct {
l *logrus.Logger
}
func (l logrusLogger) Error(msg string) {
l.l.Error(msg)
}
func (l logrusLogger) Infof(msg string, args ...interface{}) {
l.l.Infof(msg, args...)
}

View file

@ -17,6 +17,7 @@ package federationapi
import (
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts"
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/common/basecomponent"
"github.com/matrix-org/dendrite/roomserver/api"
// TODO: Are we really wanting to pull in the producer from clientapi
@ -29,6 +30,7 @@ import (
// FederationAPI component.
func SetupFederationAPIComponent(
base *basecomponent.BaseDendrite,
tracers *common.Tracers,
accountsDB *accounts.Database,
deviceDB *devices.Database,
federation *gomatrixserverlib.FederationClient,
@ -39,8 +41,11 @@ func SetupFederationAPIComponent(
) {
roomserverProducer := producers.NewRoomserverProducer(inputAPI)
tracer := tracers.SetupNewTracer("Dendrite - FederationAPI")
routing.Setup(
base.APIMux, *base.Cfg, queryAPI, aliasAPI,
roomserverProducer, *keyRing, federation, accountsDB, deviceDB,
tracer,
)
}

View file

@ -27,6 +27,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
opentracing "github.com/opentracing/opentracing-go"
)
const (
@ -45,11 +46,12 @@ func Setup(
federation *gomatrixserverlib.FederationClient,
accountDB *accounts.Database,
deviceDB *devices.Database,
tracer opentracing.Tracer,
) {
v2keysmux := apiMux.PathPrefix(pathPrefixV2Keys).Subrouter()
v1fedmux := apiMux.PathPrefix(pathPrefixV1Federation).Subrouter()
localKeys := common.MakeExternalAPI("localkeys", func(req *http.Request) util.JSONResponse {
localKeys := common.MakeExternalAPI(tracer, "localkeys", func(req *http.Request) util.JSONResponse {
return LocalKeys(cfg)
})
@ -61,7 +63,7 @@ func Setup(
v2keysmux.Handle("/server/", localKeys).Methods(http.MethodGet)
v1fedmux.Handle("/send/{txnID}/", common.MakeFedAPI(
"federation_send", cfg.Matrix.ServerName, keys,
tracer, "federation_send", cfg.Matrix.ServerName, keys,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse {
vars := mux.Vars(httpReq)
return Send(
@ -72,7 +74,7 @@ func Setup(
)).Methods(http.MethodPut, http.MethodOptions)
v1fedmux.Handle("/invite/{roomID}/{eventID}", common.MakeFedAPI(
"federation_invite", cfg.Matrix.ServerName, keys,
tracer, "federation_invite", cfg.Matrix.ServerName, keys,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse {
vars := mux.Vars(httpReq)
return Invite(
@ -82,14 +84,15 @@ func Setup(
},
)).Methods(http.MethodPut, http.MethodOptions)
v1fedmux.Handle("/3pid/onbind", common.MakeExternalAPI("3pid_onbind",
v1fedmux.Handle("/3pid/onbind", common.MakeExternalAPI(
tracer, "3pid_onbind",
func(req *http.Request) util.JSONResponse {
return CreateInvitesFrom3PIDInvites(req, query, cfg, producer, federation, accountDB)
},
)).Methods(http.MethodPost, http.MethodOptions)
v1fedmux.Handle("/exchange_third_party_invite/{roomID}", common.MakeFedAPI(
"exchange_third_party_invite", cfg.Matrix.ServerName, keys,
tracer, "exchange_third_party_invite", cfg.Matrix.ServerName, keys,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse {
vars := mux.Vars(httpReq)
return ExchangeThirdPartyInvite(
@ -99,7 +102,7 @@ func Setup(
)).Methods(http.MethodPut, http.MethodOptions)
v1fedmux.Handle("/event/{eventID}", common.MakeFedAPI(
"federation_get_event", cfg.Matrix.ServerName, keys,
tracer, "federation_get_event", cfg.Matrix.ServerName, keys,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse {
vars := mux.Vars(httpReq)
return GetEvent(
@ -109,7 +112,7 @@ func Setup(
)).Methods(http.MethodGet)
v1fedmux.Handle("/state/{roomID}", common.MakeFedAPI(
"federation_get_event_auth", cfg.Matrix.ServerName, keys,
tracer, "federation_get_event_auth", cfg.Matrix.ServerName, keys,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse {
vars := mux.Vars(httpReq)
return GetState(
@ -120,7 +123,7 @@ func Setup(
)).Methods(http.MethodGet)
v1fedmux.Handle("/state_ids/{roomID}", common.MakeFedAPI(
"federation_get_event_auth", cfg.Matrix.ServerName, keys,
tracer, "federation_get_event_auth", cfg.Matrix.ServerName, keys,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse {
vars := mux.Vars(httpReq)
return GetStateIDs(
@ -131,7 +134,7 @@ func Setup(
)).Methods(http.MethodGet)
v1fedmux.Handle("/query/directory/", common.MakeFedAPI(
"federation_query_room_alias", cfg.Matrix.ServerName, keys,
tracer, "federation_query_room_alias", cfg.Matrix.ServerName, keys,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse {
return RoomAliasToID(
httpReq, federation, cfg, aliasAPI,
@ -140,7 +143,7 @@ func Setup(
)).Methods(http.MethodGet)
v1fedmux.Handle("/query/profile", common.MakeFedAPI(
"federation_query_profile", cfg.Matrix.ServerName, keys,
tracer, "federation_query_profile", cfg.Matrix.ServerName, keys,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse {
return GetProfile(
httpReq, accountDB, cfg,
@ -149,7 +152,7 @@ func Setup(
)).Methods(http.MethodGet)
v1fedmux.Handle("/query/user_devices/{userID}", common.MakeFedAPI(
"federation_query_user_devices", cfg.Matrix.ServerName, keys,
tracer, "federation_query_user_devices", cfg.Matrix.ServerName, keys,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse {
vars := mux.Vars(httpReq)
return GetUserDevices(
@ -159,7 +162,7 @@ func Setup(
)).Methods(http.MethodGet)
v1fedmux.Handle("/make_join/{roomID}/{userID}", common.MakeFedAPI(
"federation_make_join", cfg.Matrix.ServerName, keys,
tracer, "federation_make_join", cfg.Matrix.ServerName, keys,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse {
vars := mux.Vars(httpReq)
roomID := vars["roomID"]
@ -171,7 +174,7 @@ func Setup(
)).Methods(http.MethodGet)
v1fedmux.Handle("/send_join/{roomID}/{userID}", common.MakeFedAPI(
"federation_send_join", cfg.Matrix.ServerName, keys,
tracer, "federation_send_join", cfg.Matrix.ServerName, keys,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse {
vars := mux.Vars(httpReq)
roomID := vars["roomID"]
@ -183,7 +186,7 @@ func Setup(
)).Methods(http.MethodPut)
v1fedmux.Handle("/make_leave/{roomID}/{userID}", common.MakeFedAPI(
"federation_make_leave", cfg.Matrix.ServerName, keys,
tracer, "federation_make_leave", cfg.Matrix.ServerName, keys,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse {
vars := mux.Vars(httpReq)
roomID := vars["roomID"]
@ -195,7 +198,7 @@ func Setup(
)).Methods(http.MethodGet)
v1fedmux.Handle("/send_leave/{roomID}/{userID}", common.MakeFedAPI(
"federation_send_leave", cfg.Matrix.ServerName, keys,
tracer, "federation_send_leave", cfg.Matrix.ServerName, keys,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse {
vars := mux.Vars(httpReq)
roomID := vars["roomID"]
@ -207,14 +210,14 @@ func Setup(
)).Methods(http.MethodPut)
v1fedmux.Handle("/version", common.MakeExternalAPI(
"federation_version",
tracer, "federation_version",
func(httpReq *http.Request) util.JSONResponse {
return Version()
},
)).Methods(http.MethodGet)
v1fedmux.Handle("get_missing_events/{roomID}", common.MakeFedAPI(
"federation_get_missing_events", cfg.Matrix.ServerName, keys,
tracer, "federation_get_missing_events", cfg.Matrix.ServerName, keys,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse {
vars := mux.Vars(httpReq)
return GetMissingEvents(httpReq, request, query, vars["roomID"])

View file

@ -26,6 +26,7 @@ import (
"github.com/matrix-org/dendrite/federationsender/types"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib"
opentracing "github.com/opentracing/opentracing-go"
log "github.com/sirupsen/logrus"
sarama "gopkg.in/Shopify/sarama.v1"
)
@ -36,6 +37,7 @@ type OutputRoomEventConsumer struct {
db *storage.Database
queues *queue.OutgoingQueues
query api.RoomserverQueryAPI
tracer opentracing.Tracer
}
// NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers.
@ -45,6 +47,7 @@ func NewOutputRoomEventConsumer(
queues *queue.OutgoingQueues,
store *storage.Database,
queryAPI api.RoomserverQueryAPI,
tracer opentracing.Tracer,
) *OutputRoomEventConsumer {
consumer := common.ContinualConsumer{
Topic: string(cfg.Kafka.Topics.OutputRoomEvent),
@ -56,6 +59,7 @@ func NewOutputRoomEventConsumer(
db: store,
queues: queues,
query: queryAPI,
tracer: tracer,
}
consumer.ProcessMessage = s.onMessage
@ -85,6 +89,10 @@ func (s *OutputRoomEventConsumer) onMessage(msg *sarama.ConsumerMessage) error {
)
return nil
}
ctx, span := output.StartSpanAndReplaceContext(context.Background(), s.tracer)
defer span.Finish()
ev := &output.NewRoomEvent.Event
log.WithFields(log.Fields{
"event_id": ev.EventID(),
@ -92,7 +100,7 @@ func (s *OutputRoomEventConsumer) onMessage(msg *sarama.ConsumerMessage) error {
"send_as_server": output.NewRoomEvent.SendAsServer,
}).Info("received event from roomserver")
if err := s.processMessage(*output.NewRoomEvent); err != nil {
if err := s.processMessage(ctx, *output.NewRoomEvent); err != nil {
// panic rather than continue with an inconsistent database
log.WithFields(log.Fields{
"event": string(ev.JSON()),
@ -108,8 +116,10 @@ func (s *OutputRoomEventConsumer) onMessage(msg *sarama.ConsumerMessage) error {
// processMessage updates the list of currently joined hosts in the room
// and then sends the event to the hosts that were joined before the event.
func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent) error {
addsStateEvents, err := s.lookupStateEvents(ore.AddsStateEventIDs, ore.Event)
func (s *OutputRoomEventConsumer) processMessage(
ctx context.Context, ore api.OutputNewRoomEvent,
) error {
addsStateEvents, err := s.lookupStateEvents(ctx, ore.AddsStateEventIDs, ore.Event)
if err != nil {
return err
}
@ -123,7 +133,7 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent) err
// TODO(#290): handle EventIDMismatchError and recover the current state by
// talking to the roomserver
oldJoinedHosts, err := s.db.UpdateRoom(
context.TODO(),
ctx,
ore.Event.RoomID(),
ore.LastSentEventID,
ore.Event.EventID(),
@ -148,7 +158,7 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent) err
}
// Work out which hosts were joined at the event itself.
joinedHostsAtEvent, err := s.joinedHostsAtEvent(ore, oldJoinedHosts)
joinedHostsAtEvent, err := s.joinedHostsAtEvent(ctx, ore, oldJoinedHosts)
if err != nil {
return err
}
@ -169,7 +179,7 @@ func (s *OutputRoomEventConsumer) processMessage(ore api.OutputNewRoomEvent) err
// events from the room server.
// Returns an error if there was a problem talking to the room server.
func (s *OutputRoomEventConsumer) joinedHostsAtEvent(
ore api.OutputNewRoomEvent, oldJoinedHosts []types.JoinedHost,
ctx context.Context, ore api.OutputNewRoomEvent, oldJoinedHosts []types.JoinedHost,
) ([]gomatrixserverlib.ServerName, error) {
// Combine the delta into a single delta so that the adds and removes can
// cancel each other out. This should reduce the number of times we need
@ -178,7 +188,7 @@ func (s *OutputRoomEventConsumer) joinedHostsAtEvent(
ore.AddsStateEventIDs, ore.RemovesStateEventIDs,
ore.StateBeforeAddsEventIDs, ore.StateBeforeRemovesEventIDs,
)
combinedAddsEvents, err := s.lookupStateEvents(combinedAdds, ore.Event)
combinedAddsEvents, err := s.lookupStateEvents(ctx, combinedAdds, ore.Event)
if err != nil {
return nil, err
}
@ -288,7 +298,7 @@ func combineDeltas(adds1, removes1, adds2, removes2 []string) (adds, removes []s
// lookupStateEvents looks up the state events that are added by a new event.
func (s *OutputRoomEventConsumer) lookupStateEvents(
addsStateEventIDs []string, event gomatrixserverlib.Event,
ctx context.Context, addsStateEventIDs []string, event gomatrixserverlib.Event,
) ([]gomatrixserverlib.Event, error) {
// Fast path if there aren't any new state events.
if len(addsStateEventIDs) == 0 {
@ -321,7 +331,7 @@ func (s *OutputRoomEventConsumer) lookupStateEvents(
// from the roomserver using the query API.
eventReq := api.QueryEventsByIDRequest{EventIDs: missing}
var eventResp api.QueryEventsByIDResponse
if err := s.query.QueryEventsByID(context.TODO(), &eventReq, &eventResp); err != nil {
if err := s.query.QueryEventsByID(ctx, &eventReq, &eventResp); err != nil {
return nil, err
}

View file

@ -15,6 +15,7 @@
package federationsender
import (
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/common/basecomponent"
"github.com/matrix-org/dendrite/federationsender/consumers"
"github.com/matrix-org/dendrite/federationsender/queue"
@ -28,19 +29,22 @@ import (
// FederationSender component.
func SetupFederationSenderComponent(
base *basecomponent.BaseDendrite,
tracers *common.Tracers,
federation *gomatrixserverlib.FederationClient,
queryAPI api.RoomserverQueryAPI,
) {
federationSenderDB, err := storage.NewDatabase(string(base.Cfg.Database.FederationSender))
federationSenderDB, err := storage.NewDatabase(tracers, string(base.Cfg.Database.FederationSender))
if err != nil {
logrus.WithError(err).Panic("failed to connect to federation sender db")
}
queues := queue.NewOutgoingQueues(base.Cfg.Matrix.ServerName, federation)
tracer := tracers.SetupNewTracer("Dendrite - FederationSender")
consumer := consumers.NewOutputRoomEventConsumer(
base.Cfg, base.KafkaConsumer, queues,
federationSenderDB, queryAPI,
federationSenderDB, queryAPI, tracer,
)
if err = consumer.Start(); err != nil {
logrus.WithError(err).Panic("failed to start room server consumer")

View file

@ -31,10 +31,10 @@ type Database struct {
}
// NewDatabase opens a new database
func NewDatabase(dataSourceName string) (*Database, error) {
func NewDatabase(tracers *common.Tracers, dataSourceName string) (*Database, error) {
var result Database
var err error
if result.db, err = sql.Open("postgres", dataSourceName); err != nil {
if result.db, err = common.OpenPostgresWithTracing(tracers, "federationsender", dataSourceName); err != nil {
return nil, err
}
if err = result.prepare(); err != nil {

View file

@ -16,6 +16,7 @@ package mediaapi
import (
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/common/basecomponent"
"github.com/matrix-org/dendrite/mediaapi/routing"
"github.com/matrix-org/dendrite/mediaapi/storage"
@ -27,14 +28,18 @@ import (
// component.
func SetupMediaAPIComponent(
base *basecomponent.BaseDendrite,
tracers *common.Tracers,
deviceDB *devices.Database,
) {
mediaDB, err := storage.Open(string(base.Cfg.Database.MediaAPI))
tracer := tracers.SetupNewTracer("Dendrite - MediaAPI")
mediaDB, err := storage.Open(tracers, string(base.Cfg.Database.MediaAPI))
if err != nil {
logrus.WithError(err).Panicf("failed to connect to media db")
}
routing.Setup(
base.APIMux, base.Cfg, mediaDB, deviceDB, gomatrixserverlib.NewClient(),
tracer,
)
}

View file

@ -19,6 +19,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
opentracing "github.com/opentracing/opentracing-go"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
@ -40,6 +41,7 @@ func Setup(
db *storage.Database,
deviceDB *devices.Database,
client *gomatrixserverlib.Client,
tracer opentracing.Tracer,
) {
r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter()
@ -48,9 +50,11 @@ func Setup(
}
authData := auth.Data{nil, deviceDB, nil}
// TODO: How to use tracing with these endpoints?
// TODO: Add AS support
r0mux.Handle("/upload", common.MakeAuthAPI(
"upload", authData,
tracer, "upload", authData,
func(req *http.Request, _ *authtypes.Device) util.JSONResponse {
return Upload(req, cfg, db, activeThumbnailGeneration)
},
@ -75,6 +79,8 @@ func makeDownloadAPI(
activeRemoteRequests *types.ActiveRemoteRequests,
activeThumbnailGeneration *types.ActiveThumbnailGeneration,
) http.HandlerFunc {
// TODO: Add opentracing.
return prometheus.InstrumentHandler(name, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
req = util.RequestWithLogging(req)

View file

@ -18,8 +18,7 @@ import (
"context"
"database/sql"
// Import the postgres database driver.
_ "github.com/lib/pq"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
@ -31,10 +30,10 @@ type Database struct {
}
// Open opens a postgres database.
func Open(dataSourceName string) (*Database, error) {
func Open(tracers *common.Tracers, dataSourceName string) (*Database, error) {
var d Database
var err error
if d.db, err = sql.Open("postgres", dataSourceName); err != nil {
if d.db, err = common.OpenPostgresWithTracing(tracers, "media", dataSourceName); err != nil {
return nil, err
}
if err = d.statements.prepare(d.db); err != nil {

View file

@ -22,6 +22,7 @@ import (
"github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/dendrite/publicroomsapi/storage"
"github.com/matrix-org/dendrite/roomserver/api"
opentracing "github.com/opentracing/opentracing-go"
log "github.com/sirupsen/logrus"
sarama "gopkg.in/Shopify/sarama.v1"
)
@ -31,6 +32,7 @@ type OutputRoomEventConsumer struct {
roomServerConsumer *common.ContinualConsumer
db *storage.PublicRoomsServerDatabase
query api.RoomserverQueryAPI
tracer opentracing.Tracer
}
// NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers.
@ -39,6 +41,7 @@ func NewOutputRoomEventConsumer(
kafkaConsumer sarama.Consumer,
store *storage.PublicRoomsServerDatabase,
queryAPI api.RoomserverQueryAPI,
tracer opentracing.Tracer,
) *OutputRoomEventConsumer {
consumer := common.ContinualConsumer{
Topic: string(cfg.Kafka.Topics.OutputRoomEvent),
@ -49,6 +52,7 @@ func NewOutputRoomEventConsumer(
roomServerConsumer: &consumer,
db: store,
query: queryAPI,
tracer: tracer,
}
consumer.ProcessMessage = s.onMessage
@ -77,6 +81,9 @@ func (s *OutputRoomEventConsumer) onMessage(msg *sarama.ConsumerMessage) error {
return nil
}
ctx, span := output.StartSpanAndReplaceContext(context.Background(), s.tracer)
defer span.Finish()
ev := output.NewRoomEvent.Event
log.WithFields(log.Fields{
"event_id": ev.EventID(),
@ -86,17 +93,17 @@ func (s *OutputRoomEventConsumer) onMessage(msg *sarama.ConsumerMessage) error {
addQueryReq := api.QueryEventsByIDRequest{EventIDs: output.NewRoomEvent.AddsStateEventIDs}
var addQueryRes api.QueryEventsByIDResponse
if err := s.query.QueryEventsByID(context.TODO(), &addQueryReq, &addQueryRes); err != nil {
if err := s.query.QueryEventsByID(ctx, &addQueryReq, &addQueryRes); err != nil {
log.Warn(err)
return err
}
remQueryReq := api.QueryEventsByIDRequest{EventIDs: output.NewRoomEvent.RemovesStateEventIDs}
var remQueryRes api.QueryEventsByIDResponse
if err := s.query.QueryEventsByID(context.TODO(), &remQueryReq, &remQueryRes); err != nil {
if err := s.query.QueryEventsByID(ctx, &remQueryReq, &remQueryRes); err != nil {
log.Warn(err)
return err
}
return s.db.UpdateRoomFromEvents(context.TODO(), addQueryRes.Events, remQueryRes.Events)
return s.db.UpdateRoomFromEvents(ctx, addQueryRes.Events, remQueryRes.Events)
}

View file

@ -16,6 +16,7 @@ package publicroomsapi
import (
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/common/basecomponent"
"github.com/matrix-org/dendrite/publicroomsapi/routing"
"github.com/matrix-org/dendrite/publicroomsapi/storage"
@ -26,12 +27,15 @@ import (
// component.
func SetupPublicRoomsAPIComponent(
base *basecomponent.BaseDendrite,
tracers *common.Tracers,
deviceDB *devices.Database,
) {
publicRoomsDB, err := storage.NewPublicRoomsServerDatabase(string(base.Cfg.Database.PublicRoomsAPI))
publicRoomsDB, err := storage.NewPublicRoomsServerDatabase(tracers, string(base.Cfg.Database.PublicRoomsAPI))
if err != nil {
logrus.WithError(err).Panicf("failed to connect to public rooms db")
}
routing.Setup(base.APIMux, deviceDB, publicRoomsDB)
tracer := tracers.SetupNewTracer("Dendrite - PublicRoomsAPI")
routing.Setup(base.APIMux, deviceDB, publicRoomsDB, tracer)
}

View file

@ -25,31 +25,37 @@ import (
"github.com/matrix-org/dendrite/publicroomsapi/directory"
"github.com/matrix-org/dendrite/publicroomsapi/storage"
"github.com/matrix-org/util"
opentracing "github.com/opentracing/opentracing-go"
)
const pathPrefixR0 = "/_matrix/client/r0"
// Setup configures the given mux with publicroomsapi server listeners
func Setup(apiMux *mux.Router, deviceDB *devices.Database, publicRoomsDB *storage.PublicRoomsServerDatabase) {
func Setup(
apiMux *mux.Router,
deviceDB *devices.Database,
publicRoomsDB *storage.PublicRoomsServerDatabase,
tracer opentracing.Tracer,
) {
r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter()
authData := auth.Data{nil, deviceDB, nil}
r0mux.Handle("/directory/list/room/{roomID}",
common.MakeExternalAPI("directory_list", func(req *http.Request) util.JSONResponse {
common.MakeExternalAPI(tracer, "directory_list", func(req *http.Request) util.JSONResponse {
vars := mux.Vars(req)
return directory.GetVisibility(req, publicRoomsDB, vars["roomID"])
}),
).Methods(http.MethodGet, http.MethodOptions)
// TODO: Add AS support
r0mux.Handle("/directory/list/room/{roomID}",
common.MakeAuthAPI("directory_list", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "directory_list", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
return directory.SetVisibility(req, publicRoomsDB, vars["roomID"])
}),
).Methods(http.MethodPut, http.MethodOptions)
r0mux.Handle("/publicRooms",
common.MakeExternalAPI("public_rooms", func(req *http.Request) util.JSONResponse {
common.MakeExternalAPI(tracer, "public_rooms", func(req *http.Request) util.JSONResponse {
return directory.GetPublicRooms(req, publicRoomsDB)
}),
).Methods(http.MethodGet, http.MethodPost, http.MethodOptions)

View file

@ -35,10 +35,10 @@ type PublicRoomsServerDatabase struct {
type attributeValue interface{}
// NewPublicRoomsServerDatabase creates a new public rooms server database.
func NewPublicRoomsServerDatabase(dataSourceName string) (*PublicRoomsServerDatabase, error) {
func NewPublicRoomsServerDatabase(tracers *common.Tracers, dataSourceName string) (*PublicRoomsServerDatabase, error) {
var db *sql.DB
var err error
if db, err = sql.Open("postgres", dataSourceName); err != nil {
if db, err = common.OpenPostgresWithTracing(tracers, "publicrooms", dataSourceName); err != nil {
return nil, err
}
partitions := common.PartitionOffsetStatements{}

View file

@ -26,6 +26,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
opentracing "github.com/opentracing/opentracing-go"
)
// RoomserverAliasAPIDatabase has the storage APIs needed to implement the alias API.
@ -156,7 +157,7 @@ func (r *RoomserverAliasAPI) sendUpdatedAliasesEvent(
StateKey: &serverName,
}
// Retrieve the updated list of aliases, marhal it and set it as the
// Retrieve the updated list of aliases, marshal it and set it as the
// event's content
aliases, err := r.DB.GetAliasesForRoomID(ctx, roomID)
if err != nil {
@ -229,10 +230,10 @@ func (r *RoomserverAliasAPI) sendUpdatedAliasesEvent(
}
// SetupHTTP adds the RoomserverAliasAPI handlers to the http.ServeMux.
func (r *RoomserverAliasAPI) SetupHTTP(servMux *http.ServeMux) {
func (r *RoomserverAliasAPI) SetupHTTP(servMux *http.ServeMux, tracer opentracing.Tracer) {
servMux.Handle(
api.RoomserverSetRoomAliasPath,
common.MakeInternalAPI("setRoomAlias", func(req *http.Request) util.JSONResponse {
common.MakeInternalAPI(tracer, "setRoomAlias", func(req *http.Request) util.JSONResponse {
var request api.SetRoomAliasRequest
var response api.SetRoomAliasResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
@ -246,7 +247,7 @@ func (r *RoomserverAliasAPI) SetupHTTP(servMux *http.ServeMux) {
)
servMux.Handle(
api.RoomserverGetRoomIDForAliasPath,
common.MakeInternalAPI("GetRoomIDForAlias", func(req *http.Request) util.JSONResponse {
common.MakeInternalAPI(tracer, "getRoomIDForAlias", func(req *http.Request) util.JSONResponse {
var request api.GetRoomIDForAliasRequest
var response api.GetRoomIDForAliasResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
@ -260,7 +261,7 @@ func (r *RoomserverAliasAPI) SetupHTTP(servMux *http.ServeMux) {
)
servMux.Handle(
api.RoomserverRemoveRoomAliasPath,
common.MakeInternalAPI("removeRoomAlias", func(req *http.Request) util.JSONResponse {
common.MakeInternalAPI(tracer, "removeRoomAlias", func(req *http.Request) util.JSONResponse {
var request api.RemoveRoomAliasRequest
var response api.RemoveRoomAliasResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {

View file

@ -15,7 +15,11 @@
package api
import (
"context"
"github.com/matrix-org/gomatrixserverlib"
opentracing "github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/ext"
)
// An OutputType is a type of roomserver output.
@ -41,6 +45,44 @@ type OutputEvent struct {
NewInviteEvent *OutputNewInviteEvent `json:"new_invite_event,omitempty"`
// The content of event with type OutputTypeRetireInviteEvent
RetireInviteEvent *OutputRetireInviteEvent `json:"retire_invite_event,omitempty"`
// Serialized span context
OpentracingCarrier opentracing.TextMapCarrier `json:"opentracing_carrier"`
}
// AddSpanFromContext fills out the OpentracingCarrier field from the given context
func (o *OutputEvent) AddSpanFromContext(ctx context.Context) error {
span := opentracing.SpanFromContext(ctx)
ext.SpanKindProducer.Set(span)
carrier := make(opentracing.TextMapCarrier)
tracer := opentracing.GlobalTracer()
err := tracer.Inject(span.Context(), opentracing.TextMap, carrier)
if err != nil {
return err
}
o.OpentracingCarrier = carrier
return nil
}
// StartSpanAndReplaceContext produces a context and opentracing span from the
// info embedded in OutputEvent
func (o *OutputEvent) StartSpanAndReplaceContext(
ctx context.Context, tracer opentracing.Tracer,
) (context.Context, opentracing.Span) {
producerContext, err := tracer.Extract(opentracing.TextMap, o.OpentracingCarrier)
var span opentracing.Span
if err != nil {
// Default to a span without reference to producer context.
span = tracer.StartSpan("output_event_consumer")
} else {
// Set the producer context.
span = tracer.StartSpan("output_event_consumer", opentracing.FollowsFrom(producerContext))
}
return opentracing.ContextWithSpan(ctx, span), span
}
// An OutputNewRoomEvent is written when the roomserver receives a new event.

View file

@ -221,7 +221,7 @@ func processInviteEvent(
return nil
}
outputUpdates, err := updateToInviteMembership(updater, &input.Event, nil)
outputUpdates, err := updateToInviteMembership(ctx, updater, &input.Event, nil)
if err != nil {
return err
}

View file

@ -24,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/util"
opentracing "github.com/opentracing/opentracing-go"
sarama "gopkg.in/Shopify/sarama.v1"
)
@ -78,9 +79,9 @@ func (r *RoomserverInputAPI) InputRoomEvents(
}
// SetupHTTP adds the RoomserverInputAPI handlers to the http.ServeMux.
func (r *RoomserverInputAPI) SetupHTTP(servMux *http.ServeMux) {
func (r *RoomserverInputAPI) SetupHTTP(servMux *http.ServeMux, tracer opentracing.Tracer) {
servMux.Handle(api.RoomserverInputRoomEventsPath,
common.MakeInternalAPI("inputRoomEvents", func(req *http.Request) util.JSONResponse {
common.MakeInternalAPI(tracer, "inputRoomEvents", func(req *http.Request) util.JSONResponse {
var request api.InputRoomEventsRequest
var response api.InputRoomEventsResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
@ -93,3 +94,29 @@ func (r *RoomserverInputAPI) SetupHTTP(servMux *http.ServeMux) {
}),
)
}
type InProcessRoomServerInput struct {
db RoomserverInputAPI
tracer opentracing.Tracer
}
func NewInProcessRoomServerInput(db RoomserverInputAPI, tracer opentracing.Tracer) *InProcessRoomServerInput {
return &InProcessRoomServerInput{
db: db, tracer: tracer,
}
}
func (r *InProcessRoomServerInput) InputRoomEvents(
ctx context.Context,
request *api.InputRoomEventsRequest,
response *api.InputRoomEventsResponse,
) error {
span := r.tracer.StartSpan(
"InputRoomEvents",
opentracing.ChildOf(opentracing.SpanFromContext(ctx).Context()),
)
defer span.Finish()
ctx = opentracing.ContextWithSpan(ctx, span)
return r.db.InputRoomEvents(ctx, request, response)
}

View file

@ -280,10 +280,17 @@ func (u *latestEventsUpdater) makeOutputNewRoomEvent() (*api.OutputEvent, error)
}
ore.SendAsServer = u.sendAsServer
return &api.OutputEvent{
oe := api.OutputEvent{
Type: api.OutputTypeNewRoomEvent,
NewRoomEvent: &ore,
}, nil
}
err = oe.AddSpanFromContext(u.ctx)
if err != nil {
return nil, err
}
return &oe, nil
}
type eventNIDSorter []types.EventNID

View file

@ -77,7 +77,7 @@ func updateMemberships(
ae = &ev.Event
}
}
if updates, err = updateMembership(updater, targetUserNID, re, ae, updates); err != nil {
if updates, err = updateMembership(ctx, updater, targetUserNID, re, ae, updates); err != nil {
return nil, err
}
}
@ -85,6 +85,7 @@ func updateMemberships(
}
func updateMembership(
ctx context.Context,
updater types.RoomRecentEventsUpdater, targetUserNID types.EventStateKeyNID,
remove, add *gomatrixserverlib.Event,
updates []api.OutputEvent,
@ -119,11 +120,11 @@ func updateMembership(
switch newMembership {
case invite:
return updateToInviteMembership(mu, add, updates)
return updateToInviteMembership(ctx, mu, add, updates)
case join:
return updateToJoinMembership(mu, add, updates)
return updateToJoinMembership(ctx, mu, add, updates)
case leave, ban:
return updateToLeaveMembership(mu, add, newMembership, updates)
return updateToLeaveMembership(ctx, mu, add, newMembership, updates)
default:
panic(fmt.Errorf(
"input: membership %q is not one of the allowed values", newMembership,
@ -132,7 +133,7 @@ func updateMembership(
}
func updateToInviteMembership(
mu types.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent,
ctx context.Context, mu types.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent,
) ([]api.OutputEvent, error) {
// We may have already sent the invite to the user, either because we are
// reprocessing this event, or because the we received this invite from a
@ -151,16 +152,24 @@ func updateToInviteMembership(
onie := api.OutputNewInviteEvent{
Event: *add,
}
updates = append(updates, api.OutputEvent{
oe := api.OutputEvent{
Type: api.OutputTypeNewInviteEvent,
NewInviteEvent: &onie,
})
}
err = oe.AddSpanFromContext(ctx)
if err != nil {
return nil, err
}
updates = append(updates, oe)
}
return updates, nil
}
func updateToJoinMembership(
mu types.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent,
ctx context.Context, mu types.MembershipUpdater, add *gomatrixserverlib.Event, updates []api.OutputEvent,
) ([]api.OutputEvent, error) {
// If the user is already marked as being joined, we call SetToJoin to update
// the event ID then we can return immediately. Retired is ignored as there
@ -187,15 +196,24 @@ func updateToJoinMembership(
RetiredByEventID: add.EventID(),
TargetUserID: *add.StateKey(),
}
updates = append(updates, api.OutputEvent{
oe := api.OutputEvent{
Type: api.OutputTypeRetireInviteEvent,
RetireInviteEvent: &orie,
})
}
err = oe.AddSpanFromContext(ctx)
if err != nil {
return nil, err
}
updates = append(updates, oe)
}
return updates, nil
}
func updateToLeaveMembership(
ctx context.Context,
mu types.MembershipUpdater, add *gomatrixserverlib.Event,
newMembership string, updates []api.OutputEvent,
) ([]api.OutputEvent, error) {
@ -219,10 +237,18 @@ func updateToLeaveMembership(
RetiredByEventID: add.EventID(),
TargetUserID: *add.StateKey(),
}
updates = append(updates, api.OutputEvent{
oe := api.OutputEvent{
Type: api.OutputTypeRetireInviteEvent,
RetireInviteEvent: &orie,
})
}
err = oe.AddSpanFromContext(ctx)
if err != nil {
return nil, err
}
updates = append(updates, oe)
}
return updates, nil
}

View file

@ -26,6 +26,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
opentracing "github.com/opentracing/opentracing-go"
)
// RoomserverQueryAPIEventDB has a convenience API to fetch events directly by
@ -581,10 +582,10 @@ func getAuthChain(
// SetupHTTP adds the RoomserverQueryAPI handlers to the http.ServeMux.
// nolint: gocyclo
func (r *RoomserverQueryAPI) SetupHTTP(servMux *http.ServeMux) {
func (r *RoomserverQueryAPI) SetupHTTP(servMux *http.ServeMux, tracer opentracing.Tracer) {
servMux.Handle(
api.RoomserverQueryLatestEventsAndStatePath,
common.MakeInternalAPI("queryLatestEventsAndState", func(req *http.Request) util.JSONResponse {
common.MakeInternalAPI(tracer, "queryLatestEventsAndState", func(req *http.Request) util.JSONResponse {
var request api.QueryLatestEventsAndStateRequest
var response api.QueryLatestEventsAndStateResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
@ -598,7 +599,7 @@ func (r *RoomserverQueryAPI) SetupHTTP(servMux *http.ServeMux) {
)
servMux.Handle(
api.RoomserverQueryStateAfterEventsPath,
common.MakeInternalAPI("queryStateAfterEvents", func(req *http.Request) util.JSONResponse {
common.MakeInternalAPI(tracer, "queryStateAfterEvents", func(req *http.Request) util.JSONResponse {
var request api.QueryStateAfterEventsRequest
var response api.QueryStateAfterEventsResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
@ -612,7 +613,7 @@ func (r *RoomserverQueryAPI) SetupHTTP(servMux *http.ServeMux) {
)
servMux.Handle(
api.RoomserverQueryEventsByIDPath,
common.MakeInternalAPI("queryEventsByID", func(req *http.Request) util.JSONResponse {
common.MakeInternalAPI(tracer, "queryEventsByID", func(req *http.Request) util.JSONResponse {
var request api.QueryEventsByIDRequest
var response api.QueryEventsByIDResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
@ -626,7 +627,7 @@ func (r *RoomserverQueryAPI) SetupHTTP(servMux *http.ServeMux) {
)
servMux.Handle(
api.RoomserverQueryMembershipForUserPath,
common.MakeInternalAPI("QueryMembershipForUser", func(req *http.Request) util.JSONResponse {
common.MakeInternalAPI(tracer, "QueryMembershipForUser", func(req *http.Request) util.JSONResponse {
var request api.QueryMembershipForUserRequest
var response api.QueryMembershipForUserResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
@ -640,7 +641,7 @@ func (r *RoomserverQueryAPI) SetupHTTP(servMux *http.ServeMux) {
)
servMux.Handle(
api.RoomserverQueryMembershipsForRoomPath,
common.MakeInternalAPI("queryMembershipsForRoom", func(req *http.Request) util.JSONResponse {
common.MakeInternalAPI(tracer, "queryMembershipsForRoom", func(req *http.Request) util.JSONResponse {
var request api.QueryMembershipsForRoomRequest
var response api.QueryMembershipsForRoomResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
@ -654,7 +655,7 @@ func (r *RoomserverQueryAPI) SetupHTTP(servMux *http.ServeMux) {
)
servMux.Handle(
api.RoomserverQueryInvitesForUserPath,
common.MakeInternalAPI("queryInvitesForUser", func(req *http.Request) util.JSONResponse {
common.MakeInternalAPI(tracer, "queryInvitesForUser", func(req *http.Request) util.JSONResponse {
var request api.QueryInvitesForUserRequest
var response api.QueryInvitesForUserResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
@ -668,7 +669,7 @@ func (r *RoomserverQueryAPI) SetupHTTP(servMux *http.ServeMux) {
)
servMux.Handle(
api.RoomserverQueryServerAllowedToSeeEventPath,
common.MakeInternalAPI("queryServerAllowedToSeeEvent", func(req *http.Request) util.JSONResponse {
common.MakeInternalAPI(tracer, "queryServerAllowedToSeeEvent", func(req *http.Request) util.JSONResponse {
var request api.QueryServerAllowedToSeeEventRequest
var response api.QueryServerAllowedToSeeEventResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
@ -682,7 +683,7 @@ func (r *RoomserverQueryAPI) SetupHTTP(servMux *http.ServeMux) {
)
servMux.Handle(
api.RoomserverQueryMissingEventsPath,
common.MakeInternalAPI("queryMissingEvents", func(req *http.Request) util.JSONResponse {
common.MakeInternalAPI(tracer, "queryMissingEvents", func(req *http.Request) util.JSONResponse {
var request api.QueryMissingEventsRequest
var response api.QueryMissingEventsResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
@ -696,7 +697,7 @@ func (r *RoomserverQueryAPI) SetupHTTP(servMux *http.ServeMux) {
)
servMux.Handle(
api.RoomserverQueryStateAndAuthChainPath,
common.MakeInternalAPI("queryStateAndAuthChain", func(req *http.Request) util.JSONResponse {
common.MakeInternalAPI(tracer, "queryStateAndAuthChain", func(req *http.Request) util.JSONResponse {
var request api.QueryStateAndAuthChainRequest
var response api.QueryStateAndAuthChainResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
@ -709,3 +710,127 @@ func (r *RoomserverQueryAPI) SetupHTTP(servMux *http.ServeMux) {
}),
)
}
type InProcessRoomServerQueryAPI struct {
db RoomserverQueryAPI
tracer opentracing.Tracer
}
func NewInProcessRoomServerQueryAPI(db RoomserverQueryAPI, tracer opentracing.Tracer) InProcessRoomServerQueryAPI {
return InProcessRoomServerQueryAPI{
db: db,
tracer: tracer,
}
}
// QueryLatestEventsAndState implements RoomserverQueryAPI
func (h *InProcessRoomServerQueryAPI) QueryLatestEventsAndState(
ctx context.Context,
request *api.QueryLatestEventsAndStateRequest,
response *api.QueryLatestEventsAndStateResponse,
) error {
span := h.tracer.StartSpan(
"QueryLatestEventsAndState",
opentracing.ChildOf(opentracing.SpanFromContext(ctx).Context()),
)
defer span.Finish()
ctx = opentracing.ContextWithSpan(ctx, span)
return h.db.QueryLatestEventsAndState(ctx, request, response)
}
// QueryStateAfterEvents implements RoomserverQueryAPI
func (h *InProcessRoomServerQueryAPI) QueryStateAfterEvents(
ctx context.Context,
request *api.QueryStateAfterEventsRequest,
response *api.QueryStateAfterEventsResponse,
) error {
span := h.tracer.StartSpan(
"QueryStateAfterEvents",
opentracing.ChildOf(opentracing.SpanFromContext(ctx).Context()),
)
defer span.Finish()
ctx = opentracing.ContextWithSpan(ctx, span)
return h.db.QueryStateAfterEvents(ctx, request, response)
}
// QueryEventsByID implements RoomserverQueryAPI
func (h *InProcessRoomServerQueryAPI) QueryEventsByID(
ctx context.Context,
request *api.QueryEventsByIDRequest,
response *api.QueryEventsByIDResponse,
) error {
span := h.tracer.StartSpan(
"QueryEventsByID",
opentracing.ChildOf(opentracing.SpanFromContext(ctx).Context()),
)
defer span.Finish()
ctx = opentracing.ContextWithSpan(ctx, span)
return h.db.QueryEventsByID(ctx, request, response)
}
// QueryMembershipsForRoom implements RoomserverQueryAPI
func (h *InProcessRoomServerQueryAPI) QueryMembershipsForRoom(
ctx context.Context,
request *api.QueryMembershipsForRoomRequest,
response *api.QueryMembershipsForRoomResponse,
) error {
span := h.tracer.StartSpan(
"QueryMembershipsForRoom",
opentracing.ChildOf(opentracing.SpanFromContext(ctx).Context()),
)
defer span.Finish()
ctx = opentracing.ContextWithSpan(ctx, span)
return h.db.QueryMembershipsForRoom(ctx, request, response)
}
// QueryInvitesForUser implements RoomserverQueryAPI
func (h *InProcessRoomServerQueryAPI) QueryInvitesForUser(
ctx context.Context,
request *api.QueryInvitesForUserRequest,
response *api.QueryInvitesForUserResponse,
) error {
span := h.tracer.StartSpan(
"QueryInvitesForUser",
opentracing.ChildOf(opentracing.SpanFromContext(ctx).Context()),
)
defer span.Finish()
ctx = opentracing.ContextWithSpan(ctx, span)
return h.db.QueryInvitesForUser(ctx, request, response)
}
// QueryServerAllowedToSeeEvent implements RoomserverQueryAPI
func (h *InProcessRoomServerQueryAPI) QueryServerAllowedToSeeEvent(
ctx context.Context,
request *api.QueryServerAllowedToSeeEventRequest,
response *api.QueryServerAllowedToSeeEventResponse,
) (err error) {
span := h.tracer.StartSpan(
"QueryServerAllowedToSeeEvent",
opentracing.ChildOf(opentracing.SpanFromContext(ctx).Context()),
)
defer span.Finish()
ctx = opentracing.ContextWithSpan(ctx, span)
return h.db.QueryServerAllowedToSeeEvent(ctx, request, response)
}
// QueryStateAndAuthChain implements RoomserverQueryAPI
func (h *InProcessRoomServerQueryAPI) QueryStateAndAuthChain(
ctx context.Context,
request *api.QueryStateAndAuthChainRequest,
response *api.QueryStateAndAuthChainResponse,
) error {
span := h.tracer.StartSpan(
"QueryStateAndAuthChain",
opentracing.ChildOf(opentracing.SpanFromContext(ctx).Context()),
)
defer span.Finish()
ctx = opentracing.ContextWithSpan(ctx, span)
return h.db.QueryStateAndAuthChain(ctx, request, response)
}

View file

@ -17,10 +17,10 @@ package roomserver
import (
"net/http"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/common/basecomponent"
"github.com/matrix-org/dendrite/roomserver/alias"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/input"
"github.com/matrix-org/dendrite/roomserver/query"
"github.com/matrix-org/dendrite/roomserver/storage"
@ -33,8 +33,11 @@ import (
// APIs directly instead of having to use HTTP.
func SetupRoomServerComponent(
base *basecomponent.BaseDendrite,
tracers *common.Tracers,
) (api.RoomserverAliasAPI, api.RoomserverInputAPI, api.RoomserverQueryAPI) {
roomserverDB, err := storage.Open(string(base.Cfg.Database.RoomServer))
tracer := tracers.SetupNewTracer("Dendrite - RoomserverAPI")
roomserverDB, err := storage.Open(tracers, string(base.Cfg.Database.RoomServer))
if err != nil {
logrus.WithError(err).Panicf("failed to connect to room server db")
}
@ -45,11 +48,11 @@ func SetupRoomServerComponent(
OutputRoomEventTopic: string(base.Cfg.Kafka.Topics.OutputRoomEvent),
}
inputAPI.SetupHTTP(http.DefaultServeMux)
inputAPI.SetupHTTP(http.DefaultServeMux, tracer)
queryAPI := query.RoomserverQueryAPI{DB: roomserverDB}
queryAPI.SetupHTTP(http.DefaultServeMux)
queryAPI.SetupHTTP(http.DefaultServeMux, tracer)
aliasAPI := alias.RoomserverAliasAPI{
DB: roomserverDB,
@ -58,7 +61,7 @@ func SetupRoomServerComponent(
QueryAPI: &queryAPI,
}
aliasAPI.SetupHTTP(http.DefaultServeMux)
aliasAPI.SetupHTTP(http.DefaultServeMux, tracer)
return &aliasAPI, &inputAPI, &queryAPI
}

View file

@ -20,6 +20,7 @@ import (
// Import the postgres database driver.
_ "github.com/lib/pq"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
@ -32,10 +33,10 @@ type Database struct {
}
// Open a postgres database.
func Open(dataSourceName string) (*Database, error) {
func Open(tracers *common.Tracers, dataSourceName string) (*Database, error) {
var d Database
var err error
if d.db, err = sql.Open("postgres", dataSourceName); err != nil {
if d.db, err = common.OpenPostgresWithTracing(tracers, "roomserver", dataSourceName); err != nil {
return nil, err
}
if err = d.statements.prepare(d.db); err != nil {

View file

@ -19,6 +19,8 @@ import (
"encoding/json"
"fmt"
"github.com/opentracing/opentracing-go"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/dendrite/roomserver/api"
@ -36,6 +38,7 @@ type OutputRoomEventConsumer struct {
db *storage.SyncServerDatabase
notifier *sync.Notifier
query api.RoomserverQueryAPI
tracer opentracing.Tracer
}
// NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers.
@ -45,6 +48,7 @@ func NewOutputRoomEventConsumer(
n *sync.Notifier,
store *storage.SyncServerDatabase,
queryAPI api.RoomserverQueryAPI,
tracer opentracing.Tracer,
) *OutputRoomEventConsumer {
consumer := common.ContinualConsumer{
@ -57,6 +61,7 @@ func NewOutputRoomEventConsumer(
db: store,
notifier: n,
query: queryAPI,
tracer: tracer,
}
consumer.ProcessMessage = s.onMessage
@ -80,13 +85,16 @@ func (s *OutputRoomEventConsumer) onMessage(msg *sarama.ConsumerMessage) error {
return nil
}
ctx, span := output.StartSpanAndReplaceContext(context.Background(), s.tracer)
defer span.Finish()
switch output.Type {
case api.OutputTypeNewRoomEvent:
return s.onNewRoomEvent(context.TODO(), *output.NewRoomEvent)
return s.onNewRoomEvent(ctx, *output.NewRoomEvent)
case api.OutputTypeNewInviteEvent:
return s.onNewInviteEvent(context.TODO(), *output.NewInviteEvent)
return s.onNewInviteEvent(ctx, *output.NewInviteEvent)
case api.OutputTypeRetireInviteEvent:
return s.onRetireInviteEvent(context.TODO(), *output.RetireInviteEvent)
return s.onRetireInviteEvent(ctx, *output.RetireInviteEvent)
default:
log.WithField("type", output.Type).Debug(
"roomserver output log: ignoring unknown output type",

View file

@ -25,32 +25,37 @@ import (
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/sync"
"github.com/matrix-org/util"
opentracing "github.com/opentracing/opentracing-go"
)
const pathPrefixR0 = "/_matrix/client/r0"
// Setup configures the given mux with sync-server listeners
func Setup(apiMux *mux.Router, srp *sync.RequestPool, syncDB *storage.SyncServerDatabase, deviceDB *devices.Database) {
func Setup(
apiMux *mux.Router,
srp *sync.RequestPool,
syncDB *storage.SyncServerDatabase,
deviceDB *devices.Database,
tracer opentracing.Tracer,
) {
r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter()
authData := auth.Data{nil, deviceDB, nil}
// TODO: Add AS support for all handlers below.
r0mux.Handle("/sync", common.MakeAuthAPI("sync", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
r0mux.Handle("/sync", common.MakeAuthAPI(tracer, "sync", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
return srp.OnIncomingSyncRequest(req, device)
})).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/state", common.MakeAuthAPI("room_state", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
r0mux.Handle("/rooms/{roomID}/state", common.MakeAuthAPI(tracer, "room_state", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
return OnIncomingStateRequest(req, syncDB, vars["roomID"])
})).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/state/{type}", common.MakeAuthAPI("room_state", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
r0mux.Handle("/rooms/{roomID}/state/{type}", common.MakeAuthAPI(tracer, "room_state", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
return OnIncomingStateTypeRequest(req, syncDB, vars["roomID"], vars["type"], "")
})).Methods(http.MethodGet, http.MethodOptions)
r0mux.Handle("/rooms/{roomID}/state/{type}/{stateKey}", common.MakeAuthAPI("room_state", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
r0mux.Handle("/rooms/{roomID}/state/{type}/{stateKey}", common.MakeAuthAPI(tracer, "room_state", authData, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
return OnIncomingStateTypeRequest(req, syncDB, vars["roomID"], vars["type"], vars["stateKey"])
})).Methods(http.MethodGet, http.MethodOptions)

View file

@ -24,7 +24,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/roomserver/api"
// Import the postgres database driver.
_ "github.com/lib/pq"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
@ -57,10 +57,10 @@ type SyncServerDatabase struct {
}
// NewSyncServerDatabase creates a new sync server database
func NewSyncServerDatabase(dataSourceName string) (*SyncServerDatabase, error) {
func NewSyncServerDatabase(tracers *common.Tracers, dataSourceName string) (*SyncServerDatabase, error) {
var d SyncServerDatabase
var err error
if d.db, err = sql.Open("postgres", dataSourceName); err != nil {
if d.db, err = common.OpenPostgresWithTracing(tracers, "sync", dataSourceName); err != nil {
return nil, err
}
if err = d.PartitionOffsetStatements.Prepare(d.db, "syncapi"); err != nil {

View file

@ -20,10 +20,10 @@ import (
"github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts"
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/common/basecomponent"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
"github.com/matrix-org/dendrite/syncapi/consumers"
"github.com/matrix-org/dendrite/syncapi/routing"
"github.com/matrix-org/dendrite/syncapi/storage"
@ -35,11 +35,14 @@ import (
// component.
func SetupSyncAPIComponent(
base *basecomponent.BaseDendrite,
tracers *common.Tracers,
deviceDB *devices.Database,
accountsDB *accounts.Database,
queryAPI api.RoomserverQueryAPI,
) {
syncDB, err := storage.NewSyncServerDatabase(string(base.Cfg.Database.SyncAPI))
tracer := tracers.SetupNewTracer("Dendrite - RoomserverAPI")
syncDB, err := storage.NewSyncServerDatabase(tracers, string(base.Cfg.Database.SyncAPI))
if err != nil {
logrus.WithError(err).Panicf("failed to connect to sync db")
}
@ -58,7 +61,7 @@ func SetupSyncAPIComponent(
requestPool := sync.NewRequestPool(syncDB, notifier, accountsDB)
roomConsumer := consumers.NewOutputRoomEventConsumer(
base.Cfg, base.KafkaConsumer, notifier, syncDB, queryAPI,
base.Cfg, base.KafkaConsumer, notifier, syncDB, queryAPI, tracer,
)
if err = roomConsumer.Start(); err != nil {
logrus.WithError(err).Panicf("failed to start room server consumer")
@ -71,5 +74,5 @@ func SetupSyncAPIComponent(
logrus.WithError(err).Panicf("failed to start client data consumer")
}
routing.Setup(base.APIMux, requestPool, syncDB, deviceDB)
routing.Setup(base.APIMux, requestPool, syncDB, deviceDB, tracer)
}

6
vendor/manifest vendored
View file

@ -77,6 +77,12 @@
"revision": "44cc805cf13205b55f69e14bcb69867d1ae92f98",
"branch": "master"
},
{
"importpath": "github.com/gchaincl/sqlhooks",
"repository": "https://github.com/gchaincl/sqlhooks",
"revision": "b4a12bad76664eae8012d196ed901f8fa8f87909",
"branch": "master"
},
{
"importpath": "github.com/golang/protobuf/proto",
"repository": "https://github.com/golang/protobuf",

View file

@ -0,0 +1,41 @@
# Change Log
## [Unreleased](https://github.com/gchaincl/sqlhooks/tree/HEAD)
[Full Changelog](https://github.com/gchaincl/sqlhooks/compare/v1.0.0...HEAD)
**Closed issues:**
- Add Benchmarks [\#9](https://github.com/gchaincl/sqlhooks/issues/9)
## [v1.0.0](https://github.com/gchaincl/sqlhooks/tree/v1.0.0) (2017-05-08)
[Full Changelog](https://github.com/gchaincl/sqlhooks/compare/v0.4...v1.0.0)
**Merged pull requests:**
- Godoc [\#7](https://github.com/gchaincl/sqlhooks/pull/7) ([gchaincl](https://github.com/gchaincl))
- Make covermode=count [\#6](https://github.com/gchaincl/sqlhooks/pull/6) ([gchaincl](https://github.com/gchaincl))
- V1 [\#5](https://github.com/gchaincl/sqlhooks/pull/5) ([gchaincl](https://github.com/gchaincl))
- Expose a WrapDriver function [\#4](https://github.com/gchaincl/sqlhooks/issues/4)
- Implement new 1.8 interfaces [\#3](https://github.com/gchaincl/sqlhooks/issues/3)
## [v0.4](https://github.com/gchaincl/sqlhooks/tree/v0.4) (2017-03-23)
[Full Changelog](https://github.com/gchaincl/sqlhooks/compare/v0.3...v0.4)
## [v0.3](https://github.com/gchaincl/sqlhooks/tree/v0.3) (2016-06-02)
[Full Changelog](https://github.com/gchaincl/sqlhooks/compare/v0.2...v0.3)
**Closed issues:**
- Change Notifications [\#2](https://github.com/gchaincl/sqlhooks/issues/2)
## [v0.2](https://github.com/gchaincl/sqlhooks/tree/v0.2) (2016-05-01)
[Full Changelog](https://github.com/gchaincl/sqlhooks/compare/v0.1...v0.2)
## [v0.1](https://github.com/gchaincl/sqlhooks/tree/v0.1) (2016-04-25)
**Merged pull requests:**
- Sqlite3 [\#1](https://github.com/gchaincl/sqlhooks/pull/1) ([gchaincl](https://github.com/gchaincl))
\* *This Change Log was automatically generated by [github_changelog_generator](https://github.com/skywinder/Github-Changelog-Generator)*

View file

@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) 2016 Gustavo Chaín
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View file

@ -0,0 +1,82 @@
# sqlhooks [![Build Status](https://travis-ci.org/gchaincl/sqlhooks.svg)](https://travis-ci.org/gchaincl/sqlhooks) [![Coverage Status](https://coveralls.io/repos/github/gchaincl/sqlhooks/badge.svg?branch=master)](https://coveralls.io/github/gchaincl/sqlhooks?branch=master) [![Go Report Card](https://goreportcard.com/badge/github.com/gchaincl/sqlhooks)](https://goreportcard.com/report/github.com/gchaincl/sqlhooks)
Attach hooks to any database/sql driver.
The purpose of sqlhooks is to provide a way to instrument your sql statements, making really easy to log queries or measure execution time without modifying your actual code.
# Install
```bash
go get github.com/gchaincl/sqlhooks
```
## Breaking changes
`V1` isn't backward compatible with previous versions, if you want to fetch old versions, you can get them from [gopkg.in](http://gopkg.in/)
```bash
go get gopkg.in/gchaincl/sqlhooks.v0
```
# Usage [![GoDoc](https://godoc.org/github.com/gchaincl/dotsql?status.svg)](https://godoc.org/github.com/gchaincl/sqlhooks)
```go
// This example shows how to instrument sql queries in order to display the time that they consume
package main
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/gchaincl/sqlhooks"
"github.com/mattn/go-sqlite3"
)
// Hooks satisfies the sqlhook.Hooks interface
type Hooks struct {}
// Before hook will print the query with it's args and return the context with the timestamp
func (h *Hooks) Before(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
fmt.Printf("> %s %q", query, args)
return context.WithValue(ctx, "begin", time.Now()), nil
}
// After hook will get the timestamp registered on the Before hook and print the elapsed time
func (h *Hooks) After(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
begin := ctx.Value("begin").(time.Time)
fmt.Printf(". took: %s\n", time.Since(begin))
return ctx, nil
}
func main() {
// First, register the wrapper
sql.Register("sqlite3WithHooks", sqlhooks.Wrap(&sqlite3.SQLiteDriver{}, &Hooks{}))
// Connect to the registered wrapped driver
db, _ := sql.Open("sqlite3WithHooks", ":memory:")
// Do you're stuff
db.Exec("CREATE TABLE t (id INTEGER, text VARCHAR(16))")
db.Exec("INSERT into t (text) VALUES(?), (?)", "foo", "bar")
db.Query("SELECT id, text FROM t")
}
/*
Output should look like:
> CREATE TABLE t (id INTEGER, text VARCHAR(16)) []. took: 121.238µs
> INSERT into t (text) VALUES(?), (?) ["foo" "bar"]. took: 36.364µs
> SELECT id, text FROM t []. took: 4.653µs
*/
```
# Benchmarks
```
go test -bench=. -benchmem
BenchmarkSQLite3/Without_Hooks-4 200000 8572 ns/op 627 B/op 16 allocs/op
BenchmarkSQLite3/With_Hooks-4 200000 10231 ns/op 738 B/op 18 allocs/op
BenchmarkMySQL/Without_Hooks-4 10000 108421 ns/op 437 B/op 10 allocs/op
BenchmarkMySQL/With_Hooks-4 10000 226085 ns/op 597 B/op 13 allocs/op
BenchmarkPostgres/Without_Hooks-4 10000 125718 ns/op 649 B/op 17 allocs/op
BenchmarkPostgres/With_Hooks-4 5000 354831 ns/op 1122 B/op 27 allocs/op
PASS
ok github.com/gchaincl/sqlhooks 11.713s
```

View file

@ -0,0 +1,76 @@
package sqlhooks
import (
"database/sql"
"os"
"testing"
"github.com/go-sql-driver/mysql"
"github.com/lib/pq"
"github.com/mattn/go-sqlite3"
"github.com/stretchr/testify/require"
)
func init() {
hooks := &testHooks{}
hooks.noop()
sql.Register("sqlite3-benchmark", Wrap(&sqlite3.SQLiteDriver{}, hooks))
sql.Register("mysql-benchmark", Wrap(&mysql.MySQLDriver{}, hooks))
sql.Register("postgres-benchmark", Wrap(&pq.Driver{}, hooks))
}
func benchmark(b *testing.B, driver, dsn string) {
db, err := sql.Open(driver, dsn)
require.NoError(b, err)
defer db.Close()
var query = "SELECT 'hello'"
b.ResetTimer()
for i := 0; i < b.N; i++ {
rows, err := db.Query(query)
require.NoError(b, err)
require.NoError(b, rows.Close())
}
}
func BenchmarkSQLite3(b *testing.B) {
b.Run("Without Hooks", func(b *testing.B) {
benchmark(b, "sqlite3", ":memory:")
})
b.Run("With Hooks", func(b *testing.B) {
benchmark(b, "sqlite3-benchmark", ":memory:")
})
}
func BenchmarkMySQL(b *testing.B) {
dsn := os.Getenv("SQLHOOKS_MYSQL_DSN")
if dsn == "" {
b.Skipf("SQLHOOKS_MYSQL_DSN not set")
}
b.Run("Without Hooks", func(b *testing.B) {
benchmark(b, "mysql", dsn)
})
b.Run("With Hooks", func(b *testing.B) {
benchmark(b, "mysql-benchmark", dsn)
})
}
func BenchmarkPostgres(b *testing.B) {
dsn := os.Getenv("SQLHOOKS_POSTGRES_DSN")
if dsn == "" {
b.Skipf("SQLHOOKS_POSTGRES_DSN not set")
}
b.Run("Without Hooks", func(b *testing.B) {
benchmark(b, "postgres", dsn)
})
b.Run("With Hooks", func(b *testing.B) {
benchmark(b, "postgres-benchmark", dsn)
})
}

View file

@ -0,0 +1,52 @@
// package sqlhooks allows you to attach hooks to any database/sql driver.
// The purpose of sqlhooks is to provide a way to instrument your sql statements, making really easy to log queries or measure execution time without modifying your actual code.
// This example shows how to instrument sql queries in order to display the time that they consume
// package main
//
// import (
// "context"
// "database/sql"
// "fmt"
// "time"
//
// "github.com/gchaincl/sqlhooks"
// "github.com/mattn/go-sqlite3"
// )
//
// // Hooks satisfies the sqlhook.Hooks interface
// type Hooks struct {}
//
// // Before hook will print the query with it's args and return the context with the timestamp
// func (h *Hooks) Before(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
// fmt.Printf("> %s %q", query, args)
// return context.WithValue(ctx, "begin", time.Now()), nil
// }
//
// // After hook will get the timestamp registered on the Before hook and print the elapsed time
// func (h *Hooks) After(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
// begin := ctx.Value("begin").(time.Time)
// fmt.Printf(". took: %s\n", time.Since(begin))
// return ctx, nil
// }
//
// func main() {
// // First, register the wrapper
// sql.Register("sqlite3WithHooks", sqlhooks.Wrap(&sqlite3.SQLiteDriver{}, &Hooks{}))
//
// // Connect to the registered wrapped driver
// db, _ := sql.Open("sqlite3WithHooks", ":memory:")
//
// // Do you're stuff
// db.Exec("CREATE TABLE t (id INTEGER, text VARCHAR(16))")
// db.Exec("INSERT into t (text) VALUES(?), (?)", "foo", "bar")
// db.Query("SELECT id, text FROM t")
// }
//
// /*
// Output should look like:
// > CREATE TABLE t (id INTEGER, text VARCHAR(16)) []. took: 121.238µs
// > INSERT into t (text) VALUES(?), (?) ["foo" "bar"]. took: 36.364µs
// > SELECT id, text FROM t []. took: 4.653µs
// */
package sqlhooks

View file

@ -0,0 +1,17 @@
package loghooks
import (
"database/sql"
"github.com/gchaincl/sqlhooks"
sqlite3 "github.com/mattn/go-sqlite3"
)
func Example() {
driver := sqlhooks.Wrap(&sqlite3.SQLiteDriver{}, New())
sql.Register("sqlite3-logger", driver)
db, _ := sql.Open("sqlite3-logger", ":memory:")
// This query will output logs
db.Query("SELECT 1+1")
}

View file

@ -0,0 +1,31 @@
package main
import (
"database/sql"
"log"
"github.com/gchaincl/sqlhooks"
"github.com/gchaincl/sqlhooks/hooks/loghooks"
"github.com/mattn/go-sqlite3"
)
func main() {
sql.Register("sqlite3log", sqlhooks.Wrap(&sqlite3.SQLiteDriver{}, loghooks.New()))
db, err := sql.Open("sqlite3log", ":memory:")
if err != nil {
log.Fatal(err)
}
if _, err := db.Exec("CREATE TABLE users(ID int, name text)"); err != nil {
log.Fatal(err)
}
if _, err := db.Exec(`INSERT INTO users (id, name) VALUES(?, ?)`, 1, "gus"); err != nil {
log.Fatal(err)
}
if _, err := db.Query(`SELECT id, name FROM users`); err != nil {
log.Fatal(err)
}
}

View file

@ -0,0 +1,30 @@
package loghooks
import (
"context"
"log"
"os"
"time"
)
type logger interface {
Printf(string, ...interface{})
}
type Hook struct {
log logger
}
func New() *Hook {
return &Hook{
log: log.New(os.Stderr, "", log.LstdFlags),
}
}
func (h *Hook) Before(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
return context.WithValue(ctx, "started", time.Now()), nil
}
func (h *Hook) After(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
h.log.Printf("Query: `%s`, Args: `%q`. took: %s", query, args, time.Since(ctx.Value("started").(time.Time)))
return ctx, nil
}

View file

@ -0,0 +1,39 @@
package main
import (
"context"
"database/sql"
"log"
"github.com/gchaincl/sqlhooks"
"github.com/gchaincl/sqlhooks/hooks/othooks"
"github.com/mattn/go-sqlite3"
"github.com/opentracing/opentracing-go"
)
func main() {
tracer := opentracing.GlobalTracer()
hooks := othooks.New(tracer)
sql.Register("sqlite3ot", sqlhooks.Wrap(&sqlite3.SQLiteDriver{}, hooks))
db, err := sql.Open("sqlite3ot", ":memory:")
if err != nil {
log.Fatal(err)
}
span := tracer.StartSpan("sql")
defer span.Finish()
ctx := opentracing.ContextWithSpan(context.Background(), span)
if _, err := db.ExecContext(ctx, "CREATE TABLE users(ID int, name text)"); err != nil {
log.Fatal(err)
}
if _, err := db.ExecContext(ctx, `INSERT INTO users (id, name) VALUES(?, ?)`, 1, "gus"); err != nil {
log.Fatal(err)
}
if _, err := db.QueryContext(ctx, `SELECT id, name FROM users`); err != nil {
log.Fatal(err)
}
}

View file

@ -0,0 +1,35 @@
package othooks
import "context"
import "github.com/opentracing/opentracing-go"
import "github.com/opentracing/opentracing-go/ext"
type Hook struct {
tracer opentracing.Tracer
}
func New(tracer opentracing.Tracer) *Hook {
return &Hook{tracer: tracer}
}
func (h *Hook) Before(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
parent := opentracing.SpanFromContext(ctx)
if parent == nil {
return ctx, nil
}
span := h.tracer.StartSpan("sql", opentracing.ChildOf(parent.Context()))
ext.DBStatement.Set(span, query)
return opentracing.ContextWithSpan(ctx, span), nil
}
func (h *Hook) After(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
span := opentracing.SpanFromContext(ctx)
if span != nil {
defer span.Finish()
}
return ctx, nil
}

View file

@ -0,0 +1,74 @@
package othooks
import (
"context"
"database/sql"
"testing"
"github.com/gchaincl/sqlhooks"
sqlite3 "github.com/mattn/go-sqlite3"
opentracing "github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/mocktracer"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
var (
tracer *mocktracer.MockTracer
)
func init() {
tracer = mocktracer.New()
driver := sqlhooks.Wrap(&sqlite3.SQLiteDriver{}, New(tracer))
sql.Register("ot", driver)
}
func TestSpansAreRecorded(t *testing.T) {
db, err := sql.Open("ot", ":memory:")
require.NoError(t, err)
defer db.Close()
tracer.Reset()
parent := tracer.StartSpan("parent")
ctx := opentracing.ContextWithSpan(context.Background(), parent)
{
rows, err := db.QueryContext(ctx, "SELECT 1+?", "1")
require.NoError(t, err)
rows.Close()
}
{
rows, err := db.QueryContext(ctx, "SELECT 1+?", "1")
require.NoError(t, err)
rows.Close()
}
parent.Finish()
spans := tracer.FinishedSpans()
require.Len(t, spans, 3)
span := spans[1]
assert.Equal(t, "sql", span.OperationName)
logFields := span.Logs()[0].Fields
assert.Equal(t, "query", logFields[0].Key)
assert.Equal(t, "SELECT 1+?", logFields[0].ValueString)
assert.Equal(t, "args", logFields[1].Key)
assert.Equal(t, "[1]", logFields[1].ValueString)
assert.NotEmpty(t, span.FinishTime)
}
func TesNoSpansAreRecorded(t *testing.T) {
db, err := sql.Open("ot", ":memory:")
require.NoError(t, err)
defer db.Close()
tracer.Reset()
rows, err := db.QueryContext(context.Background(), "SELECT 1")
require.NoError(t, err)
rows.Close()
assert.Empty(t, tracer.FinishedSpans())
}

View file

@ -0,0 +1,277 @@
package sqlhooks
import (
"context"
"errors"
"database/sql/driver"
)
// Hook is the hook callback signature
type Hook func(ctx context.Context, query string, args ...interface{}) (context.Context, error)
// Hooks instances may be passed to Wrap() to define an instrumented driver
type Hooks interface {
Before(ctx context.Context, query string, args ...interface{}) (context.Context, error)
After(ctx context.Context, query string, args ...interface{}) (context.Context, error)
}
// Driver implements a database/sql/driver.Driver
type Driver struct {
driver.Driver
hooks Hooks
}
// Open opens a connection
func (drv *Driver) Open(name string) (driver.Conn, error) {
conn, err := drv.Driver.Open(name)
if err != nil {
return conn, err
}
wrapped := &Conn{conn, drv.hooks}
if isExecer(conn) {
// If conn implements an Execer interface, return a driver.Conn which
// also implements Execer
return &ExecerContext{wrapped}, nil
}
return wrapped, nil
}
// Conn implements a database/sql.driver.Conn
type Conn struct {
Conn driver.Conn
hooks Hooks
}
func (conn *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
var (
stmt driver.Stmt
err error
)
if c, ok := conn.Conn.(driver.ConnPrepareContext); ok {
stmt, err = c.PrepareContext(ctx, query)
} else {
stmt, err = conn.Prepare(query)
}
if err != nil {
return stmt, err
}
return &Stmt{stmt, conn.hooks, query}, nil
}
func (conn *Conn) Prepare(query string) (driver.Stmt, error) { return conn.Conn.Prepare(query) }
func (conn *Conn) Close() error { return conn.Conn.Close() }
func (conn *Conn) Begin() (driver.Tx, error) { return conn.Conn.Begin() }
func (conn *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
return conn.Conn.(driver.ConnBeginTx).BeginTx(ctx, opts)
}
// ExecerContext implements a database/sql.driver.ExecerContext
type ExecerContext struct {
*Conn
}
func isExecer(conn driver.Conn) bool {
switch conn.(type) {
case driver.ExecerContext:
return true
case driver.Execer:
return true
default:
return false
}
}
func (conn *ExecerContext) execContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
switch c := conn.Conn.Conn.(type) {
case driver.ExecerContext:
return c.ExecContext(ctx, query, args)
case driver.Execer:
dargs, err := namedValueToValue(args)
if err != nil {
return nil, err
}
return c.Exec(query, dargs)
default:
// This should not happen
return nil, errors.New("ExecerContext created for a non Execer driver.Conn")
}
}
func (conn *ExecerContext) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
var err error
list := namedToInterface(args)
// Exec `Before` Hooks
if ctx, err = conn.hooks.Before(ctx, query, list...); err != nil {
return nil, err
}
results, err := conn.execContext(ctx, query, args)
if err != nil {
return results, err
}
if ctx, err = conn.hooks.After(ctx, query, list...); err != nil {
return nil, err
}
return results, err
}
func (conn *ExecerContext) Exec(query string, args []driver.Value) (driver.Result, error) {
// We have to implement Exec since it is required in the current version of
// Go for it to run ExecContext. From Go 10 it will be optional. However,
// this code should never run since database/sql always prefers to run
// ExecContext.
return nil, errors.New("Exec was called when ExecContext was implemented")
}
// Stmt implements a database/sql/driver.Stmt
type Stmt struct {
Stmt driver.Stmt
hooks Hooks
query string
}
func (stmt *Stmt) execContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
if s, ok := stmt.Stmt.(driver.StmtExecContext); ok {
return s.ExecContext(ctx, args)
}
values := make([]driver.Value, len(args))
for _, arg := range args {
values[arg.Ordinal-1] = arg.Value
}
return stmt.Exec(values)
}
func (stmt *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
var err error
list := namedToInterface(args)
// Exec `Before` Hooks
if ctx, err = stmt.hooks.Before(ctx, stmt.query, list...); err != nil {
return nil, err
}
results, err := stmt.execContext(ctx, args)
if err != nil {
return results, err
}
if ctx, err = stmt.hooks.After(ctx, stmt.query, list...); err != nil {
return nil, err
}
return results, err
}
func (stmt *Stmt) queryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
if s, ok := stmt.Stmt.(driver.StmtQueryContext); ok {
return s.QueryContext(ctx, args)
}
values := make([]driver.Value, len(args))
for _, arg := range args {
values[arg.Ordinal-1] = arg.Value
}
return stmt.Query(values)
}
func (stmt *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
var err error
list := namedToInterface(args)
// Exec Before Hooks
if ctx, err = stmt.hooks.Before(ctx, stmt.query, list...); err != nil {
return nil, err
}
rows, err := stmt.queryContext(ctx, args)
if err != nil {
return rows, err
}
if ctx, err = stmt.hooks.After(ctx, stmt.query, list...); err != nil {
return nil, err
}
return rows, err
}
func (stmt *Stmt) Close() error { return stmt.Stmt.Close() }
func (stmt *Stmt) NumInput() int { return stmt.Stmt.NumInput() }
func (stmt *Stmt) Exec(args []driver.Value) (driver.Result, error) { return stmt.Stmt.Exec(args) }
func (stmt *Stmt) Query(args []driver.Value) (driver.Rows, error) { return stmt.Stmt.Query(args) }
// Wrap is used to create a new instrumented driver, it takes a vendor specific driver, and a Hooks instance to produce a new driver instance.
// It's usually used inside a sql.Register() statement
func Wrap(driver driver.Driver, hooks Hooks) driver.Driver {
return &Driver{driver, hooks}
}
func namedToInterface(args []driver.NamedValue) []interface{} {
list := make([]interface{}, len(args))
for i, a := range args {
list[i] = a.Value
}
return list
}
// namedValueToValue copied from database/sql
func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) {
dargs := make([]driver.Value, len(named))
for n, param := range named {
if len(param.Name) > 0 {
return nil, errors.New("sql: driver does not support the use of Named Parameters")
}
dargs[n] = param.Value
}
return dargs, nil
}
/*
type hooks struct {
}
func (h *hooks) Before(ctx context.Context, query string, args ...interface{}) error {
log.Printf("before> ctx = %+v, q=%s, args = %+v\n", ctx, query, args)
return nil
}
func (h *hooks) After(ctx context.Context, query string, args ...interface{}) error {
log.Printf("after> ctx = %+v, q=%s, args = %+v\n", ctx, query, args)
return nil
}
func main() {
sql.Register("sqlite3-proxy", Wrap(&sqlite3.SQLiteDriver{}, &hooks{}))
db, err := sql.Open("sqlite3-proxy", ":memory:")
if err != nil {
log.Fatalln(err)
}
if _, ok := driver.Stmt(&Stmt{}).(driver.StmtExecContext); !ok {
panic("NOPE")
}
if _, err := db.Exec("CREATE table users(id int)"); err != nil {
log.Printf("|err| = %+v\n", err)
}
if _, err := db.QueryContext(context.Background(), "SELECT * FROM users WHERE id = ?", 1); err != nil {
log.Printf("err = %+v\n", err)
}
}
*/

View file

@ -0,0 +1,56 @@
package sqlhooks
import (
"database/sql"
"os"
"testing"
"github.com/go-sql-driver/mysql"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func setUpMySQL(t *testing.T, dsn string) {
db, err := sql.Open("mysql", dsn)
require.NoError(t, err)
require.NoError(t, db.Ping())
defer db.Close()
_, err = db.Exec("CREATE table IF NOT EXISTS users(id int, name text)")
require.NoError(t, err)
}
func TestMySQL(t *testing.T) {
dsn := os.Getenv("SQLHOOKS_MYSQL_DSN")
if dsn == "" {
t.Skipf("SQLHOOKS_MYSQL_DSN not set")
}
setUpMySQL(t, dsn)
s := newSuite(t, &mysql.MySQLDriver{}, dsn)
s.TestHooksExecution(t, "SELECT * FROM users WHERE id = ?", 1)
s.TestHooksArguments(t, "SELECT * FROM users WHERE id = ? AND name = ?", int64(1), "Gus")
s.TestHooksErrors(t, "SELECT 1+1")
t.Run("DBWorks", func(t *testing.T) {
s.hooks.noop()
if _, err := s.db.Exec("DELETE FROM users"); err != nil {
t.Fatal(err)
}
stmt, err := s.db.Prepare("INSERT INTO users (id, name) VALUES(?, ?)")
require.NoError(t, err)
for i := range [5]struct{}{} {
_, err := stmt.Exec(i, "gus")
require.NoError(t, err)
}
var count int
require.NoError(t,
s.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count),
)
assert.Equal(t, 5, count)
})
}

View file

@ -0,0 +1,56 @@
package sqlhooks
import (
"database/sql"
"os"
"testing"
"github.com/lib/pq"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func setUpPostgres(t *testing.T, dsn string) {
db, err := sql.Open("postgres", dsn)
require.NoError(t, err)
require.NoError(t, db.Ping())
defer db.Close()
_, err = db.Exec("CREATE table IF NOT EXISTS users(id int, name text)")
require.NoError(t, err)
}
func TestPostgres(t *testing.T) {
dsn := os.Getenv("SQLHOOKS_POSTGRES_DSN")
if dsn == "" {
t.Skipf("SQLHOOKS_POSTGRES_DSN not set")
}
setUpPostgres(t, dsn)
s := newSuite(t, &pq.Driver{}, dsn)
s.TestHooksExecution(t, "SELECT * FROM users WHERE id = $1", 1)
s.TestHooksArguments(t, "SELECT * FROM users WHERE id = $1 AND name = $2", int64(1), "Gus")
s.TestHooksErrors(t, "SELECT 1+1")
t.Run("DBWorks", func(t *testing.T) {
s.hooks.noop()
if _, err := s.db.Exec("DELETE FROM users"); err != nil {
t.Fatal(err)
}
stmt, err := s.db.Prepare("INSERT INTO users (id, name) VALUES($1, $2)")
require.NoError(t, err)
for i := range [5]struct{}{} {
_, err := stmt.Exec(i, "gus")
require.NoError(t, err)
}
var count int
require.NoError(t,
s.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count),
)
assert.Equal(t, 5, count)
})
}

View file

@ -0,0 +1,54 @@
package sqlhooks
import (
"database/sql"
"os"
"testing"
"time"
sqlite3 "github.com/mattn/go-sqlite3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func setUp(t *testing.T) func() {
dbName := "sqlite3test.db"
db, err := sql.Open("sqlite3", dbName)
require.NoError(t, err)
defer db.Close()
_, err = db.Exec("CREATE table users(id int, name text)")
require.NoError(t, err)
return func() { os.Remove(dbName) }
}
func TestSQLite3(t *testing.T) {
defer setUp(t)()
s := newSuite(t, &sqlite3.SQLiteDriver{}, "sqlite3test.db")
s.TestHooksExecution(t, "SELECT * FROM users WHERE id = ?", 1)
s.TestHooksArguments(t, "SELECT * FROM users WHERE id = ? AND name = ?", int64(1), "Gus")
s.TestHooksErrors(t, "SELECT 1+1")
t.Run("DBWorks", func(t *testing.T) {
s.hooks.noop()
if _, err := s.db.Exec("DELETE FROM users"); err != nil {
t.Fatal(err)
}
stmt, err := s.db.Prepare("INSERT INTO users (id, name) VALUES(?, ?)")
require.NoError(t, err)
for range [5]struct{}{} {
_, err := stmt.Exec(time.Now().UnixNano(), "gus")
require.NoError(t, err)
}
var count int
require.NoError(t,
s.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count),
)
assert.Equal(t, 5, count)
})
}

View file

@ -0,0 +1,167 @@
package sqlhooks
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type testHooks struct {
before Hook
after Hook
}
func (h *testHooks) noop() {
noop := func(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
return ctx, nil
}
h.before, h.after = noop, noop
}
func (h *testHooks) Before(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
return h.before(ctx, query, args...)
}
func (h *testHooks) After(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
return h.after(ctx, query, args...)
}
type suite struct {
db *sql.DB
hooks *testHooks
}
func newSuite(t *testing.T, driver driver.Driver, dsn string) *suite {
hooks := &testHooks{}
driverName := fmt.Sprintf("sqlhooks-%s", time.Now().String())
sql.Register(driverName, Wrap(driver, hooks))
db, err := sql.Open(driverName, dsn)
require.NoError(t, err)
require.NoError(t, db.Ping())
return &suite{db, hooks}
}
func (s *suite) TestHooksExecution(t *testing.T, query string, args ...interface{}) {
var before, after bool
s.hooks.before = func(ctx context.Context, q string, a ...interface{}) (context.Context, error) {
before = true
return ctx, nil
}
s.hooks.after = func(ctx context.Context, q string, a ...interface{}) (context.Context, error) {
after = true
return ctx, nil
}
t.Run("Query", func(t *testing.T) {
before, after = false, false
_, err := s.db.Query(query, args...)
require.NoError(t, err)
assert.True(t, before, "Before Hook did not run for query: "+query)
assert.True(t, after, "After Hook did not run for query: "+query)
})
t.Run("QueryContext", func(t *testing.T) {
before, after = false, false
_, err := s.db.QueryContext(context.Background(), query, args...)
require.NoError(t, err)
assert.True(t, before, "Before Hook did not run for query: "+query)
assert.True(t, after, "After Hook did not run for query: "+query)
})
t.Run("Exec", func(t *testing.T) {
before, after = false, false
_, err := s.db.Exec(query, args...)
require.NoError(t, err)
assert.True(t, before, "Before Hook did not run for query: "+query)
assert.True(t, after, "After Hook did not run for query: "+query)
})
t.Run("ExecContext", func(t *testing.T) {
before, after = false, false
_, err := s.db.ExecContext(context.Background(), query, args...)
require.NoError(t, err)
assert.True(t, before, "Before Hook did not run for query: "+query)
assert.True(t, after, "After Hook did not run for query: "+query)
})
t.Run("Statements", func(t *testing.T) {
before, after = false, false
stmt, err := s.db.Prepare(query)
require.NoError(t, err)
// Hooks just run when the stmt is executed (Query or Exec)
assert.False(t, before, "Before Hook run before execution: "+query)
assert.False(t, after, "After Hook run before execution: "+query)
stmt.Query(args...)
assert.True(t, before, "Before Hook did not run for query: "+query)
assert.True(t, after, "After Hook did not run for query: "+query)
})
}
func (s *suite) testHooksArguments(t *testing.T, query string, args ...interface{}) {
hook := func(ctx context.Context, q string, a ...interface{}) (context.Context, error) {
assert.Equal(t, query, q)
assert.Equal(t, args, a)
assert.Equal(t, "val", ctx.Value("key").(string))
return ctx, nil
}
s.hooks.before = hook
s.hooks.after = hook
ctx := context.WithValue(context.Background(), "key", "val")
{
_, err := s.db.QueryContext(ctx, query, args...)
require.NoError(t, err)
}
{
_, err := s.db.ExecContext(ctx, query, args...)
require.NoError(t, err)
}
}
func (s *suite) TestHooksArguments(t *testing.T, query string, args ...interface{}) {
t.Run("TestHooksArguments", func(t *testing.T) { s.testHooksArguments(t, query, args...) })
}
func (s *suite) testHooksErrors(t *testing.T, query string) {
boom := errors.New("boom")
s.hooks.before = func(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
return ctx, boom
}
s.hooks.after = func(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
assert.False(t, true, "this should not run")
return ctx, nil
}
_, err := s.db.Query(query)
assert.Equal(t, boom, err)
}
func (s *suite) TestHooksErrors(t *testing.T, query string) {
t.Run("TestHooksErrors", func(t *testing.T) { s.testHooksErrors(t, query) })
}
func TestNamedValueToValue(t *testing.T) {
named := []driver.NamedValue{
{Ordinal: 1, Value: "foo"},
{Ordinal: 2, Value: 42},
}
want := []driver.Value{"foo", 42}
dargs, err := namedValueToValue(named)
require.NoError(t, err)
assert.Equal(t, want, dargs)
}