All the opentracing stuff

This commit is contained in:
Erik Johnston 2017-12-04 09:57:03 +00:00
parent a5afbf4404
commit 0fc6432856
51 changed files with 1668 additions and 208 deletions

View file

@ -24,7 +24,6 @@ import (
"github.com/matrix-org/gomatrixserverlib"
"golang.org/x/crypto/bcrypt"
// Import the postgres database driver.
_ "github.com/lib/pq"
)
// Database represents an account database
@ -41,10 +40,10 @@ type Database struct {
}
// NewDatabase creates a new accounts and profiles database
func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) {
func NewDatabase(tracers *common.Tracers, dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) {
var db *sql.DB
var err error
if db, err = sql.Open("postgres", dataSourceName); err != nil {
if db, err = common.OpenPostgresWithTracing(tracers, "accounts", dataSourceName); err != nil {
return nil, err
}
partitions := common.PartitionOffsetStatements{}

View file

@ -31,10 +31,10 @@ type Database struct {
}
// NewDatabase creates a new device database
func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) {
func NewDatabase(tracers *common.Tracers, dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) {
var db *sql.DB
var err error
if db, err = sql.Open("postgres", dataSourceName); err != nil {
if db, err = common.OpenPostgresWithTracing(tracers, "devices", dataSourceName); err != nil {
return nil, err
}
d := devicesStatements{}

View file

@ -23,6 +23,7 @@ import (
"github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib"
opentracing "github.com/opentracing/opentracing-go"
log "github.com/sirupsen/logrus"
sarama "gopkg.in/Shopify/sarama.v1"
@ -34,6 +35,7 @@ type OutputRoomEventConsumer struct {
db *accounts.Database
query api.RoomserverQueryAPI
serverName string
tracer opentracing.Tracer
}
// NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers.
@ -42,6 +44,7 @@ func NewOutputRoomEventConsumer(
kafkaConsumer sarama.Consumer,
store *accounts.Database,
queryAPI api.RoomserverQueryAPI,
tracer opentracing.Tracer,
) *OutputRoomEventConsumer {
consumer := common.ContinualConsumer{
@ -54,6 +57,7 @@ func NewOutputRoomEventConsumer(
db: store,
query: queryAPI,
serverName: string(cfg.Matrix.ServerName),
tracer: tracer,
}
consumer.ProcessMessage = s.onMessage
@ -77,7 +81,7 @@ func (s *OutputRoomEventConsumer) onMessage(msg *sarama.ConsumerMessage) error {
return nil
}
ctx, span := output.StartSpanAndReplaceContext(context.Background())
ctx, span := output.StartSpanAndReplaceContext(context.Background(), s.tracer)
defer span.Finish()
if output.Type != api.OutputTypeNewRoomEvent {

View file

@ -30,6 +30,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
opentracing "github.com/opentracing/opentracing-go"
)
const pathPrefixV1 = "/_matrix/client/api/v1"
@ -48,10 +49,10 @@ func Setup(
keyRing gomatrixserverlib.KeyRing,
userUpdateProducer *producers.UserUpdateProducer,
syncProducer *producers.SyncAPIProducer,
tracer opentracing.Tracer,
) {
apiMux.Handle("/_matrix/client/versions",
common.MakeExternalAPI("versions", func(req *http.Request) util.JSONResponse {
common.MakeExternalAPI(tracer, "versions", func(req *http.Request) util.JSONResponse {
return util.JSONResponse{
Code: 200,
JSON: struct {
@ -70,12 +71,12 @@ func Setup(
unstableMux := apiMux.PathPrefix(pathPrefixUnstable).Subrouter()
r0mux.Handle("/createRoom",
common.MakeAuthAPI("createRoom", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "createRoom", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
return CreateRoom(req, device, cfg, producer, accountDB, aliasAPI)
}),
).Methods("POST", "OPTIONS")
r0mux.Handle("/join/{roomIDOrAlias}",
common.MakeAuthAPI("join", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "join", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
return JoinRoomByIDOrAlias(
req, device, vars["roomIDOrAlias"], cfg, federation, producer, queryAPI, aliasAPI, keyRing, accountDB,
@ -83,26 +84,26 @@ func Setup(
}),
).Methods("POST", "OPTIONS")
r0mux.Handle("/rooms/{roomID}/{membership:(?:join|kick|ban|unban|leave|invite)}",
common.MakeAuthAPI("membership", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "membership", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
return SendMembership(req, accountDB, device, vars["roomID"], vars["membership"], cfg, queryAPI, producer)
}),
).Methods("POST", "OPTIONS")
r0mux.Handle("/rooms/{roomID}/send/{eventType}",
common.MakeAuthAPI("send_message", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "send_message", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, nil, cfg, queryAPI, producer)
}),
).Methods("POST", "OPTIONS")
r0mux.Handle("/rooms/{roomID}/send/{eventType}/{txnID}",
common.MakeAuthAPI("send_message", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "send_message", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
txnID := vars["txnID"]
return SendEvent(req, device, vars["roomID"], vars["eventType"], &txnID, nil, cfg, queryAPI, producer)
}),
).Methods("PUT", "OPTIONS")
r0mux.Handle("/rooms/{roomID}/state/{eventType:[^/]+/?}",
common.MakeAuthAPI("send_message", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "send_message", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
emptyString := ""
eventType := vars["eventType"]
@ -114,54 +115,54 @@ func Setup(
}),
).Methods("PUT", "OPTIONS")
r0mux.Handle("/rooms/{roomID}/state/{eventType}/{stateKey}",
common.MakeAuthAPI("send_message", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "send_message", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
stateKey := vars["stateKey"]
return SendEvent(req, device, vars["roomID"], vars["eventType"], nil, &stateKey, cfg, queryAPI, producer)
}),
).Methods("PUT", "OPTIONS")
r0mux.Handle("/register", common.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse {
r0mux.Handle("/register", common.MakeExternalAPI(tracer, "register", func(req *http.Request) util.JSONResponse {
return Register(req, accountDB, deviceDB, &cfg)
})).Methods("POST", "OPTIONS")
v1mux.Handle("/register", common.MakeExternalAPI("register", func(req *http.Request) util.JSONResponse {
v1mux.Handle("/register", common.MakeExternalAPI(tracer, "register", func(req *http.Request) util.JSONResponse {
return LegacyRegister(req, accountDB, deviceDB, &cfg)
})).Methods("POST", "OPTIONS")
r0mux.Handle("/register/available", common.MakeExternalAPI("registerAvailable", func(req *http.Request) util.JSONResponse {
r0mux.Handle("/register/available", common.MakeExternalAPI(tracer, "registerAvailable", func(req *http.Request) util.JSONResponse {
return RegisterAvailable(req, accountDB)
})).Methods("GET", "OPTIONS")
r0mux.Handle("/directory/room/{roomAlias}",
common.MakeAuthAPI("directory_room", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "directory_room", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
return DirectoryRoom(req, vars["roomAlias"], federation, &cfg, aliasAPI)
}),
).Methods("GET", "OPTIONS")
r0mux.Handle("/directory/room/{roomAlias}",
common.MakeAuthAPI("directory_room", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "directory_room", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
return SetLocalAlias(req, device, vars["roomAlias"], &cfg, aliasAPI)
}),
).Methods("PUT", "OPTIONS")
r0mux.Handle("/directory/room/{roomAlias}",
common.MakeAuthAPI("directory_room", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "directory_room", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
return RemoveLocalAlias(req, device, vars["roomAlias"], aliasAPI)
}),
).Methods("DELETE", "OPTIONS")
r0mux.Handle("/logout",
common.MakeAuthAPI("logout", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "logout", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
return Logout(req, deviceDB, device)
}),
).Methods("POST", "OPTIONS")
r0mux.Handle("/logout/all",
common.MakeAuthAPI("logout", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "logout", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
return LogoutAll(req, deviceDB, device)
}),
).Methods("POST", "OPTIONS")
@ -169,13 +170,13 @@ func Setup(
// Stub endpoints required by Riot
r0mux.Handle("/login",
common.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse {
common.MakeExternalAPI(tracer, "login", func(req *http.Request) util.JSONResponse {
return Login(req, accountDB, deviceDB, cfg)
}),
).Methods("GET", "POST", "OPTIONS")
r0mux.Handle("/pushrules/",
common.MakeExternalAPI("push_rules", func(req *http.Request) util.JSONResponse {
common.MakeExternalAPI(tracer, "push_rules", func(req *http.Request) util.JSONResponse {
// TODO: Implement push rules API
res := json.RawMessage(`{
"global": {
@ -194,14 +195,14 @@ func Setup(
).Methods("GET", "OPTIONS")
r0mux.Handle("/user/{userId}/filter",
common.MakeAuthAPI("put_filter", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "put_filter", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
return PutFilter(req, device, accountDB, vars["userId"])
}),
).Methods("POST", "OPTIONS")
r0mux.Handle("/user/{userId}/filter/{filterId}",
common.MakeAuthAPI("get_filter", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "get_filter", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
return GetFilter(req, device, accountDB, vars["userId"], vars["filterId"])
}),
@ -210,21 +211,21 @@ func Setup(
// Riot user settings
r0mux.Handle("/profile/{userID}",
common.MakeExternalAPI("profile", func(req *http.Request) util.JSONResponse {
common.MakeExternalAPI(tracer, "profile", func(req *http.Request) util.JSONResponse {
vars := mux.Vars(req)
return GetProfile(req, accountDB, vars["userID"])
}),
).Methods("GET", "OPTIONS")
r0mux.Handle("/profile/{userID}/avatar_url",
common.MakeExternalAPI("profile_avatar_url", func(req *http.Request) util.JSONResponse {
common.MakeExternalAPI(tracer, "profile_avatar_url", func(req *http.Request) util.JSONResponse {
vars := mux.Vars(req)
return GetAvatarURL(req, accountDB, vars["userID"])
}),
).Methods("GET", "OPTIONS")
r0mux.Handle("/profile/{userID}/avatar_url",
common.MakeAuthAPI("profile_avatar_url", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "profile_avatar_url", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
return SetAvatarURL(req, accountDB, device, vars["userID"], userUpdateProducer, &cfg, producer, queryAPI)
}),
@ -233,14 +234,14 @@ func Setup(
// PUT requests, so we need to allow this method
r0mux.Handle("/profile/{userID}/displayname",
common.MakeExternalAPI("profile_displayname", func(req *http.Request) util.JSONResponse {
common.MakeExternalAPI(tracer, "profile_displayname", func(req *http.Request) util.JSONResponse {
vars := mux.Vars(req)
return GetDisplayName(req, accountDB, vars["userID"])
}),
).Methods("GET", "OPTIONS")
r0mux.Handle("/profile/{userID}/displayname",
common.MakeAuthAPI("profile_displayname", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "profile_displayname", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
return SetDisplayName(req, accountDB, device, vars["userID"], userUpdateProducer, &cfg, producer, queryAPI)
}),
@ -249,32 +250,32 @@ func Setup(
// PUT requests, so we need to allow this method
r0mux.Handle("/account/3pid",
common.MakeAuthAPI("account_3pid", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "account_3pid", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
return GetAssociated3PIDs(req, accountDB, device)
}),
).Methods("GET", "OPTIONS")
r0mux.Handle("/account/3pid",
common.MakeAuthAPI("account_3pid", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "account_3pid", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
return CheckAndSave3PIDAssociation(req, accountDB, device, cfg)
}),
).Methods("POST", "OPTIONS")
unstableMux.Handle("/account/3pid/delete",
common.MakeAuthAPI("account_3pid", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "account_3pid", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
return Forget3PID(req, accountDB)
}),
).Methods("POST", "OPTIONS")
r0mux.Handle("/{path:(?:account/3pid|register)}/email/requestToken",
common.MakeExternalAPI("account_3pid_request_token", func(req *http.Request) util.JSONResponse {
common.MakeExternalAPI(tracer, "account_3pid_request_token", func(req *http.Request) util.JSONResponse {
return RequestEmailToken(req, accountDB, cfg)
}),
).Methods("POST", "OPTIONS")
// Riot logs get flooded unless this is handled
r0mux.Handle("/presence/{userID}/status",
common.MakeExternalAPI("presence", func(req *http.Request) util.JSONResponse {
common.MakeExternalAPI(tracer, "presence", func(req *http.Request) util.JSONResponse {
// TODO: Set presence (probably the responsibility of a presence server not clientapi)
return util.JSONResponse{
Code: 200,
@ -284,13 +285,13 @@ func Setup(
).Methods("PUT", "OPTIONS")
r0mux.Handle("/voip/turnServer",
common.MakeAuthAPI("turn_server", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "turn_server", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
return RequestTurnServer(req, device, cfg)
}),
).Methods("GET", "OPTIONS")
unstableMux.Handle("/thirdparty/protocols",
common.MakeExternalAPI("thirdparty_protocols", func(req *http.Request) util.JSONResponse {
common.MakeExternalAPI(tracer, "thirdparty_protocols", func(req *http.Request) util.JSONResponse {
// TODO: Return the third party protcols
return util.JSONResponse{
Code: 200,
@ -300,7 +301,7 @@ func Setup(
).Methods("GET", "OPTIONS")
r0mux.Handle("/rooms/{roomID}/initialSync",
common.MakeExternalAPI("rooms_initial_sync", func(req *http.Request) util.JSONResponse {
common.MakeExternalAPI(tracer, "rooms_initial_sync", func(req *http.Request) util.JSONResponse {
// TODO: Allow people to peek into rooms.
return util.JSONResponse{
Code: 403,
@ -310,62 +311,62 @@ func Setup(
).Methods("GET", "OPTIONS")
r0mux.Handle("/user/{userID}/account_data/{type}",
common.MakeAuthAPI("user_account_data", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "user_account_data", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
return SaveAccountData(req, accountDB, device, vars["userID"], "", vars["type"], syncProducer)
}),
).Methods("PUT", "OPTIONS")
r0mux.Handle("/user/{userID}/rooms/{roomID}/account_data/{type}",
common.MakeAuthAPI("user_account_data", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "user_account_data", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
return SaveAccountData(req, accountDB, device, vars["userID"], vars["roomID"], vars["type"], syncProducer)
}),
).Methods("PUT", "OPTIONS")
r0mux.Handle("/rooms/{roomID}/members",
common.MakeAuthAPI("rooms_members", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "rooms_members", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
return GetMemberships(req, device, vars["roomID"], false, cfg, queryAPI)
}),
).Methods("GET", "OPTIONS")
r0mux.Handle("/rooms/{roomID}/joined_members",
common.MakeAuthAPI("rooms_members", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "rooms_members", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
return GetMemberships(req, device, vars["roomID"], true, cfg, queryAPI)
}),
).Methods("GET", "OPTIONS")
r0mux.Handle("/rooms/{roomID}/read_markers",
common.MakeExternalAPI("rooms_read_markers", func(req *http.Request) util.JSONResponse {
common.MakeExternalAPI(tracer, "rooms_read_markers", func(req *http.Request) util.JSONResponse {
// TODO: return the read_markers.
return util.JSONResponse{Code: 200, JSON: struct{}{}}
}),
).Methods("POST", "OPTIONS")
r0mux.Handle("/rooms/{roomID}/typing/{userID}",
common.MakeExternalAPI("rooms_typing", func(req *http.Request) util.JSONResponse {
common.MakeExternalAPI(tracer, "rooms_typing", func(req *http.Request) util.JSONResponse {
// TODO: handling typing
return util.JSONResponse{Code: 200, JSON: struct{}{}}
}),
).Methods("PUT", "OPTIONS")
r0mux.Handle("/devices",
common.MakeAuthAPI("get_devices", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "get_devices", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
return GetDevicesByLocalpart(req, deviceDB, device)
}),
).Methods("GET", "OPTIONS")
r0mux.Handle("/device/{deviceID}",
common.MakeAuthAPI("get_device", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "get_device", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
return GetDeviceByID(req, deviceDB, device, vars["deviceID"])
}),
).Methods("GET", "OPTIONS")
r0mux.Handle("/devices/{deviceID}",
common.MakeAuthAPI("device_data", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "device_data", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
return UpdateDeviceByID(req, deviceDB, device, vars["deviceID"])
}),
@ -373,7 +374,7 @@ func Setup(
// Stub implementations for sytest
r0mux.Handle("/events",
common.MakeExternalAPI("events", func(req *http.Request) util.JSONResponse {
common.MakeExternalAPI(tracer, "events", func(req *http.Request) util.JSONResponse {
return util.JSONResponse{Code: 200, JSON: map[string]interface{}{
"chunk": []interface{}{},
"start": "",
@ -383,7 +384,7 @@ func Setup(
).Methods("GET", "OPTIONS")
r0mux.Handle("/initialSync",
common.MakeExternalAPI("initial_sync", func(req *http.Request) util.JSONResponse {
common.MakeExternalAPI(tracer, "initial_sync", func(req *http.Request) util.JSONResponse {
return util.JSONResponse{Code: 200, JSON: map[string]interface{}{
"end": "",
}}

View file

@ -22,6 +22,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts"
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/gomatrixserverlib"
)
@ -63,7 +64,7 @@ func main() {
serverName := gomatrixserverlib.ServerName(*serverNameStr)
accountDB, err := accounts.NewDatabase(*database, serverName)
accountDB, err := accounts.NewDatabase(common.NoopTracers(), *database, serverName)
if err != nil {
fmt.Println(err.Error())
os.Exit(1)
@ -75,7 +76,7 @@ func main() {
os.Exit(1)
}
deviceDB, err := devices.NewDatabase(*database, serverName)
deviceDB, err := devices.NewDatabase(common.NoopTracers(), *database, serverName)
if err != nil {
fmt.Println(err.Error())
os.Exit(1)

View file

@ -19,6 +19,8 @@ import (
"net/http"
"os"
"github.com/opentracing/opentracing-go"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts"
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
@ -51,11 +53,13 @@ func main() {
log.Fatalf("Invalid config file: %s", err)
}
closer, err := cfg.SetupTracing("DendriteClientAPI")
tracers := common.NewTracers(cfg)
defer tracers.Close() // nolint: errcheck
err = tracers.InitGlobalTracer("Dendrite - ClientAPI")
if err != nil {
log.WithError(err).Fatalf("Failed to start tracer")
}
defer closer.Close() // nolint: errcheck
queryAPI := api.NewRoomserverQueryAPIHTTP(cfg.RoomServerURL(), nil)
aliasAPI := api.NewRoomserverAliasAPIHTTP(cfg.RoomServerURL(), nil)
@ -85,15 +89,15 @@ func main() {
cfg.Matrix.ServerName, cfg.Matrix.KeyID, cfg.Matrix.PrivateKey,
)
accountDB, err := accounts.NewDatabase(string(cfg.Database.Account), cfg.Matrix.ServerName)
accountDB, err := accounts.NewDatabase(tracers, string(cfg.Database.Account), cfg.Matrix.ServerName)
if err != nil {
log.Panicf("Failed to setup account database(%q): %s", cfg.Database.Account, err.Error())
}
deviceDB, err := devices.NewDatabase(string(cfg.Database.Device), cfg.Matrix.ServerName)
deviceDB, err := devices.NewDatabase(tracers, string(cfg.Database.Device), cfg.Matrix.ServerName)
if err != nil {
log.Panicf("Failed to setup device database(%q): %s", cfg.Database.Device, err.Error())
}
keyDB, err := keydb.NewDatabase(string(cfg.Database.ServerKey))
keyDB, err := keydb.NewDatabase(tracers, string(cfg.Database.ServerKey))
if err != nil {
log.Panicf("Failed to setup key database(%q): %s", cfg.Database.ServerKey, err.Error())
}
@ -108,7 +112,7 @@ func main() {
}).Panic("Failed to setup kafka consumers")
}
consumer := consumers.NewOutputRoomEventConsumer(cfg, kafkaConsumer, accountDB, queryAPI)
consumer := consumers.NewOutputRoomEventConsumer(cfg, kafkaConsumer, accountDB, queryAPI, opentracing.GlobalTracer())
if err = consumer.Start(); err != nil {
log.Panicf("startup: failed to start room server consumer")
}
@ -120,6 +124,7 @@ func main() {
api, *cfg, roomserverProducer,
queryAPI, aliasAPI, accountDB, deviceDB, federation, keyRing,
userUpdateProducer, syncProducer,
opentracing.GlobalTracer(),
)
common.SetupHTTPAPI(http.DefaultServeMux, common.WrapHandlerInCORS(api))

View file

@ -28,6 +28,7 @@ import (
"github.com/matrix-org/dendrite/federationapi/routing"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib"
opentracing "github.com/opentracing/opentracing-go"
log "github.com/sirupsen/logrus"
)
@ -50,22 +51,24 @@ func main() {
log.Fatalf("Invalid config file: %s", err)
}
closer, err := cfg.SetupTracing("DendriteFederationAPI")
tracers := common.NewTracers(cfg)
defer tracers.Close() // nolint: errcheck
err = tracers.InitGlobalTracer("Dendrite - FenderationAPI")
if err != nil {
log.WithError(err).Fatalf("Failed to start tracer")
}
defer closer.Close() // nolint: errcheck
federation := gomatrixserverlib.NewFederationClient(
cfg.Matrix.ServerName, cfg.Matrix.KeyID, cfg.Matrix.PrivateKey,
)
keyDB, err := keydb.NewDatabase(string(cfg.Database.ServerKey))
keyDB, err := keydb.NewDatabase(tracers, string(cfg.Database.ServerKey))
if err != nil {
log.Panicf("Failed to setup key database(%q): %s", cfg.Database.ServerKey, err.Error())
}
accountDB, err := accounts.NewDatabase(string(cfg.Database.Account), cfg.Matrix.ServerName)
accountDB, err := accounts.NewDatabase(tracers, string(cfg.Database.Account), cfg.Matrix.ServerName)
if err != nil {
log.Panicf("Failed to setup account database(%q): %s", cfg.Database.Account, err.Error())
}
@ -91,7 +94,7 @@ func main() {
log.Info("Starting federation API server on ", cfg.Listen.FederationAPI)
api := mux.NewRouter()
routing.Setup(api, *cfg, queryAPI, aliasAPI, roomserverProducer, keyRing, federation, accountDB)
routing.Setup(api, *cfg, queryAPI, aliasAPI, roomserverProducer, keyRing, federation, accountDB, opentracing.GlobalTracer())
common.SetupHTTPAPI(http.DefaultServeMux, api)
log.Fatal(http.ListenAndServe(string(cfg.Listen.FederationAPI), nil))

View file

@ -27,6 +27,7 @@ import (
"github.com/matrix-org/dendrite/federationsender/storage"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib"
opentracing "github.com/opentracing/opentracing-go"
log "github.com/sirupsen/logrus"
sarama "gopkg.in/Shopify/sarama.v1"
@ -47,11 +48,13 @@ func main() {
log.Fatalf("Invalid config file: %s", err)
}
closer, err := cfg.SetupTracing("DendriteFederationSender")
tracers := common.NewTracers(cfg)
defer tracers.Close() // nolint: errcheck
err = tracers.InitGlobalTracer("Dendrite - Federation Sender")
if err != nil {
log.WithError(err).Fatalf("Failed to start tracer")
}
defer closer.Close() // nolint: errcheck
kafkaConsumer, err := sarama.NewConsumer(cfg.Kafka.Addresses, nil)
if err != nil {
@ -63,7 +66,7 @@ func main() {
queryAPI := api.NewRoomserverQueryAPIHTTP(cfg.RoomServerURL(), nil)
db, err := storage.NewDatabase(string(cfg.Database.FederationSender))
db, err := storage.NewDatabase(tracers, string(cfg.Database.FederationSender))
if err != nil {
log.Panicf("startup: failed to create federation sender database with data source %s : %s", cfg.Database.FederationSender, err)
}
@ -74,7 +77,7 @@ func main() {
queues := queue.NewOutgoingQueues(cfg.Matrix.ServerName, federation)
consumer := consumers.NewOutputRoomEventConsumer(cfg, kafkaConsumer, queues, db, queryAPI)
consumer := consumers.NewOutputRoomEventConsumer(cfg, kafkaConsumer, queues, db, queryAPI, opentracing.GlobalTracer())
if err = consumer.Start(); err != nil {
log.WithError(err).Panicf("startup: failed to start room server consumer")
}

View file

@ -26,6 +26,7 @@ import (
"github.com/matrix-org/dendrite/mediaapi/routing"
"github.com/matrix-org/dendrite/mediaapi/storage"
"github.com/matrix-org/gomatrixserverlib"
opentracing "github.com/opentracing/opentracing-go"
log "github.com/sirupsen/logrus"
)
@ -48,18 +49,20 @@ func main() {
log.Fatalf("Invalid config file: %s", err)
}
closer, err := cfg.SetupTracing("DendriteMediaAPI")
tracers := common.NewTracers(cfg)
defer tracers.Close() // nolint: errcheck
err = tracers.InitGlobalTracer("Dendrite - MediaAPI")
if err != nil {
log.WithError(err).Fatalf("Failed to start tracer")
}
defer closer.Close() // nolint: errcheck
db, err := storage.Open(string(cfg.Database.MediaAPI))
db, err := storage.Open(tracers, string(cfg.Database.MediaAPI))
if err != nil {
log.WithError(err).Panic("Failed to open database")
}
deviceDB, err := devices.NewDatabase(string(cfg.Database.Device), cfg.Matrix.ServerName)
deviceDB, err := devices.NewDatabase(tracers, string(cfg.Database.Device), cfg.Matrix.ServerName)
if err != nil {
log.WithError(err).Panicf("Failed to setup device database(%q)", cfg.Database.Device)
}
@ -69,7 +72,7 @@ func main() {
log.Info("Starting media API server on ", cfg.Listen.MediaAPI)
api := mux.NewRouter()
routing.Setup(api, cfg, db, deviceDB, client)
routing.Setup(api, cfg, db, deviceDB, client, opentracing.GlobalTracer())
common.SetupHTTPAPI(http.DefaultServeMux, common.WrapHandlerInCORS(api))
log.Fatal(http.ListenAndServe(string(cfg.Listen.MediaAPI), nil))

View file

@ -21,6 +21,8 @@ import (
"net/http"
"os"
"github.com/opentracing/opentracing-go"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts"
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
@ -84,13 +86,15 @@ func main() {
log.Fatalf("Invalid config file: %s", err)
}
closer, err := cfg.SetupTracing("DendriteMonolith")
tracers := common.NewTracers(cfg)
defer tracers.Close() // nolint: errcheck
err = tracers.InitGlobalTracer("DendriteMonolith")
if err != nil {
log.WithError(err).Fatalf("Failed to start tracer")
}
defer closer.Close() // nolint: errcheck
m := newMonolith(cfg)
m := newMonolith(cfg, tracers)
m.setupDatabases()
m.setupFederation()
m.setupKafka()
@ -136,8 +140,8 @@ type monolith struct {
federation *gomatrixserverlib.FederationClient
keyRing gomatrixserverlib.KeyRing
inputAPI *roomserver_input.RoomserverInputAPI
queryAPI *roomserver_query.RoomserverQueryAPI
inputAPI *roomserver_input.InProcessRoomServerInput
queryAPI *roomserver_query.InProcessRoomServerQueryAPI
aliasAPI *roomserver_alias.RoomserverAliasAPI
naffka *naffka.Naffka
@ -148,43 +152,61 @@ type monolith struct {
syncProducer *producers.SyncAPIProducer
syncAPINotifier *syncapi_sync.Notifier
tracers *common.Tracers
clientAPITracer opentracing.Tracer
syncAPITracer opentracing.Tracer
mediaAPITracer opentracing.Tracer
federationAPITracer opentracing.Tracer
publicRoomsAPITracer opentracing.Tracer
roomServerTracer opentracing.Tracer
}
func newMonolith(cfg *config.Dendrite) *monolith {
return &monolith{cfg: cfg, api: mux.NewRouter()}
func newMonolith(cfg *config.Dendrite, tracers *common.Tracers) *monolith {
return &monolith{
cfg: cfg,
api: mux.NewRouter(),
tracers: tracers,
clientAPITracer: tracers.SetupNewTracer("ClientAPI"),
syncAPITracer: tracers.SetupNewTracer("SyncAPI"),
mediaAPITracer: tracers.SetupNewTracer("MediaAPI"),
federationAPITracer: tracers.SetupNewTracer("FederationAPI"),
publicRoomsAPITracer: tracers.SetupNewTracer("PublicRooms"),
roomServerTracer: tracers.SetupNewTracer("RoomServer"),
}
}
func (m *monolith) setupDatabases() {
var err error
m.roomServerDB, err = roomserver_storage.Open(string(m.cfg.Database.RoomServer))
m.roomServerDB, err = roomserver_storage.Open(m.tracers, string(m.cfg.Database.RoomServer))
if err != nil {
panic(err)
}
m.accountDB, err = accounts.NewDatabase(string(m.cfg.Database.Account), m.cfg.Matrix.ServerName)
m.accountDB, err = accounts.NewDatabase(m.tracers, string(m.cfg.Database.Account), m.cfg.Matrix.ServerName)
if err != nil {
log.Panicf("Failed to setup account database(%q): %s", m.cfg.Database.Account, err.Error())
}
m.deviceDB, err = devices.NewDatabase(string(m.cfg.Database.Device), m.cfg.Matrix.ServerName)
m.deviceDB, err = devices.NewDatabase(m.tracers, string(m.cfg.Database.Device), m.cfg.Matrix.ServerName)
if err != nil {
log.Panicf("Failed to setup device database(%q): %s", m.cfg.Database.Device, err.Error())
}
m.keyDB, err = keydb.NewDatabase(string(m.cfg.Database.ServerKey))
m.keyDB, err = keydb.NewDatabase(m.tracers, string(m.cfg.Database.ServerKey))
if err != nil {
log.Panicf("Failed to setup key database(%q): %s", m.cfg.Database.ServerKey, err.Error())
}
m.mediaAPIDB, err = mediaapi_storage.Open(string(m.cfg.Database.MediaAPI))
m.mediaAPIDB, err = mediaapi_storage.Open(m.tracers, string(m.cfg.Database.MediaAPI))
if err != nil {
log.Panicf("Failed to setup sync api database(%q): %s", m.cfg.Database.MediaAPI, err.Error())
}
m.syncAPIDB, err = syncapi_storage.NewSyncServerDatabase(string(m.cfg.Database.SyncAPI))
m.syncAPIDB, err = syncapi_storage.NewSyncServerDatabase(m.tracers, string(m.cfg.Database.SyncAPI))
if err != nil {
log.Panicf("Failed to setup sync api database(%q): %s", m.cfg.Database.SyncAPI, err.Error())
}
m.federationSenderDB, err = federationsender_storage.NewDatabase(string(m.cfg.Database.FederationSender))
m.federationSenderDB, err = federationsender_storage.NewDatabase(m.tracers, string(m.cfg.Database.FederationSender))
if err != nil {
log.Panicf("startup: failed to create federation sender database with data source %s : %s", m.cfg.Database.FederationSender, err)
}
m.publicRoomsAPIDB, err = publicroomsapi_storage.NewPublicRoomsServerDatabase(string(m.cfg.Database.PublicRoomsAPI))
m.publicRoomsAPIDB, err = publicroomsapi_storage.NewPublicRoomsServerDatabase(m.tracers, string(m.cfg.Database.PublicRoomsAPI))
if err != nil {
log.Panicf("startup: failed to setup public rooms api database with data source %s : %s", m.cfg.Database.PublicRoomsAPI, err)
}
@ -249,15 +271,22 @@ func (m *monolith) kafkaConsumer() sarama.Consumer {
}
func (m *monolith) setupRoomServer() {
m.inputAPI = &roomserver_input.RoomserverInputAPI{
DB: m.roomServerDB,
Producer: m.kafkaProducer,
OutputRoomEventTopic: string(m.cfg.Kafka.Topics.OutputRoomEvent),
}
m.inputAPI = roomserver_input.NewInProcessRoomServerInput(
roomserver_input.RoomserverInputAPI{
DB: m.roomServerDB,
Producer: m.kafkaProducer,
OutputRoomEventTopic: string(m.cfg.Kafka.Topics.OutputRoomEvent),
},
m.roomServerTracer,
)
m.queryAPI = &roomserver_query.RoomserverQueryAPI{
DB: m.roomServerDB,
}
q := roomserver_query.NewInProcessRoomServerQueryAPI(
roomserver_query.RoomserverQueryAPI{
DB: m.roomServerDB,
},
m.roomServerTracer,
)
m.queryAPI = &q
m.aliasAPI = &roomserver_alias.RoomserverAliasAPI{
DB: m.roomServerDB,
@ -296,6 +325,7 @@ func (m *monolith) setupConsumers() {
clientAPIConsumer := clientapi_consumers.NewOutputRoomEventConsumer(
m.cfg, m.kafkaConsumer(), m.accountDB, m.queryAPI,
m.clientAPITracer,
)
if err = clientAPIConsumer.Start(); err != nil {
log.Panicf("startup: failed to start room server consumer: %s", err)
@ -303,6 +333,7 @@ func (m *monolith) setupConsumers() {
syncAPIRoomConsumer := syncapi_consumers.NewOutputRoomEventConsumer(
m.cfg, m.kafkaConsumer(), m.syncAPINotifier, m.syncAPIDB, m.queryAPI,
m.syncAPITracer,
)
if err = syncAPIRoomConsumer.Start(); err != nil {
log.Panicf("startup: failed to start room server consumer: %s", err)
@ -316,7 +347,7 @@ func (m *monolith) setupConsumers() {
}
publicRoomsAPIConsumer := publicroomsapi_consumers.NewOutputRoomEventConsumer(
m.cfg, m.kafkaConsumer(), m.publicRoomsAPIDB, m.queryAPI,
m.cfg, m.kafkaConsumer(), m.publicRoomsAPIDB, m.queryAPI, m.publicRoomsAPITracer,
)
if err = publicRoomsAPIConsumer.Start(); err != nil {
log.Panicf("startup: failed to start room server consumer: %s", err)
@ -326,6 +357,7 @@ func (m *monolith) setupConsumers() {
federationSenderRoomConsumer := federationsender_consumers.NewOutputRoomEventConsumer(
m.cfg, m.kafkaConsumer(), federationSenderQueues, m.federationSenderDB, m.queryAPI,
m.federationAPITracer,
)
if err = federationSenderRoomConsumer.Start(); err != nil {
log.WithError(err).Panicf("startup: failed to start room server consumer")
@ -337,20 +369,22 @@ func (m *monolith) setupAPIs() {
m.api, *m.cfg, m.roomServerProducer,
m.queryAPI, m.aliasAPI, m.accountDB, m.deviceDB, m.federation, m.keyRing,
m.userUpdateProducer, m.syncProducer,
m.clientAPITracer,
)
mediaapi_routing.Setup(
m.api, m.cfg, m.mediaAPIDB, m.deviceDB, &m.federation.Client,
m.mediaAPITracer,
)
syncapi_routing.Setup(m.api, syncapi_sync.NewRequestPool(
m.syncAPIDB, m.syncAPINotifier, m.accountDB,
), m.syncAPIDB, m.deviceDB)
), m.syncAPIDB, m.deviceDB, m.syncAPITracer)
federationapi_routing.Setup(
m.api, *m.cfg, m.queryAPI, m.aliasAPI, m.roomServerProducer, m.keyRing, m.federation,
m.accountDB,
m.accountDB, m.federationAPITracer,
)
publicroomsapi_routing.Setup(m.api, m.deviceDB, m.publicRoomsAPIDB)
publicroomsapi_routing.Setup(m.api, m.deviceDB, m.publicRoomsAPIDB, m.publicRoomsAPITracer)
}

View file

@ -19,6 +19,8 @@ import (
"net/http"
"os"
"github.com/opentracing/opentracing-go"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
"github.com/matrix-org/dendrite/common"
@ -47,20 +49,22 @@ func main() {
log.Fatalf("Invalid config file: %s", err)
}
closer, err := cfg.SetupTracing("DendritePublicRoomsAPI")
tracers := common.NewTracers(cfg)
defer tracers.Close() // nolint: errcheck
err = tracers.InitGlobalTracer("Dendrite - Public Rooms API")
if err != nil {
log.WithError(err).Fatalf("Failed to start tracer")
}
defer closer.Close() // nolint: errcheck
queryAPI := api.NewRoomserverQueryAPIHTTP(cfg.RoomServerURL(), nil)
db, err := storage.NewPublicRoomsServerDatabase(string(cfg.Database.PublicRoomsAPI))
db, err := storage.NewPublicRoomsServerDatabase(tracers, string(cfg.Database.PublicRoomsAPI))
if err != nil {
log.Panicf("startup: failed to create public rooms server database with data source %s : %s", cfg.Database.PublicRoomsAPI, err)
}
deviceDB, err := devices.NewDatabase(string(cfg.Database.Device), cfg.Matrix.ServerName)
deviceDB, err := devices.NewDatabase(tracers, string(cfg.Database.Device), cfg.Matrix.ServerName)
if err != nil {
log.Panicf("startup: failed to create device database with data source %s : %s", cfg.Database.Device, err)
}
@ -73,7 +77,7 @@ func main() {
}).Panic("Failed to setup kafka consumers")
}
roomConsumer := consumers.NewOutputRoomEventConsumer(cfg, kafkaConsumer, db, queryAPI)
roomConsumer := consumers.NewOutputRoomEventConsumer(cfg, kafkaConsumer, db, queryAPI, opentracing.GlobalTracer())
if err != nil {
log.Panicf("startup: failed to create room server consumer: %s", err)
}
@ -84,7 +88,7 @@ func main() {
log.Info("Starting public rooms server on ", cfg.Listen.PublicRoomsAPI)
api := mux.NewRouter()
routing.Setup(api, deviceDB, db)
routing.Setup(api, deviceDB, db, opentracing.GlobalTracer())
common.SetupHTTPAPI(http.DefaultServeMux, common.WrapHandlerInCORS(api))
log.Fatal(http.ListenAndServe(string(cfg.Listen.PublicRoomsAPI), nil))

View file

@ -20,6 +20,8 @@ import (
_ "net/http/pprof"
"os"
"github.com/opentracing/opentracing-go"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/dendrite/roomserver/alias"
@ -49,13 +51,15 @@ func main() {
log.Fatalf("Invalid config file: %s", err)
}
closer, err := cfg.SetupTracing("DendriteRoomServer")
tracers := common.NewTracers(cfg)
defer tracers.Close() // nolint: errcheck
err = tracers.InitGlobalTracer("Dendrite - RoomServer")
if err != nil {
log.WithError(err).Fatalf("Failed to start tracer")
}
defer closer.Close() // nolint: errcheck
db, err := storage.Open(string(cfg.Database.RoomServer))
db, err := storage.Open(tracers, string(cfg.Database.RoomServer))
if err != nil {
panic(err)
}
@ -71,11 +75,11 @@ func main() {
OutputRoomEventTopic: string(cfg.Kafka.Topics.OutputRoomEvent),
}
inputAPI.SetupHTTP(http.DefaultServeMux)
inputAPI.SetupHTTP(http.DefaultServeMux, opentracing.GlobalTracer())
queryAPI := query.RoomserverQueryAPI{DB: db}
queryAPI.SetupHTTP(http.DefaultServeMux)
queryAPI.SetupHTTP(http.DefaultServeMux, opentracing.GlobalTracer())
aliasAPI := alias.RoomserverAliasAPI{
DB: db,
@ -84,7 +88,7 @@ func main() {
QueryAPI: &queryAPI,
}
aliasAPI.SetupHTTP(http.DefaultServeMux)
aliasAPI.SetupHTTP(http.DefaultServeMux, opentracing.GlobalTracer())
// This is deprecated, but prometheus are still arguing on what to replace
// it with. Alternatively we could set it up manually.

View file

@ -20,6 +20,8 @@ import (
"net/http"
"os"
"github.com/opentracing/opentracing-go"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/clientapi/auth/storage/accounts"
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
@ -51,25 +53,27 @@ func main() {
log.Fatalf("Invalid config file: %s", err)
}
closer, err := cfg.SetupTracing("DendriteSyncAPI")
tracers := common.NewTracers(cfg)
defer tracers.Close() // nolint: errcheck
err = tracers.InitGlobalTracer("Dendrite - SyncAPI")
if err != nil {
log.WithError(err).Fatalf("Failed to start tracer")
}
defer closer.Close() // nolint: errcheck
queryAPI := api.NewRoomserverQueryAPIHTTP(cfg.RoomServerURL(), nil)
db, err := storage.NewSyncServerDatabase(string(cfg.Database.SyncAPI))
db, err := storage.NewSyncServerDatabase(tracers, string(cfg.Database.SyncAPI))
if err != nil {
log.Panicf("startup: failed to create sync server database with data source %s : %s", cfg.Database.SyncAPI, err)
}
deviceDB, err := devices.NewDatabase(string(cfg.Database.Device), cfg.Matrix.ServerName)
deviceDB, err := devices.NewDatabase(tracers, string(cfg.Database.Device), cfg.Matrix.ServerName)
if err != nil {
log.Panicf("startup: failed to create device database with data source %s : %s", cfg.Database.Device, err)
}
adb, err := accounts.NewDatabase(string(cfg.Database.Account), cfg.Matrix.ServerName)
adb, err := accounts.NewDatabase(tracers, string(cfg.Database.Account), cfg.Matrix.ServerName)
if err != nil {
log.Panicf("startup: failed to create account database with data source %s : %s", cfg.Database.Account, err)
}
@ -92,7 +96,7 @@ func main() {
}).Panic("Failed to setup kafka consumers")
}
roomConsumer := consumers.NewOutputRoomEventConsumer(cfg, kafkaConsumer, n, db, queryAPI)
roomConsumer := consumers.NewOutputRoomEventConsumer(cfg, kafkaConsumer, n, db, queryAPI, opentracing.GlobalTracer())
if err = roomConsumer.Start(); err != nil {
log.Panicf("startup: failed to start room server consumer: %s", err)
}
@ -104,7 +108,7 @@ func main() {
log.Info("Starting sync server on ", cfg.Listen.SyncAPI)
api := mux.NewRouter()
routing.Setup(api, sync.NewRequestPool(db, n, adb), db, deviceDB)
routing.Setup(api, sync.NewRequestPool(db, n, adb), db, deviceDB, opentracing.GlobalTracer())
common.SetupHTTPAPI(http.DefaultServeMux, common.WrapHandlerInCORS(api))
log.Fatal(http.ListenAndServe(string(cfg.Listen.SyncAPI), nil))

View file

@ -19,7 +19,6 @@ import (
"crypto/sha256"
"encoding/pem"
"fmt"
"io"
"io/ioutil"
"path/filepath"
"strings"
@ -27,12 +26,10 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/gomatrixserverlib"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/ed25519"
"gopkg.in/yaml.v2"
jaegerconfig "github.com/uber/jaeger-client-go/config"
jaegermetrics "github.com/uber/jaeger-lib/metrics"
)
// Version is the current version of the config format.
@ -514,25 +511,3 @@ func (config *Dendrite) RoomServerURL() string {
// internet for an internal API.
return "http://" + string(config.Listen.RoomServer)
}
// SetupTracing configures the opentracing using the supplied configuration.
func (config *Dendrite) SetupTracing(serviceName string) (closer io.Closer, err error) {
return config.Tracing.Jaeger.InitGlobalTracer(
serviceName,
jaegerconfig.Logger(logrusLogger{logrus.StandardLogger()}),
jaegerconfig.Metrics(jaegermetrics.NullFactory),
)
}
// logrusLogger is a small wrapper that implements jaeger.Logger using logrus.
type logrusLogger struct {
l *logrus.Logger
}
func (l logrusLogger) Error(msg string) {
l.l.Error(msg)
}
func (l logrusLogger) Infof(msg string, args ...interface{}) {
l.l.Infof(msg, args...)
}

View file

@ -14,7 +14,7 @@ import (
)
// MakeAuthAPI turns a util.JSONRequestHandler function into an http.Handler which checks the access token in the request.
func MakeAuthAPI(metricsName string, deviceDB auth.DeviceDatabase, f func(*http.Request, *authtypes.Device) util.JSONResponse) http.Handler {
func MakeAuthAPI(tracer opentracing.Tracer, metricsName string, deviceDB auth.DeviceDatabase, f func(*http.Request, *authtypes.Device) util.JSONResponse) http.Handler {
h := func(req *http.Request) util.JSONResponse {
device, resErr := auth.VerifyAccessToken(req, deviceDB)
if resErr != nil {
@ -22,15 +22,15 @@ func MakeAuthAPI(metricsName string, deviceDB auth.DeviceDatabase, f func(*http.
}
return f(req, device)
}
return MakeExternalAPI(metricsName, h)
return MakeExternalAPI(tracer, metricsName, h)
}
// MakeExternalAPI turns a util.JSONRequestHandler function into an http.Handler.
// This is used for APIs that are called from the internet.
func MakeExternalAPI(metricsName string, f func(*http.Request) util.JSONResponse) http.Handler {
func MakeExternalAPI(tracer opentracing.Tracer, metricsName string, f func(*http.Request) util.JSONResponse) http.Handler {
h := util.MakeJSONAPI(util.NewJSONRequestHandler(f))
withSpan := func(w http.ResponseWriter, req *http.Request) {
span := opentracing.StartSpan(metricsName)
span := tracer.StartSpan(metricsName)
defer span.Finish()
req = req.WithContext(opentracing.ContextWithSpan(req.Context(), span))
h.ServeHTTP(w, req)
@ -43,11 +43,10 @@ func MakeExternalAPI(metricsName string, f func(*http.Request) util.JSONResponse
// This is used for APIs that are internal to dendrite.
// If we are passed a tracing context in the request headers then we use that
// as the parent of any tracing spans we create.
func MakeInternalAPI(metricsName string, f func(*http.Request) util.JSONResponse) http.Handler {
func MakeInternalAPI(tracer opentracing.Tracer, metricsName string, f func(*http.Request) util.JSONResponse) http.Handler {
h := util.MakeJSONAPI(util.NewJSONRequestHandler(f))
withSpan := func(w http.ResponseWriter, req *http.Request) {
carrier := opentracing.HTTPHeadersCarrier(req.Header)
tracer := opentracing.GlobalTracer()
clientContext, err := tracer.Extract(opentracing.HTTPHeaders, carrier)
var span opentracing.Span
if err == nil {
@ -67,6 +66,7 @@ func MakeInternalAPI(metricsName string, f func(*http.Request) util.JSONResponse
// MakeFedAPI makes an http.Handler that checks matrix federation authentication.
func MakeFedAPI(
tracer opentracing.Tracer,
metricsName string,
serverName gomatrixserverlib.ServerName,
keyRing gomatrixserverlib.KeyRing,
@ -81,7 +81,7 @@ func MakeFedAPI(
}
return f(req, fedReq)
}
return MakeExternalAPI(metricsName, h)
return MakeExternalAPI(tracer, metricsName, h)
}
// SetupHTTPAPI registers an HTTP API mux under /api and sets up a metrics

View file

@ -16,8 +16,8 @@ package keydb
import (
"context"
"database/sql"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/gomatrixserverlib"
)
@ -31,8 +31,8 @@ type Database struct {
// It creates the necessary tables if they don't already exist.
// It prepares all the SQL statements that it will use.
// Returns an error if there was a problem talking to the database.
func NewDatabase(dataSourceName string) (*Database, error) {
db, err := sql.Open("postgres", dataSourceName)
func NewDatabase(tracers *common.Tracers, dataSourceName string) (*Database, error) {
db, err := common.OpenPostgresWithTracing(tracers, "keys", dataSourceName)
if err != nil {
return nil, err
}

View file

@ -16,6 +16,13 @@ package common
import (
"database/sql"
"fmt"
"github.com/matrix-org/util"
"github.com/gchaincl/sqlhooks"
"github.com/gchaincl/sqlhooks/hooks/othooks"
"github.com/lib/pq"
)
// A Transaction is something that can be committed or rolledback.
@ -66,3 +73,18 @@ func TxStmt(transaction *sql.Tx, statement *sql.Stmt) *sql.Stmt {
}
return statement
}
// OpenPostgresWithTracing creates a new DB instance where calls will be
// traced with the given tracer
func OpenPostgresWithTracing(tracers *Tracers, databaseName, connstr string) (*sql.DB, error) {
tracer := tracers.SetupNewTracer("sql: " + databaseName)
hooks := othooks.New(tracer)
// This is a hack to get around the fact that you can't directly open
// a sql.DB with a given driver, you *have* to register it.
registrationName := fmt.Sprintf("postgres-ot-%s", util.RandomString(5))
sql.Register(registrationName, sqlhooks.Wrap(&pq.Driver{}, hooks))
return sql.Open(registrationName, connstr)
}

View file

@ -0,0 +1,103 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package common
import (
"io"
"github.com/matrix-org/dendrite/common/config"
"github.com/opentracing/opentracing-go"
"github.com/sirupsen/logrus"
jaegerconfig "github.com/uber/jaeger-client-go/config"
jaegermetrics "github.com/uber/jaeger-lib/metrics"
)
type Tracers struct {
cfg *config.Dendrite
closers []io.Closer
}
func NoopTracers() *Tracers {
return &Tracers{}
}
func NewTracers(cfg *config.Dendrite) *Tracers {
return &Tracers{
cfg: cfg,
}
}
func (t *Tracers) InitGlobalTracer(serviceName string) error {
if t.cfg == nil {
return nil
}
// Set up GlobalTracer
closer, err := t.cfg.Tracing.Jaeger.InitGlobalTracer(
serviceName,
jaegerconfig.Logger(logrusLogger{logrus.StandardLogger()}),
jaegerconfig.Metrics(jaegermetrics.NullFactory),
)
if err != nil {
return err
}
t.closers = append(t.closers, closer)
return nil
}
// SetupTracing configures the opentracing using the supplied configuration.
func (t *Tracers) SetupNewTracer(serviceName string) opentracing.Tracer {
if t.cfg == nil {
return opentracing.NoopTracer{}
}
tracer, closer, err := t.cfg.Tracing.Jaeger.New(
serviceName,
jaegerconfig.Logger(logrusLogger{logrus.StandardLogger()}),
jaegerconfig.Metrics(jaegermetrics.NullFactory),
)
if err != nil {
logrus.Panicf("Failed to create new tracer %s: %s", serviceName, err)
}
t.closers = append(t.closers, closer)
return tracer
}
func (t *Tracers) Close() error {
for _, c := range t.closers {
c.Close() // nolint: errcheck
}
return nil
}
// logrusLogger is a small wrapper that implements jaeger.Logger using logrus.
type logrusLogger struct {
l *logrus.Logger
}
func (l logrusLogger) Error(msg string) {
l.l.Error(msg)
}
func (l logrusLogger) Infof(msg string, args ...interface{}) {
l.l.Infof(msg, args...)
}

View file

@ -26,6 +26,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
opentracing "github.com/opentracing/opentracing-go"
)
const (
@ -43,11 +44,12 @@ func Setup(
keys gomatrixserverlib.KeyRing,
federation *gomatrixserverlib.FederationClient,
accountDB *accounts.Database,
tracer opentracing.Tracer,
) {
v2keysmux := apiMux.PathPrefix(pathPrefixV2Keys).Subrouter()
v1fedmux := apiMux.PathPrefix(pathPrefixV1Federation).Subrouter()
localKeys := common.MakeExternalAPI("localkeys", func(req *http.Request) util.JSONResponse {
localKeys := common.MakeExternalAPI(tracer, "localkeys", func(req *http.Request) util.JSONResponse {
return LocalKeys(cfg)
})
@ -59,7 +61,7 @@ func Setup(
v2keysmux.Handle("/server/", localKeys).Methods("GET")
v1fedmux.Handle("/send/{txnID}/", common.MakeFedAPI(
"federation_send", cfg.Matrix.ServerName, keys,
tracer, "federation_send", cfg.Matrix.ServerName, keys,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse {
vars := mux.Vars(httpReq)
return Send(
@ -70,7 +72,7 @@ func Setup(
)).Methods("PUT", "OPTIONS")
v1fedmux.Handle("/invite/{roomID}/{eventID}", common.MakeFedAPI(
"federation_invite", cfg.Matrix.ServerName, keys,
tracer, "federation_invite", cfg.Matrix.ServerName, keys,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse {
vars := mux.Vars(httpReq)
return Invite(
@ -80,14 +82,15 @@ func Setup(
},
)).Methods("PUT", "OPTIONS")
v1fedmux.Handle("/3pid/onbind", common.MakeExternalAPI("3pid_onbind",
v1fedmux.Handle("/3pid/onbind", common.MakeExternalAPI(
tracer, "3pid_onbind",
func(req *http.Request) util.JSONResponse {
return CreateInvitesFrom3PIDInvites(req, query, cfg, producer, federation, accountDB)
},
)).Methods("POST", "OPTIONS")
v1fedmux.Handle("/exchange_third_party_invite/{roomID}", common.MakeFedAPI(
"exchange_third_party_invite", cfg.Matrix.ServerName, keys,
tracer, "exchange_third_party_invite", cfg.Matrix.ServerName, keys,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse {
vars := mux.Vars(httpReq)
return ExchangeThirdPartyInvite(
@ -97,7 +100,7 @@ func Setup(
)).Methods("PUT", "OPTIONS")
v1fedmux.Handle("/event/{eventID}", common.MakeFedAPI(
"federation_get_event", cfg.Matrix.ServerName, keys,
tracer, "federation_get_event", cfg.Matrix.ServerName, keys,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse {
vars := mux.Vars(httpReq)
return GetEvent(
@ -107,7 +110,7 @@ func Setup(
)).Methods("GET")
v1fedmux.Handle("/query/directory/", common.MakeFedAPI(
"federation_query_room_alias", cfg.Matrix.ServerName, keys,
tracer, "federation_query_room_alias", cfg.Matrix.ServerName, keys,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse {
return RoomAliasToID(
httpReq, federation, cfg, aliasAPI,
@ -116,7 +119,7 @@ func Setup(
)).Methods("GET")
v1fedmux.Handle("/query/profile", common.MakeFedAPI(
"federation_query_profile", cfg.Matrix.ServerName, keys,
tracer, "federation_query_profile", cfg.Matrix.ServerName, keys,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse {
return GetProfile(
httpReq, accountDB, cfg,
@ -125,7 +128,7 @@ func Setup(
)).Methods("GET")
v1fedmux.Handle("/make_join/{roomID}/{userID}", common.MakeFedAPI(
"federation_make_join", cfg.Matrix.ServerName, keys,
tracer, "federation_make_join", cfg.Matrix.ServerName, keys,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse {
vars := mux.Vars(httpReq)
roomID := vars["roomID"]
@ -137,7 +140,7 @@ func Setup(
)).Methods("GET")
v1fedmux.Handle("/send_join/{roomID}/{userID}", common.MakeFedAPI(
"federation_send_join", cfg.Matrix.ServerName, keys,
tracer, "federation_send_join", cfg.Matrix.ServerName, keys,
func(httpReq *http.Request, request *gomatrixserverlib.FederationRequest) util.JSONResponse {
vars := mux.Vars(httpReq)
roomID := vars["roomID"]
@ -149,7 +152,7 @@ func Setup(
)).Methods("PUT")
v1fedmux.Handle("/version", common.MakeExternalAPI(
"federation_version",
tracer, "federation_version",
func(httpReq *http.Request) util.JSONResponse {
return Version()
},

View file

@ -26,6 +26,7 @@ import (
"github.com/matrix-org/dendrite/federationsender/types"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib"
opentracing "github.com/opentracing/opentracing-go"
log "github.com/sirupsen/logrus"
sarama "gopkg.in/Shopify/sarama.v1"
)
@ -36,6 +37,7 @@ type OutputRoomEventConsumer struct {
db *storage.Database
queues *queue.OutgoingQueues
query api.RoomserverQueryAPI
tracer opentracing.Tracer
}
// NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers.
@ -45,6 +47,7 @@ func NewOutputRoomEventConsumer(
queues *queue.OutgoingQueues,
store *storage.Database,
queryAPI api.RoomserverQueryAPI,
tracer opentracing.Tracer,
) *OutputRoomEventConsumer {
consumer := common.ContinualConsumer{
Topic: string(cfg.Kafka.Topics.OutputRoomEvent),
@ -56,6 +59,7 @@ func NewOutputRoomEventConsumer(
db: store,
queues: queues,
query: queryAPI,
tracer: tracer,
}
consumer.ProcessMessage = s.onMessage
@ -86,7 +90,7 @@ func (s *OutputRoomEventConsumer) onMessage(msg *sarama.ConsumerMessage) error {
return nil
}
ctx, span := output.StartSpanAndReplaceContext(context.Background())
ctx, span := output.StartSpanAndReplaceContext(context.Background(), s.tracer)
defer span.Finish()
ev := &output.NewRoomEvent.Event

View file

@ -31,10 +31,10 @@ type Database struct {
}
// NewDatabase opens a new database
func NewDatabase(dataSourceName string) (*Database, error) {
func NewDatabase(tracers *common.Tracers, dataSourceName string) (*Database, error) {
var result Database
var err error
if result.db, err = sql.Open("postgres", dataSourceName); err != nil {
if result.db, err = common.OpenPostgresWithTracing(tracers, "federationsender", dataSourceName); err != nil {
return nil, err
}
if err = result.prepare(); err != nil {

View file

@ -18,6 +18,7 @@ import (
"net/http"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
opentracing "github.com/opentracing/opentracing-go"
"github.com/gorilla/mux"
"github.com/matrix-org/dendrite/clientapi/auth/storage/devices"
@ -39,6 +40,7 @@ func Setup(
db *storage.Database,
deviceDB *devices.Database,
client *gomatrixserverlib.Client,
tracer opentracing.Tracer,
) {
r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter()
@ -47,6 +49,7 @@ func Setup(
}
r0mux.Handle("/upload", common.MakeAuthAPI(
tracer,
"upload",
deviceDB,
func(req *http.Request, _ *authtypes.Device) util.JSONResponse {
@ -73,6 +76,8 @@ func makeDownloadAPI(
activeRemoteRequests *types.ActiveRemoteRequests,
activeThumbnailGeneration *types.ActiveThumbnailGeneration,
) http.HandlerFunc {
// TODO: Add opentracing.
return prometheus.InstrumentHandler(name, http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
req = util.RequestWithLogging(req)

View file

@ -18,8 +18,7 @@ import (
"context"
"database/sql"
// Import the postgres database driver.
_ "github.com/lib/pq"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/mediaapi/types"
"github.com/matrix-org/gomatrixserverlib"
)
@ -31,10 +30,10 @@ type Database struct {
}
// Open opens a postgres database.
func Open(dataSourceName string) (*Database, error) {
func Open(tracers *common.Tracers, dataSourceName string) (*Database, error) {
var d Database
var err error
if d.db, err = sql.Open("postgres", dataSourceName); err != nil {
if d.db, err = common.OpenPostgresWithTracing(tracers, "media", dataSourceName); err != nil {
return nil, err
}
if err = d.statements.prepare(d.db); err != nil {

View file

@ -22,6 +22,7 @@ import (
"github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/dendrite/publicroomsapi/storage"
"github.com/matrix-org/dendrite/roomserver/api"
opentracing "github.com/opentracing/opentracing-go"
log "github.com/sirupsen/logrus"
sarama "gopkg.in/Shopify/sarama.v1"
)
@ -31,6 +32,7 @@ type OutputRoomEventConsumer struct {
roomServerConsumer *common.ContinualConsumer
db *storage.PublicRoomsServerDatabase
query api.RoomserverQueryAPI
tracer opentracing.Tracer
}
// NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers.
@ -39,6 +41,7 @@ func NewOutputRoomEventConsumer(
kafkaConsumer sarama.Consumer,
store *storage.PublicRoomsServerDatabase,
queryAPI api.RoomserverQueryAPI,
tracer opentracing.Tracer,
) *OutputRoomEventConsumer {
consumer := common.ContinualConsumer{
Topic: string(cfg.Kafka.Topics.OutputRoomEvent),
@ -49,6 +52,7 @@ func NewOutputRoomEventConsumer(
roomServerConsumer: &consumer,
db: store,
query: queryAPI,
tracer: tracer,
}
consumer.ProcessMessage = s.onMessage
@ -77,7 +81,7 @@ func (s *OutputRoomEventConsumer) onMessage(msg *sarama.ConsumerMessage) error {
return nil
}
ctx, span := output.StartSpanAndReplaceContext(context.Background())
ctx, span := output.StartSpanAndReplaceContext(context.Background(), s.tracer)
defer span.Finish()
ev := output.NewRoomEvent.Event

View file

@ -24,27 +24,33 @@ import (
"github.com/matrix-org/dendrite/publicroomsapi/directory"
"github.com/matrix-org/dendrite/publicroomsapi/storage"
"github.com/matrix-org/util"
opentracing "github.com/opentracing/opentracing-go"
)
const pathPrefixR0 = "/_matrix/client/r0"
// Setup configures the given mux with publicroomsapi server listeners
func Setup(apiMux *mux.Router, deviceDB *devices.Database, publicRoomsDB *storage.PublicRoomsServerDatabase) {
func Setup(
apiMux *mux.Router,
deviceDB *devices.Database,
publicRoomsDB *storage.PublicRoomsServerDatabase,
tracer opentracing.Tracer,
) {
r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter()
r0mux.Handle("/directory/list/room/{roomID}",
common.MakeExternalAPI("directory_list", func(req *http.Request) util.JSONResponse {
common.MakeExternalAPI(tracer, "directory_list", func(req *http.Request) util.JSONResponse {
vars := mux.Vars(req)
return directory.GetVisibility(req, publicRoomsDB, vars["roomID"])
}),
).Methods("GET", "OPTIONS")
r0mux.Handle("/directory/list/room/{roomID}",
common.MakeAuthAPI("directory_list", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
common.MakeAuthAPI(tracer, "directory_list", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
return directory.SetVisibility(req, publicRoomsDB, vars["roomID"])
}),
).Methods("PUT", "OPTIONS")
r0mux.Handle("/publicRooms",
common.MakeExternalAPI("public_rooms", func(req *http.Request) util.JSONResponse {
common.MakeExternalAPI(tracer, "public_rooms", func(req *http.Request) util.JSONResponse {
return directory.GetPublicRooms(req, publicRoomsDB)
}),
).Methods("GET", "POST", "OPTIONS")

View file

@ -35,10 +35,10 @@ type PublicRoomsServerDatabase struct {
type attributeValue interface{}
// NewPublicRoomsServerDatabase creates a new public rooms server database.
func NewPublicRoomsServerDatabase(dataSourceName string) (*PublicRoomsServerDatabase, error) {
func NewPublicRoomsServerDatabase(tracers *common.Tracers, dataSourceName string) (*PublicRoomsServerDatabase, error) {
var db *sql.DB
var err error
if db, err = sql.Open("postgres", dataSourceName); err != nil {
if db, err = common.OpenPostgresWithTracing(tracers, "publicrooms", dataSourceName); err != nil {
return nil, err
}
partitions := common.PartitionOffsetStatements{}

View file

@ -26,6 +26,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
opentracing "github.com/opentracing/opentracing-go"
)
// RoomserverAliasAPIDatabase has the storage APIs needed to implement the alias API.
@ -213,10 +214,10 @@ func (r *RoomserverAliasAPI) sendUpdatedAliasesEvent(
}
// SetupHTTP adds the RoomserverAliasAPI handlers to the http.ServeMux.
func (r *RoomserverAliasAPI) SetupHTTP(servMux *http.ServeMux) {
func (r *RoomserverAliasAPI) SetupHTTP(servMux *http.ServeMux, tracer opentracing.Tracer) {
servMux.Handle(
api.RoomserverSetRoomAliasPath,
common.MakeInternalAPI("setRoomAlias", func(req *http.Request) util.JSONResponse {
common.MakeInternalAPI(tracer, "setRoomAlias", func(req *http.Request) util.JSONResponse {
var request api.SetRoomAliasRequest
var response api.SetRoomAliasResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
@ -230,7 +231,7 @@ func (r *RoomserverAliasAPI) SetupHTTP(servMux *http.ServeMux) {
)
servMux.Handle(
api.RoomserverGetAliasRoomIDPath,
common.MakeInternalAPI("getAliasRoomID", func(req *http.Request) util.JSONResponse {
common.MakeInternalAPI(tracer, "getAliasRoomID", func(req *http.Request) util.JSONResponse {
var request api.GetAliasRoomIDRequest
var response api.GetAliasRoomIDResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
@ -244,7 +245,7 @@ func (r *RoomserverAliasAPI) SetupHTTP(servMux *http.ServeMux) {
)
servMux.Handle(
api.RoomserverRemoveRoomAliasPath,
common.MakeInternalAPI("removeRoomAlias", func(req *http.Request) util.JSONResponse {
common.MakeInternalAPI(tracer, "removeRoomAlias", func(req *http.Request) util.JSONResponse {
var request api.RemoveRoomAliasRequest
var response api.RemoveRoomAliasResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {

View file

@ -69,9 +69,8 @@ func (o *OutputEvent) AddSpanFromContext(ctx context.Context) error {
// StartSpanAndReplaceContext produces a context and opentracing span from the
// info embedded in OutputEvent
func (o *OutputEvent) StartSpanAndReplaceContext(
ctx context.Context,
ctx context.Context, tracer opentracing.Tracer,
) (context.Context, opentracing.Span) {
tracer := opentracing.GlobalTracer()
producerContext, err := tracer.Extract(opentracing.TextMap, o.OpentracingCarrier)
var span opentracing.Span

View file

@ -23,6 +23,7 @@ import (
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/roomserver/api"
"github.com/matrix-org/util"
opentracing "github.com/opentracing/opentracing-go"
sarama "gopkg.in/Shopify/sarama.v1"
)
@ -72,9 +73,9 @@ func (r *RoomserverInputAPI) InputRoomEvents(
}
// SetupHTTP adds the RoomserverInputAPI handlers to the http.ServeMux.
func (r *RoomserverInputAPI) SetupHTTP(servMux *http.ServeMux) {
func (r *RoomserverInputAPI) SetupHTTP(servMux *http.ServeMux, tracer opentracing.Tracer) {
servMux.Handle(api.RoomserverInputRoomEventsPath,
common.MakeInternalAPI("inputRoomEvents", func(req *http.Request) util.JSONResponse {
common.MakeInternalAPI(tracer, "inputRoomEvents", func(req *http.Request) util.JSONResponse {
var request api.InputRoomEventsRequest
var response api.InputRoomEventsResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
@ -87,3 +88,29 @@ func (r *RoomserverInputAPI) SetupHTTP(servMux *http.ServeMux) {
}),
)
}
type InProcessRoomServerInput struct {
db RoomserverInputAPI
tracer opentracing.Tracer
}
func NewInProcessRoomServerInput(db RoomserverInputAPI, tracer opentracing.Tracer) *InProcessRoomServerInput {
return &InProcessRoomServerInput{
db: db, tracer: tracer,
}
}
func (r *InProcessRoomServerInput) InputRoomEvents(
ctx context.Context,
request *api.InputRoomEventsRequest,
response *api.InputRoomEventsResponse,
) error {
span := r.tracer.StartSpan(
"InputRoomEvents",
opentracing.ChildOf(opentracing.SpanFromContext(ctx).Context()),
)
defer span.Finish()
ctx = opentracing.ContextWithSpan(ctx, span)
return r.db.InputRoomEvents(ctx, request, response)
}

View file

@ -25,6 +25,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
opentracing "github.com/opentracing/opentracing-go"
)
// RoomserverQueryAPIEventDB has a convenience API to fetch events directly by
@ -522,10 +523,10 @@ func getAuthChain(
// SetupHTTP adds the RoomserverQueryAPI handlers to the http.ServeMux.
// nolint: gocyclo
func (r *RoomserverQueryAPI) SetupHTTP(servMux *http.ServeMux) {
func (r *RoomserverQueryAPI) SetupHTTP(servMux *http.ServeMux, tracer opentracing.Tracer) {
servMux.Handle(
api.RoomserverQueryLatestEventsAndStatePath,
common.MakeInternalAPI("queryLatestEventsAndState", func(req *http.Request) util.JSONResponse {
common.MakeInternalAPI(tracer, "queryLatestEventsAndState", func(req *http.Request) util.JSONResponse {
var request api.QueryLatestEventsAndStateRequest
var response api.QueryLatestEventsAndStateResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
@ -539,7 +540,7 @@ func (r *RoomserverQueryAPI) SetupHTTP(servMux *http.ServeMux) {
)
servMux.Handle(
api.RoomserverQueryStateAfterEventsPath,
common.MakeInternalAPI("queryStateAfterEvents", func(req *http.Request) util.JSONResponse {
common.MakeInternalAPI(tracer, "queryStateAfterEvents", func(req *http.Request) util.JSONResponse {
var request api.QueryStateAfterEventsRequest
var response api.QueryStateAfterEventsResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
@ -553,7 +554,7 @@ func (r *RoomserverQueryAPI) SetupHTTP(servMux *http.ServeMux) {
)
servMux.Handle(
api.RoomserverQueryEventsByIDPath,
common.MakeInternalAPI("queryEventsByID", func(req *http.Request) util.JSONResponse {
common.MakeInternalAPI(tracer, "queryEventsByID", func(req *http.Request) util.JSONResponse {
var request api.QueryEventsByIDRequest
var response api.QueryEventsByIDResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
@ -567,7 +568,7 @@ func (r *RoomserverQueryAPI) SetupHTTP(servMux *http.ServeMux) {
)
servMux.Handle(
api.RoomserverQueryMembershipsForRoomPath,
common.MakeInternalAPI("queryMembershipsForRoom", func(req *http.Request) util.JSONResponse {
common.MakeInternalAPI(tracer, "queryMembershipsForRoom", func(req *http.Request) util.JSONResponse {
var request api.QueryMembershipsForRoomRequest
var response api.QueryMembershipsForRoomResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
@ -581,7 +582,7 @@ func (r *RoomserverQueryAPI) SetupHTTP(servMux *http.ServeMux) {
)
servMux.Handle(
api.RoomserverQueryInvitesForUserPath,
common.MakeInternalAPI("queryInvitesForUser", func(req *http.Request) util.JSONResponse {
common.MakeInternalAPI(tracer, "queryInvitesForUser", func(req *http.Request) util.JSONResponse {
var request api.QueryInvitesForUserRequest
var response api.QueryInvitesForUserResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
@ -595,7 +596,7 @@ func (r *RoomserverQueryAPI) SetupHTTP(servMux *http.ServeMux) {
)
servMux.Handle(
api.RoomserverQueryServerAllowedToSeeEventPath,
common.MakeInternalAPI("queryServerAllowedToSeeEvent", func(req *http.Request) util.JSONResponse {
common.MakeInternalAPI(tracer, "queryServerAllowedToSeeEvent", func(req *http.Request) util.JSONResponse {
var request api.QueryServerAllowedToSeeEventRequest
var response api.QueryServerAllowedToSeeEventResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
@ -609,7 +610,7 @@ func (r *RoomserverQueryAPI) SetupHTTP(servMux *http.ServeMux) {
)
servMux.Handle(
api.RoomserverQueryStateAndAuthChainPath,
common.MakeInternalAPI("queryStateAndAuthChain", func(req *http.Request) util.JSONResponse {
common.MakeInternalAPI(tracer, "queryStateAndAuthChain", func(req *http.Request) util.JSONResponse {
var request api.QueryStateAndAuthChainRequest
var response api.QueryStateAndAuthChainResponse
if err := json.NewDecoder(req.Body).Decode(&request); err != nil {
@ -622,3 +623,127 @@ func (r *RoomserverQueryAPI) SetupHTTP(servMux *http.ServeMux) {
}),
)
}
type InProcessRoomServerQueryAPI struct {
db RoomserverQueryAPI
tracer opentracing.Tracer
}
func NewInProcessRoomServerQueryAPI(db RoomserverQueryAPI, tracer opentracing.Tracer) InProcessRoomServerQueryAPI {
return InProcessRoomServerQueryAPI{
db: db,
tracer: tracer,
}
}
// QueryLatestEventsAndState implements RoomserverQueryAPI
func (h *InProcessRoomServerQueryAPI) QueryLatestEventsAndState(
ctx context.Context,
request *api.QueryLatestEventsAndStateRequest,
response *api.QueryLatestEventsAndStateResponse,
) error {
span := h.tracer.StartSpan(
"QueryLatestEventsAndState",
opentracing.ChildOf(opentracing.SpanFromContext(ctx).Context()),
)
defer span.Finish()
ctx = opentracing.ContextWithSpan(ctx, span)
return h.db.QueryLatestEventsAndState(ctx, request, response)
}
// QueryStateAfterEvents implements RoomserverQueryAPI
func (h *InProcessRoomServerQueryAPI) QueryStateAfterEvents(
ctx context.Context,
request *api.QueryStateAfterEventsRequest,
response *api.QueryStateAfterEventsResponse,
) error {
span := h.tracer.StartSpan(
"QueryStateAfterEvents",
opentracing.ChildOf(opentracing.SpanFromContext(ctx).Context()),
)
defer span.Finish()
ctx = opentracing.ContextWithSpan(ctx, span)
return h.db.QueryStateAfterEvents(ctx, request, response)
}
// QueryEventsByID implements RoomserverQueryAPI
func (h *InProcessRoomServerQueryAPI) QueryEventsByID(
ctx context.Context,
request *api.QueryEventsByIDRequest,
response *api.QueryEventsByIDResponse,
) error {
span := h.tracer.StartSpan(
"QueryEventsByID",
opentracing.ChildOf(opentracing.SpanFromContext(ctx).Context()),
)
defer span.Finish()
ctx = opentracing.ContextWithSpan(ctx, span)
return h.db.QueryEventsByID(ctx, request, response)
}
// QueryMembershipsForRoom implements RoomserverQueryAPI
func (h *InProcessRoomServerQueryAPI) QueryMembershipsForRoom(
ctx context.Context,
request *api.QueryMembershipsForRoomRequest,
response *api.QueryMembershipsForRoomResponse,
) error {
span := h.tracer.StartSpan(
"QueryMembershipsForRoom",
opentracing.ChildOf(opentracing.SpanFromContext(ctx).Context()),
)
defer span.Finish()
ctx = opentracing.ContextWithSpan(ctx, span)
return h.db.QueryMembershipsForRoom(ctx, request, response)
}
// QueryInvitesForUser implements RoomserverQueryAPI
func (h *InProcessRoomServerQueryAPI) QueryInvitesForUser(
ctx context.Context,
request *api.QueryInvitesForUserRequest,
response *api.QueryInvitesForUserResponse,
) error {
span := h.tracer.StartSpan(
"QueryInvitesForUser",
opentracing.ChildOf(opentracing.SpanFromContext(ctx).Context()),
)
defer span.Finish()
ctx = opentracing.ContextWithSpan(ctx, span)
return h.db.QueryInvitesForUser(ctx, request, response)
}
// QueryServerAllowedToSeeEvent implements RoomserverQueryAPI
func (h *InProcessRoomServerQueryAPI) QueryServerAllowedToSeeEvent(
ctx context.Context,
request *api.QueryServerAllowedToSeeEventRequest,
response *api.QueryServerAllowedToSeeEventResponse,
) (err error) {
span := h.tracer.StartSpan(
"QueryServerAllowedToSeeEvent",
opentracing.ChildOf(opentracing.SpanFromContext(ctx).Context()),
)
defer span.Finish()
ctx = opentracing.ContextWithSpan(ctx, span)
return h.db.QueryServerAllowedToSeeEvent(ctx, request, response)
}
// QueryStateAndAuthChain implements RoomserverQueryAPI
func (h *InProcessRoomServerQueryAPI) QueryStateAndAuthChain(
ctx context.Context,
request *api.QueryStateAndAuthChainRequest,
response *api.QueryStateAndAuthChainResponse,
) error {
span := h.tracer.StartSpan(
"QueryStateAndAuthChain",
opentracing.ChildOf(opentracing.SpanFromContext(ctx).Context()),
)
defer span.Finish()
ctx = opentracing.ContextWithSpan(ctx, span)
return h.db.QueryStateAndAuthChain(ctx, request, response)
}

View file

@ -18,8 +18,7 @@ import (
"context"
"database/sql"
// Import the postgres database driver.
_ "github.com/lib/pq"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib"
)
@ -31,10 +30,10 @@ type Database struct {
}
// Open a postgres database.
func Open(dataSourceName string) (*Database, error) {
func Open(tracers *common.Tracers, dataSourceName string) (*Database, error) {
var d Database
var err error
if d.db, err = sql.Open("postgres", dataSourceName); err != nil {
if d.db, err = common.OpenPostgresWithTracing(tracers, "roomserver", dataSourceName); err != nil {
return nil, err
}
if err = d.statements.prepare(d.db); err != nil {

View file

@ -19,6 +19,8 @@ import (
"encoding/json"
"fmt"
"github.com/opentracing/opentracing-go"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/common/config"
"github.com/matrix-org/dendrite/roomserver/api"
@ -36,6 +38,7 @@ type OutputRoomEventConsumer struct {
db *storage.SyncServerDatabase
notifier *sync.Notifier
query api.RoomserverQueryAPI
tracer opentracing.Tracer
}
// NewOutputRoomEventConsumer creates a new OutputRoomEventConsumer. Call Start() to begin consuming from room servers.
@ -45,6 +48,7 @@ func NewOutputRoomEventConsumer(
n *sync.Notifier,
store *storage.SyncServerDatabase,
queryAPI api.RoomserverQueryAPI,
tracer opentracing.Tracer,
) *OutputRoomEventConsumer {
consumer := common.ContinualConsumer{
@ -57,6 +61,7 @@ func NewOutputRoomEventConsumer(
db: store,
notifier: n,
query: queryAPI,
tracer: tracer,
}
consumer.ProcessMessage = s.onMessage
@ -80,7 +85,7 @@ func (s *OutputRoomEventConsumer) onMessage(msg *sarama.ConsumerMessage) error {
return nil
}
ctx, span := output.StartSpanAndReplaceContext(context.Background())
ctx, span := output.StartSpanAndReplaceContext(context.Background(), s.tracer)
defer span.Finish()
switch output.Type {

View file

@ -24,29 +24,33 @@ import (
"github.com/matrix-org/dendrite/syncapi/storage"
"github.com/matrix-org/dendrite/syncapi/sync"
"github.com/matrix-org/util"
opentracing "github.com/opentracing/opentracing-go"
)
const pathPrefixR0 = "/_matrix/client/r0"
// Setup configures the given mux with sync-server listeners
func Setup(apiMux *mux.Router, srp *sync.RequestPool, syncDB *storage.SyncServerDatabase, deviceDB *devices.Database) {
func Setup(
apiMux *mux.Router, srp *sync.RequestPool, syncDB *storage.SyncServerDatabase, deviceDB *devices.Database,
tracer opentracing.Tracer,
) {
r0mux := apiMux.PathPrefix(pathPrefixR0).Subrouter()
r0mux.Handle("/sync", common.MakeAuthAPI("sync", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
r0mux.Handle("/sync", common.MakeAuthAPI(tracer, "sync", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
return srp.OnIncomingSyncRequest(req, device)
})).Methods("GET", "OPTIONS")
r0mux.Handle("/rooms/{roomID}/state", common.MakeAuthAPI("room_state", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
r0mux.Handle("/rooms/{roomID}/state", common.MakeAuthAPI(tracer, "room_state", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
return OnIncomingStateRequest(req, syncDB, vars["roomID"])
})).Methods("GET", "OPTIONS")
r0mux.Handle("/rooms/{roomID}/state/{type}", common.MakeAuthAPI("room_state", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
r0mux.Handle("/rooms/{roomID}/state/{type}", common.MakeAuthAPI(tracer, "room_state", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
return OnIncomingStateTypeRequest(req, syncDB, vars["roomID"], vars["type"], "")
})).Methods("GET", "OPTIONS")
r0mux.Handle("/rooms/{roomID}/state/{type}/{stateKey}", common.MakeAuthAPI("room_state", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
r0mux.Handle("/rooms/{roomID}/state/{type}/{stateKey}", common.MakeAuthAPI(tracer, "room_state", deviceDB, func(req *http.Request, device *authtypes.Device) util.JSONResponse {
vars := mux.Vars(req)
return OnIncomingStateTypeRequest(req, syncDB, vars["roomID"], vars["type"], vars["stateKey"])
})).Methods("GET", "OPTIONS")

View file

@ -19,7 +19,7 @@ import (
"database/sql"
"fmt"
// Import the postgres database driver.
_ "github.com/lib/pq"
"github.com/matrix-org/dendrite/common"
"github.com/matrix-org/dendrite/syncapi/types"
"github.com/matrix-org/gomatrixserverlib"
@ -51,10 +51,10 @@ type SyncServerDatabase struct {
}
// NewSyncServerDatabase creates a new sync server database
func NewSyncServerDatabase(dataSourceName string) (*SyncServerDatabase, error) {
func NewSyncServerDatabase(tracers *common.Tracers, dataSourceName string) (*SyncServerDatabase, error) {
var d SyncServerDatabase
var err error
if d.db, err = sql.Open("postgres", dataSourceName); err != nil {
if d.db, err = common.OpenPostgresWithTracing(tracers, "sync", dataSourceName); err != nil {
return nil, err
}
if err = d.PartitionOffsetStatements.Prepare(d.db, "syncapi"); err != nil {

6
vendor/manifest vendored
View file

@ -77,6 +77,12 @@
"revision": "44cc805cf13205b55f69e14bcb69867d1ae92f98",
"branch": "master"
},
{
"importpath": "github.com/gchaincl/sqlhooks",
"repository": "https://github.com/gchaincl/sqlhooks",
"revision": "b4a12bad76664eae8012d196ed901f8fa8f87909",
"branch": "master"
},
{
"importpath": "github.com/golang/protobuf/proto",
"repository": "https://github.com/golang/protobuf",

View file

@ -0,0 +1,41 @@
# Change Log
## [Unreleased](https://github.com/gchaincl/sqlhooks/tree/HEAD)
[Full Changelog](https://github.com/gchaincl/sqlhooks/compare/v1.0.0...HEAD)
**Closed issues:**
- Add Benchmarks [\#9](https://github.com/gchaincl/sqlhooks/issues/9)
## [v1.0.0](https://github.com/gchaincl/sqlhooks/tree/v1.0.0) (2017-05-08)
[Full Changelog](https://github.com/gchaincl/sqlhooks/compare/v0.4...v1.0.0)
**Merged pull requests:**
- Godoc [\#7](https://github.com/gchaincl/sqlhooks/pull/7) ([gchaincl](https://github.com/gchaincl))
- Make covermode=count [\#6](https://github.com/gchaincl/sqlhooks/pull/6) ([gchaincl](https://github.com/gchaincl))
- V1 [\#5](https://github.com/gchaincl/sqlhooks/pull/5) ([gchaincl](https://github.com/gchaincl))
- Expose a WrapDriver function [\#4](https://github.com/gchaincl/sqlhooks/issues/4)
- Implement new 1.8 interfaces [\#3](https://github.com/gchaincl/sqlhooks/issues/3)
## [v0.4](https://github.com/gchaincl/sqlhooks/tree/v0.4) (2017-03-23)
[Full Changelog](https://github.com/gchaincl/sqlhooks/compare/v0.3...v0.4)
## [v0.3](https://github.com/gchaincl/sqlhooks/tree/v0.3) (2016-06-02)
[Full Changelog](https://github.com/gchaincl/sqlhooks/compare/v0.2...v0.3)
**Closed issues:**
- Change Notifications [\#2](https://github.com/gchaincl/sqlhooks/issues/2)
## [v0.2](https://github.com/gchaincl/sqlhooks/tree/v0.2) (2016-05-01)
[Full Changelog](https://github.com/gchaincl/sqlhooks/compare/v0.1...v0.2)
## [v0.1](https://github.com/gchaincl/sqlhooks/tree/v0.1) (2016-04-25)
**Merged pull requests:**
- Sqlite3 [\#1](https://github.com/gchaincl/sqlhooks/pull/1) ([gchaincl](https://github.com/gchaincl))
\* *This Change Log was automatically generated by [github_changelog_generator](https://github.com/skywinder/Github-Changelog-Generator)*

View file

@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) 2016 Gustavo Chaín
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View file

@ -0,0 +1,82 @@
# sqlhooks [![Build Status](https://travis-ci.org/gchaincl/sqlhooks.svg)](https://travis-ci.org/gchaincl/sqlhooks) [![Coverage Status](https://coveralls.io/repos/github/gchaincl/sqlhooks/badge.svg?branch=master)](https://coveralls.io/github/gchaincl/sqlhooks?branch=master) [![Go Report Card](https://goreportcard.com/badge/github.com/gchaincl/sqlhooks)](https://goreportcard.com/report/github.com/gchaincl/sqlhooks)
Attach hooks to any database/sql driver.
The purpose of sqlhooks is to provide a way to instrument your sql statements, making really easy to log queries or measure execution time without modifying your actual code.
# Install
```bash
go get github.com/gchaincl/sqlhooks
```
## Breaking changes
`V1` isn't backward compatible with previous versions, if you want to fetch old versions, you can get them from [gopkg.in](http://gopkg.in/)
```bash
go get gopkg.in/gchaincl/sqlhooks.v0
```
# Usage [![GoDoc](https://godoc.org/github.com/gchaincl/dotsql?status.svg)](https://godoc.org/github.com/gchaincl/sqlhooks)
```go
// This example shows how to instrument sql queries in order to display the time that they consume
package main
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/gchaincl/sqlhooks"
"github.com/mattn/go-sqlite3"
)
// Hooks satisfies the sqlhook.Hooks interface
type Hooks struct {}
// Before hook will print the query with it's args and return the context with the timestamp
func (h *Hooks) Before(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
fmt.Printf("> %s %q", query, args)
return context.WithValue(ctx, "begin", time.Now()), nil
}
// After hook will get the timestamp registered on the Before hook and print the elapsed time
func (h *Hooks) After(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
begin := ctx.Value("begin").(time.Time)
fmt.Printf(". took: %s\n", time.Since(begin))
return ctx, nil
}
func main() {
// First, register the wrapper
sql.Register("sqlite3WithHooks", sqlhooks.Wrap(&sqlite3.SQLiteDriver{}, &Hooks{}))
// Connect to the registered wrapped driver
db, _ := sql.Open("sqlite3WithHooks", ":memory:")
// Do you're stuff
db.Exec("CREATE TABLE t (id INTEGER, text VARCHAR(16))")
db.Exec("INSERT into t (text) VALUES(?), (?)", "foo", "bar")
db.Query("SELECT id, text FROM t")
}
/*
Output should look like:
> CREATE TABLE t (id INTEGER, text VARCHAR(16)) []. took: 121.238µs
> INSERT into t (text) VALUES(?), (?) ["foo" "bar"]. took: 36.364µs
> SELECT id, text FROM t []. took: 4.653µs
*/
```
# Benchmarks
```
go test -bench=. -benchmem
BenchmarkSQLite3/Without_Hooks-4 200000 8572 ns/op 627 B/op 16 allocs/op
BenchmarkSQLite3/With_Hooks-4 200000 10231 ns/op 738 B/op 18 allocs/op
BenchmarkMySQL/Without_Hooks-4 10000 108421 ns/op 437 B/op 10 allocs/op
BenchmarkMySQL/With_Hooks-4 10000 226085 ns/op 597 B/op 13 allocs/op
BenchmarkPostgres/Without_Hooks-4 10000 125718 ns/op 649 B/op 17 allocs/op
BenchmarkPostgres/With_Hooks-4 5000 354831 ns/op 1122 B/op 27 allocs/op
PASS
ok github.com/gchaincl/sqlhooks 11.713s
```

View file

@ -0,0 +1,76 @@
package sqlhooks
import (
"database/sql"
"os"
"testing"
"github.com/go-sql-driver/mysql"
"github.com/lib/pq"
"github.com/mattn/go-sqlite3"
"github.com/stretchr/testify/require"
)
func init() {
hooks := &testHooks{}
hooks.noop()
sql.Register("sqlite3-benchmark", Wrap(&sqlite3.SQLiteDriver{}, hooks))
sql.Register("mysql-benchmark", Wrap(&mysql.MySQLDriver{}, hooks))
sql.Register("postgres-benchmark", Wrap(&pq.Driver{}, hooks))
}
func benchmark(b *testing.B, driver, dsn string) {
db, err := sql.Open(driver, dsn)
require.NoError(b, err)
defer db.Close()
var query = "SELECT 'hello'"
b.ResetTimer()
for i := 0; i < b.N; i++ {
rows, err := db.Query(query)
require.NoError(b, err)
require.NoError(b, rows.Close())
}
}
func BenchmarkSQLite3(b *testing.B) {
b.Run("Without Hooks", func(b *testing.B) {
benchmark(b, "sqlite3", ":memory:")
})
b.Run("With Hooks", func(b *testing.B) {
benchmark(b, "sqlite3-benchmark", ":memory:")
})
}
func BenchmarkMySQL(b *testing.B) {
dsn := os.Getenv("SQLHOOKS_MYSQL_DSN")
if dsn == "" {
b.Skipf("SQLHOOKS_MYSQL_DSN not set")
}
b.Run("Without Hooks", func(b *testing.B) {
benchmark(b, "mysql", dsn)
})
b.Run("With Hooks", func(b *testing.B) {
benchmark(b, "mysql-benchmark", dsn)
})
}
func BenchmarkPostgres(b *testing.B) {
dsn := os.Getenv("SQLHOOKS_POSTGRES_DSN")
if dsn == "" {
b.Skipf("SQLHOOKS_POSTGRES_DSN not set")
}
b.Run("Without Hooks", func(b *testing.B) {
benchmark(b, "postgres", dsn)
})
b.Run("With Hooks", func(b *testing.B) {
benchmark(b, "postgres-benchmark", dsn)
})
}

View file

@ -0,0 +1,52 @@
// package sqlhooks allows you to attach hooks to any database/sql driver.
// The purpose of sqlhooks is to provide a way to instrument your sql statements, making really easy to log queries or measure execution time without modifying your actual code.
// This example shows how to instrument sql queries in order to display the time that they consume
// package main
//
// import (
// "context"
// "database/sql"
// "fmt"
// "time"
//
// "github.com/gchaincl/sqlhooks"
// "github.com/mattn/go-sqlite3"
// )
//
// // Hooks satisfies the sqlhook.Hooks interface
// type Hooks struct {}
//
// // Before hook will print the query with it's args and return the context with the timestamp
// func (h *Hooks) Before(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
// fmt.Printf("> %s %q", query, args)
// return context.WithValue(ctx, "begin", time.Now()), nil
// }
//
// // After hook will get the timestamp registered on the Before hook and print the elapsed time
// func (h *Hooks) After(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
// begin := ctx.Value("begin").(time.Time)
// fmt.Printf(". took: %s\n", time.Since(begin))
// return ctx, nil
// }
//
// func main() {
// // First, register the wrapper
// sql.Register("sqlite3WithHooks", sqlhooks.Wrap(&sqlite3.SQLiteDriver{}, &Hooks{}))
//
// // Connect to the registered wrapped driver
// db, _ := sql.Open("sqlite3WithHooks", ":memory:")
//
// // Do you're stuff
// db.Exec("CREATE TABLE t (id INTEGER, text VARCHAR(16))")
// db.Exec("INSERT into t (text) VALUES(?), (?)", "foo", "bar")
// db.Query("SELECT id, text FROM t")
// }
//
// /*
// Output should look like:
// > CREATE TABLE t (id INTEGER, text VARCHAR(16)) []. took: 121.238µs
// > INSERT into t (text) VALUES(?), (?) ["foo" "bar"]. took: 36.364µs
// > SELECT id, text FROM t []. took: 4.653µs
// */
package sqlhooks

View file

@ -0,0 +1,17 @@
package loghooks
import (
"database/sql"
"github.com/gchaincl/sqlhooks"
sqlite3 "github.com/mattn/go-sqlite3"
)
func Example() {
driver := sqlhooks.Wrap(&sqlite3.SQLiteDriver{}, New())
sql.Register("sqlite3-logger", driver)
db, _ := sql.Open("sqlite3-logger", ":memory:")
// This query will output logs
db.Query("SELECT 1+1")
}

View file

@ -0,0 +1,31 @@
package main
import (
"database/sql"
"log"
"github.com/gchaincl/sqlhooks"
"github.com/gchaincl/sqlhooks/hooks/loghooks"
"github.com/mattn/go-sqlite3"
)
func main() {
sql.Register("sqlite3log", sqlhooks.Wrap(&sqlite3.SQLiteDriver{}, loghooks.New()))
db, err := sql.Open("sqlite3log", ":memory:")
if err != nil {
log.Fatal(err)
}
if _, err := db.Exec("CREATE TABLE users(ID int, name text)"); err != nil {
log.Fatal(err)
}
if _, err := db.Exec(`INSERT INTO users (id, name) VALUES(?, ?)`, 1, "gus"); err != nil {
log.Fatal(err)
}
if _, err := db.Query(`SELECT id, name FROM users`); err != nil {
log.Fatal(err)
}
}

View file

@ -0,0 +1,30 @@
package loghooks
import (
"context"
"log"
"os"
"time"
)
type logger interface {
Printf(string, ...interface{})
}
type Hook struct {
log logger
}
func New() *Hook {
return &Hook{
log: log.New(os.Stderr, "", log.LstdFlags),
}
}
func (h *Hook) Before(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
return context.WithValue(ctx, "started", time.Now()), nil
}
func (h *Hook) After(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
h.log.Printf("Query: `%s`, Args: `%q`. took: %s", query, args, time.Since(ctx.Value("started").(time.Time)))
return ctx, nil
}

View file

@ -0,0 +1,39 @@
package main
import (
"context"
"database/sql"
"log"
"github.com/gchaincl/sqlhooks"
"github.com/gchaincl/sqlhooks/hooks/othooks"
"github.com/mattn/go-sqlite3"
"github.com/opentracing/opentracing-go"
)
func main() {
tracer := opentracing.GlobalTracer()
hooks := othooks.New(tracer)
sql.Register("sqlite3ot", sqlhooks.Wrap(&sqlite3.SQLiteDriver{}, hooks))
db, err := sql.Open("sqlite3ot", ":memory:")
if err != nil {
log.Fatal(err)
}
span := tracer.StartSpan("sql")
defer span.Finish()
ctx := opentracing.ContextWithSpan(context.Background(), span)
if _, err := db.ExecContext(ctx, "CREATE TABLE users(ID int, name text)"); err != nil {
log.Fatal(err)
}
if _, err := db.ExecContext(ctx, `INSERT INTO users (id, name) VALUES(?, ?)`, 1, "gus"); err != nil {
log.Fatal(err)
}
if _, err := db.QueryContext(ctx, `SELECT id, name FROM users`); err != nil {
log.Fatal(err)
}
}

View file

@ -0,0 +1,35 @@
package othooks
import "context"
import "github.com/opentracing/opentracing-go"
import "github.com/opentracing/opentracing-go/ext"
type Hook struct {
tracer opentracing.Tracer
}
func New(tracer opentracing.Tracer) *Hook {
return &Hook{tracer: tracer}
}
func (h *Hook) Before(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
parent := opentracing.SpanFromContext(ctx)
if parent == nil {
return ctx, nil
}
span := h.tracer.StartSpan("sql", opentracing.ChildOf(parent.Context()))
ext.DBStatement.Set(span, query)
return opentracing.ContextWithSpan(ctx, span), nil
}
func (h *Hook) After(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
span := opentracing.SpanFromContext(ctx)
if span != nil {
defer span.Finish()
}
return ctx, nil
}

View file

@ -0,0 +1,74 @@
package othooks
import (
"context"
"database/sql"
"testing"
"github.com/gchaincl/sqlhooks"
sqlite3 "github.com/mattn/go-sqlite3"
opentracing "github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/mocktracer"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
var (
tracer *mocktracer.MockTracer
)
func init() {
tracer = mocktracer.New()
driver := sqlhooks.Wrap(&sqlite3.SQLiteDriver{}, New(tracer))
sql.Register("ot", driver)
}
func TestSpansAreRecorded(t *testing.T) {
db, err := sql.Open("ot", ":memory:")
require.NoError(t, err)
defer db.Close()
tracer.Reset()
parent := tracer.StartSpan("parent")
ctx := opentracing.ContextWithSpan(context.Background(), parent)
{
rows, err := db.QueryContext(ctx, "SELECT 1+?", "1")
require.NoError(t, err)
rows.Close()
}
{
rows, err := db.QueryContext(ctx, "SELECT 1+?", "1")
require.NoError(t, err)
rows.Close()
}
parent.Finish()
spans := tracer.FinishedSpans()
require.Len(t, spans, 3)
span := spans[1]
assert.Equal(t, "sql", span.OperationName)
logFields := span.Logs()[0].Fields
assert.Equal(t, "query", logFields[0].Key)
assert.Equal(t, "SELECT 1+?", logFields[0].ValueString)
assert.Equal(t, "args", logFields[1].Key)
assert.Equal(t, "[1]", logFields[1].ValueString)
assert.NotEmpty(t, span.FinishTime)
}
func TesNoSpansAreRecorded(t *testing.T) {
db, err := sql.Open("ot", ":memory:")
require.NoError(t, err)
defer db.Close()
tracer.Reset()
rows, err := db.QueryContext(context.Background(), "SELECT 1")
require.NoError(t, err)
rows.Close()
assert.Empty(t, tracer.FinishedSpans())
}

View file

@ -0,0 +1,277 @@
package sqlhooks
import (
"context"
"errors"
"database/sql/driver"
)
// Hook is the hook callback signature
type Hook func(ctx context.Context, query string, args ...interface{}) (context.Context, error)
// Hooks instances may be passed to Wrap() to define an instrumented driver
type Hooks interface {
Before(ctx context.Context, query string, args ...interface{}) (context.Context, error)
After(ctx context.Context, query string, args ...interface{}) (context.Context, error)
}
// Driver implements a database/sql/driver.Driver
type Driver struct {
driver.Driver
hooks Hooks
}
// Open opens a connection
func (drv *Driver) Open(name string) (driver.Conn, error) {
conn, err := drv.Driver.Open(name)
if err != nil {
return conn, err
}
wrapped := &Conn{conn, drv.hooks}
if isExecer(conn) {
// If conn implements an Execer interface, return a driver.Conn which
// also implements Execer
return &ExecerContext{wrapped}, nil
}
return wrapped, nil
}
// Conn implements a database/sql.driver.Conn
type Conn struct {
Conn driver.Conn
hooks Hooks
}
func (conn *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
var (
stmt driver.Stmt
err error
)
if c, ok := conn.Conn.(driver.ConnPrepareContext); ok {
stmt, err = c.PrepareContext(ctx, query)
} else {
stmt, err = conn.Prepare(query)
}
if err != nil {
return stmt, err
}
return &Stmt{stmt, conn.hooks, query}, nil
}
func (conn *Conn) Prepare(query string) (driver.Stmt, error) { return conn.Conn.Prepare(query) }
func (conn *Conn) Close() error { return conn.Conn.Close() }
func (conn *Conn) Begin() (driver.Tx, error) { return conn.Conn.Begin() }
func (conn *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
return conn.Conn.(driver.ConnBeginTx).BeginTx(ctx, opts)
}
// ExecerContext implements a database/sql.driver.ExecerContext
type ExecerContext struct {
*Conn
}
func isExecer(conn driver.Conn) bool {
switch conn.(type) {
case driver.ExecerContext:
return true
case driver.Execer:
return true
default:
return false
}
}
func (conn *ExecerContext) execContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
switch c := conn.Conn.Conn.(type) {
case driver.ExecerContext:
return c.ExecContext(ctx, query, args)
case driver.Execer:
dargs, err := namedValueToValue(args)
if err != nil {
return nil, err
}
return c.Exec(query, dargs)
default:
// This should not happen
return nil, errors.New("ExecerContext created for a non Execer driver.Conn")
}
}
func (conn *ExecerContext) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
var err error
list := namedToInterface(args)
// Exec `Before` Hooks
if ctx, err = conn.hooks.Before(ctx, query, list...); err != nil {
return nil, err
}
results, err := conn.execContext(ctx, query, args)
if err != nil {
return results, err
}
if ctx, err = conn.hooks.After(ctx, query, list...); err != nil {
return nil, err
}
return results, err
}
func (conn *ExecerContext) Exec(query string, args []driver.Value) (driver.Result, error) {
// We have to implement Exec since it is required in the current version of
// Go for it to run ExecContext. From Go 10 it will be optional. However,
// this code should never run since database/sql always prefers to run
// ExecContext.
return nil, errors.New("Exec was called when ExecContext was implemented")
}
// Stmt implements a database/sql/driver.Stmt
type Stmt struct {
Stmt driver.Stmt
hooks Hooks
query string
}
func (stmt *Stmt) execContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
if s, ok := stmt.Stmt.(driver.StmtExecContext); ok {
return s.ExecContext(ctx, args)
}
values := make([]driver.Value, len(args))
for _, arg := range args {
values[arg.Ordinal-1] = arg.Value
}
return stmt.Exec(values)
}
func (stmt *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
var err error
list := namedToInterface(args)
// Exec `Before` Hooks
if ctx, err = stmt.hooks.Before(ctx, stmt.query, list...); err != nil {
return nil, err
}
results, err := stmt.execContext(ctx, args)
if err != nil {
return results, err
}
if ctx, err = stmt.hooks.After(ctx, stmt.query, list...); err != nil {
return nil, err
}
return results, err
}
func (stmt *Stmt) queryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
if s, ok := stmt.Stmt.(driver.StmtQueryContext); ok {
return s.QueryContext(ctx, args)
}
values := make([]driver.Value, len(args))
for _, arg := range args {
values[arg.Ordinal-1] = arg.Value
}
return stmt.Query(values)
}
func (stmt *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
var err error
list := namedToInterface(args)
// Exec Before Hooks
if ctx, err = stmt.hooks.Before(ctx, stmt.query, list...); err != nil {
return nil, err
}
rows, err := stmt.queryContext(ctx, args)
if err != nil {
return rows, err
}
if ctx, err = stmt.hooks.After(ctx, stmt.query, list...); err != nil {
return nil, err
}
return rows, err
}
func (stmt *Stmt) Close() error { return stmt.Stmt.Close() }
func (stmt *Stmt) NumInput() int { return stmt.Stmt.NumInput() }
func (stmt *Stmt) Exec(args []driver.Value) (driver.Result, error) { return stmt.Stmt.Exec(args) }
func (stmt *Stmt) Query(args []driver.Value) (driver.Rows, error) { return stmt.Stmt.Query(args) }
// Wrap is used to create a new instrumented driver, it takes a vendor specific driver, and a Hooks instance to produce a new driver instance.
// It's usually used inside a sql.Register() statement
func Wrap(driver driver.Driver, hooks Hooks) driver.Driver {
return &Driver{driver, hooks}
}
func namedToInterface(args []driver.NamedValue) []interface{} {
list := make([]interface{}, len(args))
for i, a := range args {
list[i] = a.Value
}
return list
}
// namedValueToValue copied from database/sql
func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) {
dargs := make([]driver.Value, len(named))
for n, param := range named {
if len(param.Name) > 0 {
return nil, errors.New("sql: driver does not support the use of Named Parameters")
}
dargs[n] = param.Value
}
return dargs, nil
}
/*
type hooks struct {
}
func (h *hooks) Before(ctx context.Context, query string, args ...interface{}) error {
log.Printf("before> ctx = %+v, q=%s, args = %+v\n", ctx, query, args)
return nil
}
func (h *hooks) After(ctx context.Context, query string, args ...interface{}) error {
log.Printf("after> ctx = %+v, q=%s, args = %+v\n", ctx, query, args)
return nil
}
func main() {
sql.Register("sqlite3-proxy", Wrap(&sqlite3.SQLiteDriver{}, &hooks{}))
db, err := sql.Open("sqlite3-proxy", ":memory:")
if err != nil {
log.Fatalln(err)
}
if _, ok := driver.Stmt(&Stmt{}).(driver.StmtExecContext); !ok {
panic("NOPE")
}
if _, err := db.Exec("CREATE table users(id int)"); err != nil {
log.Printf("|err| = %+v\n", err)
}
if _, err := db.QueryContext(context.Background(), "SELECT * FROM users WHERE id = ?", 1); err != nil {
log.Printf("err = %+v\n", err)
}
}
*/

View file

@ -0,0 +1,56 @@
package sqlhooks
import (
"database/sql"
"os"
"testing"
"github.com/go-sql-driver/mysql"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func setUpMySQL(t *testing.T, dsn string) {
db, err := sql.Open("mysql", dsn)
require.NoError(t, err)
require.NoError(t, db.Ping())
defer db.Close()
_, err = db.Exec("CREATE table IF NOT EXISTS users(id int, name text)")
require.NoError(t, err)
}
func TestMySQL(t *testing.T) {
dsn := os.Getenv("SQLHOOKS_MYSQL_DSN")
if dsn == "" {
t.Skipf("SQLHOOKS_MYSQL_DSN not set")
}
setUpMySQL(t, dsn)
s := newSuite(t, &mysql.MySQLDriver{}, dsn)
s.TestHooksExecution(t, "SELECT * FROM users WHERE id = ?", 1)
s.TestHooksArguments(t, "SELECT * FROM users WHERE id = ? AND name = ?", int64(1), "Gus")
s.TestHooksErrors(t, "SELECT 1+1")
t.Run("DBWorks", func(t *testing.T) {
s.hooks.noop()
if _, err := s.db.Exec("DELETE FROM users"); err != nil {
t.Fatal(err)
}
stmt, err := s.db.Prepare("INSERT INTO users (id, name) VALUES(?, ?)")
require.NoError(t, err)
for i := range [5]struct{}{} {
_, err := stmt.Exec(i, "gus")
require.NoError(t, err)
}
var count int
require.NoError(t,
s.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count),
)
assert.Equal(t, 5, count)
})
}

View file

@ -0,0 +1,56 @@
package sqlhooks
import (
"database/sql"
"os"
"testing"
"github.com/lib/pq"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func setUpPostgres(t *testing.T, dsn string) {
db, err := sql.Open("postgres", dsn)
require.NoError(t, err)
require.NoError(t, db.Ping())
defer db.Close()
_, err = db.Exec("CREATE table IF NOT EXISTS users(id int, name text)")
require.NoError(t, err)
}
func TestPostgres(t *testing.T) {
dsn := os.Getenv("SQLHOOKS_POSTGRES_DSN")
if dsn == "" {
t.Skipf("SQLHOOKS_POSTGRES_DSN not set")
}
setUpPostgres(t, dsn)
s := newSuite(t, &pq.Driver{}, dsn)
s.TestHooksExecution(t, "SELECT * FROM users WHERE id = $1", 1)
s.TestHooksArguments(t, "SELECT * FROM users WHERE id = $1 AND name = $2", int64(1), "Gus")
s.TestHooksErrors(t, "SELECT 1+1")
t.Run("DBWorks", func(t *testing.T) {
s.hooks.noop()
if _, err := s.db.Exec("DELETE FROM users"); err != nil {
t.Fatal(err)
}
stmt, err := s.db.Prepare("INSERT INTO users (id, name) VALUES($1, $2)")
require.NoError(t, err)
for i := range [5]struct{}{} {
_, err := stmt.Exec(i, "gus")
require.NoError(t, err)
}
var count int
require.NoError(t,
s.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count),
)
assert.Equal(t, 5, count)
})
}

View file

@ -0,0 +1,54 @@
package sqlhooks
import (
"database/sql"
"os"
"testing"
"time"
sqlite3 "github.com/mattn/go-sqlite3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func setUp(t *testing.T) func() {
dbName := "sqlite3test.db"
db, err := sql.Open("sqlite3", dbName)
require.NoError(t, err)
defer db.Close()
_, err = db.Exec("CREATE table users(id int, name text)")
require.NoError(t, err)
return func() { os.Remove(dbName) }
}
func TestSQLite3(t *testing.T) {
defer setUp(t)()
s := newSuite(t, &sqlite3.SQLiteDriver{}, "sqlite3test.db")
s.TestHooksExecution(t, "SELECT * FROM users WHERE id = ?", 1)
s.TestHooksArguments(t, "SELECT * FROM users WHERE id = ? AND name = ?", int64(1), "Gus")
s.TestHooksErrors(t, "SELECT 1+1")
t.Run("DBWorks", func(t *testing.T) {
s.hooks.noop()
if _, err := s.db.Exec("DELETE FROM users"); err != nil {
t.Fatal(err)
}
stmt, err := s.db.Prepare("INSERT INTO users (id, name) VALUES(?, ?)")
require.NoError(t, err)
for range [5]struct{}{} {
_, err := stmt.Exec(time.Now().UnixNano(), "gus")
require.NoError(t, err)
}
var count int
require.NoError(t,
s.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count),
)
assert.Equal(t, 5, count)
})
}

View file

@ -0,0 +1,167 @@
package sqlhooks
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type testHooks struct {
before Hook
after Hook
}
func (h *testHooks) noop() {
noop := func(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
return ctx, nil
}
h.before, h.after = noop, noop
}
func (h *testHooks) Before(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
return h.before(ctx, query, args...)
}
func (h *testHooks) After(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
return h.after(ctx, query, args...)
}
type suite struct {
db *sql.DB
hooks *testHooks
}
func newSuite(t *testing.T, driver driver.Driver, dsn string) *suite {
hooks := &testHooks{}
driverName := fmt.Sprintf("sqlhooks-%s", time.Now().String())
sql.Register(driverName, Wrap(driver, hooks))
db, err := sql.Open(driverName, dsn)
require.NoError(t, err)
require.NoError(t, db.Ping())
return &suite{db, hooks}
}
func (s *suite) TestHooksExecution(t *testing.T, query string, args ...interface{}) {
var before, after bool
s.hooks.before = func(ctx context.Context, q string, a ...interface{}) (context.Context, error) {
before = true
return ctx, nil
}
s.hooks.after = func(ctx context.Context, q string, a ...interface{}) (context.Context, error) {
after = true
return ctx, nil
}
t.Run("Query", func(t *testing.T) {
before, after = false, false
_, err := s.db.Query(query, args...)
require.NoError(t, err)
assert.True(t, before, "Before Hook did not run for query: "+query)
assert.True(t, after, "After Hook did not run for query: "+query)
})
t.Run("QueryContext", func(t *testing.T) {
before, after = false, false
_, err := s.db.QueryContext(context.Background(), query, args...)
require.NoError(t, err)
assert.True(t, before, "Before Hook did not run for query: "+query)
assert.True(t, after, "After Hook did not run for query: "+query)
})
t.Run("Exec", func(t *testing.T) {
before, after = false, false
_, err := s.db.Exec(query, args...)
require.NoError(t, err)
assert.True(t, before, "Before Hook did not run for query: "+query)
assert.True(t, after, "After Hook did not run for query: "+query)
})
t.Run("ExecContext", func(t *testing.T) {
before, after = false, false
_, err := s.db.ExecContext(context.Background(), query, args...)
require.NoError(t, err)
assert.True(t, before, "Before Hook did not run for query: "+query)
assert.True(t, after, "After Hook did not run for query: "+query)
})
t.Run("Statements", func(t *testing.T) {
before, after = false, false
stmt, err := s.db.Prepare(query)
require.NoError(t, err)
// Hooks just run when the stmt is executed (Query or Exec)
assert.False(t, before, "Before Hook run before execution: "+query)
assert.False(t, after, "After Hook run before execution: "+query)
stmt.Query(args...)
assert.True(t, before, "Before Hook did not run for query: "+query)
assert.True(t, after, "After Hook did not run for query: "+query)
})
}
func (s *suite) testHooksArguments(t *testing.T, query string, args ...interface{}) {
hook := func(ctx context.Context, q string, a ...interface{}) (context.Context, error) {
assert.Equal(t, query, q)
assert.Equal(t, args, a)
assert.Equal(t, "val", ctx.Value("key").(string))
return ctx, nil
}
s.hooks.before = hook
s.hooks.after = hook
ctx := context.WithValue(context.Background(), "key", "val")
{
_, err := s.db.QueryContext(ctx, query, args...)
require.NoError(t, err)
}
{
_, err := s.db.ExecContext(ctx, query, args...)
require.NoError(t, err)
}
}
func (s *suite) TestHooksArguments(t *testing.T, query string, args ...interface{}) {
t.Run("TestHooksArguments", func(t *testing.T) { s.testHooksArguments(t, query, args...) })
}
func (s *suite) testHooksErrors(t *testing.T, query string) {
boom := errors.New("boom")
s.hooks.before = func(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
return ctx, boom
}
s.hooks.after = func(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
assert.False(t, true, "this should not run")
return ctx, nil
}
_, err := s.db.Query(query)
assert.Equal(t, boom, err)
}
func (s *suite) TestHooksErrors(t *testing.T, query string) {
t.Run("TestHooksErrors", func(t *testing.T) { s.testHooksErrors(t, query) })
}
func TestNamedValueToValue(t *testing.T) {
named := []driver.NamedValue{
{Ordinal: 1, Value: "foo"},
{Ordinal: 2, Value: 42},
}
want := []driver.Value{"foo", 42}
dargs, err := namedValueToValue(named)
require.NoError(t, err)
assert.Equal(t, want, dargs)
}