Start sql tracing

This commit is contained in:
Erik Johnston 2017-12-13 10:37:28 +00:00
parent 7e07f8ae7d
commit 9ed5205b84
22 changed files with 1190 additions and 8 deletions

View file

@ -16,6 +16,13 @@ package common
import (
"database/sql"
"fmt"
"github.com/gchaincl/sqlhooks"
"github.com/gchaincl/sqlhooks/hooks/othooks"
"github.com/lib/pq"
"github.com/matrix-org/util"
opentracing "github.com/opentracing/opentracing-go"
)
// A Transaction is something that can be committed or rolledback.
@ -66,3 +73,22 @@ func TxStmt(transaction *sql.Tx, statement *sql.Stmt) *sql.Stmt {
}
return statement
}
type NewTracerFactory interface {
CreateNewTracer(name string) opentracing.Tracer
}
// OpenPostgresWithTracing creates a new DB instance where calls will be
// traced with the given tracer
func OpenPostgresWithTracing(tracerFactory NewTracerFactory, databaseName, connstr string) (*sql.DB, error) {
tracer := tracerFactory.CreateNewTracer("sql - " + databaseName)
hooks := othooks.New(tracer)
// This is a hack to get around the fact that you can't directly open
// a sql.DB with a given driver, you *have* to register it.
registrationName := fmt.Sprintf("postgres-ot-%s", util.RandomString(5))
sql.Register(registrationName, sqlhooks.Wrap(&pq.Driver{}, hooks))
return sql.Open(registrationName, connstr)
}

View file

@ -54,10 +54,10 @@ type SyncServerDatabase struct {
}
// NewSyncServerDatabase creates a new sync server database
func NewSyncServerDatabase(dataSourceName string) (*SyncServerDatabase, error) {
func NewSyncServerDatabase(tracerFactory common.NewTracerFactory, dataSourceName string) (*SyncServerDatabase, error) {
var d SyncServerDatabase
var err error
if d.db, err = sql.Open("postgres", dataSourceName); err != nil {
if d.db, err = common.OpenPostgresWithTracing(tracerFactory, "sync", dataSourceName); err != nil {
return nil, err
}
if err = d.PartitionOffsetStatements.Prepare(d.db, "syncapi"); err != nil {

View file

@ -39,7 +39,7 @@ func SetupSyncAPIComponent(
) {
tracer := base.CreateNewTracer("SyncAPI")
syncDB, err := storage.NewSyncServerDatabase(string(base.Cfg.Database.SyncAPI))
syncDB, err := storage.NewSyncServerDatabase(base, string(base.Cfg.Database.SyncAPI))
if err != nil {
logrus.WithError(err).Panicf("failed to connect to sync db")
}

6
vendor/manifest vendored
View file

@ -77,6 +77,12 @@
"revision": "44cc805cf13205b55f69e14bcb69867d1ae92f98",
"branch": "master"
},
{
"importpath": "github.com/gchaincl/sqlhooks",
"repository": "https://github.com/gchaincl/sqlhooks",
"revision": "b4a12bad76664eae8012d196ed901f8fa8f87909",
"branch": "master"
},
{
"importpath": "github.com/golang/protobuf/proto",
"repository": "https://github.com/golang/protobuf",

View file

@ -0,0 +1,41 @@
# Change Log
## [Unreleased](https://github.com/gchaincl/sqlhooks/tree/HEAD)
[Full Changelog](https://github.com/gchaincl/sqlhooks/compare/v1.0.0...HEAD)
**Closed issues:**
- Add Benchmarks [\#9](https://github.com/gchaincl/sqlhooks/issues/9)
## [v1.0.0](https://github.com/gchaincl/sqlhooks/tree/v1.0.0) (2017-05-08)
[Full Changelog](https://github.com/gchaincl/sqlhooks/compare/v0.4...v1.0.0)
**Merged pull requests:**
- Godoc [\#7](https://github.com/gchaincl/sqlhooks/pull/7) ([gchaincl](https://github.com/gchaincl))
- Make covermode=count [\#6](https://github.com/gchaincl/sqlhooks/pull/6) ([gchaincl](https://github.com/gchaincl))
- V1 [\#5](https://github.com/gchaincl/sqlhooks/pull/5) ([gchaincl](https://github.com/gchaincl))
- Expose a WrapDriver function [\#4](https://github.com/gchaincl/sqlhooks/issues/4)
- Implement new 1.8 interfaces [\#3](https://github.com/gchaincl/sqlhooks/issues/3)
## [v0.4](https://github.com/gchaincl/sqlhooks/tree/v0.4) (2017-03-23)
[Full Changelog](https://github.com/gchaincl/sqlhooks/compare/v0.3...v0.4)
## [v0.3](https://github.com/gchaincl/sqlhooks/tree/v0.3) (2016-06-02)
[Full Changelog](https://github.com/gchaincl/sqlhooks/compare/v0.2...v0.3)
**Closed issues:**
- Change Notifications [\#2](https://github.com/gchaincl/sqlhooks/issues/2)
## [v0.2](https://github.com/gchaincl/sqlhooks/tree/v0.2) (2016-05-01)
[Full Changelog](https://github.com/gchaincl/sqlhooks/compare/v0.1...v0.2)
## [v0.1](https://github.com/gchaincl/sqlhooks/tree/v0.1) (2016-04-25)
**Merged pull requests:**
- Sqlite3 [\#1](https://github.com/gchaincl/sqlhooks/pull/1) ([gchaincl](https://github.com/gchaincl))
\* *This Change Log was automatically generated by [github_changelog_generator](https://github.com/skywinder/Github-Changelog-Generator)*

View file

@ -0,0 +1,21 @@
The MIT License (MIT)
Copyright (c) 2016 Gustavo Chaín
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View file

@ -0,0 +1,82 @@
# sqlhooks [![Build Status](https://travis-ci.org/gchaincl/sqlhooks.svg)](https://travis-ci.org/gchaincl/sqlhooks) [![Coverage Status](https://coveralls.io/repos/github/gchaincl/sqlhooks/badge.svg?branch=master)](https://coveralls.io/github/gchaincl/sqlhooks?branch=master) [![Go Report Card](https://goreportcard.com/badge/github.com/gchaincl/sqlhooks)](https://goreportcard.com/report/github.com/gchaincl/sqlhooks)
Attach hooks to any database/sql driver.
The purpose of sqlhooks is to provide a way to instrument your sql statements, making really easy to log queries or measure execution time without modifying your actual code.
# Install
```bash
go get github.com/gchaincl/sqlhooks
```
## Breaking changes
`V1` isn't backward compatible with previous versions, if you want to fetch old versions, you can get them from [gopkg.in](http://gopkg.in/)
```bash
go get gopkg.in/gchaincl/sqlhooks.v0
```
# Usage [![GoDoc](https://godoc.org/github.com/gchaincl/dotsql?status.svg)](https://godoc.org/github.com/gchaincl/sqlhooks)
```go
// This example shows how to instrument sql queries in order to display the time that they consume
package main
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/gchaincl/sqlhooks"
"github.com/mattn/go-sqlite3"
)
// Hooks satisfies the sqlhook.Hooks interface
type Hooks struct {}
// Before hook will print the query with it's args and return the context with the timestamp
func (h *Hooks) Before(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
fmt.Printf("> %s %q", query, args)
return context.WithValue(ctx, "begin", time.Now()), nil
}
// After hook will get the timestamp registered on the Before hook and print the elapsed time
func (h *Hooks) After(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
begin := ctx.Value("begin").(time.Time)
fmt.Printf(". took: %s\n", time.Since(begin))
return ctx, nil
}
func main() {
// First, register the wrapper
sql.Register("sqlite3WithHooks", sqlhooks.Wrap(&sqlite3.SQLiteDriver{}, &Hooks{}))
// Connect to the registered wrapped driver
db, _ := sql.Open("sqlite3WithHooks", ":memory:")
// Do you're stuff
db.Exec("CREATE TABLE t (id INTEGER, text VARCHAR(16))")
db.Exec("INSERT into t (text) VALUES(?), (?)", "foo", "bar")
db.Query("SELECT id, text FROM t")
}
/*
Output should look like:
> CREATE TABLE t (id INTEGER, text VARCHAR(16)) []. took: 121.238µs
> INSERT into t (text) VALUES(?), (?) ["foo" "bar"]. took: 36.364µs
> SELECT id, text FROM t []. took: 4.653µs
*/
```
# Benchmarks
```
go test -bench=. -benchmem
BenchmarkSQLite3/Without_Hooks-4 200000 8572 ns/op 627 B/op 16 allocs/op
BenchmarkSQLite3/With_Hooks-4 200000 10231 ns/op 738 B/op 18 allocs/op
BenchmarkMySQL/Without_Hooks-4 10000 108421 ns/op 437 B/op 10 allocs/op
BenchmarkMySQL/With_Hooks-4 10000 226085 ns/op 597 B/op 13 allocs/op
BenchmarkPostgres/Without_Hooks-4 10000 125718 ns/op 649 B/op 17 allocs/op
BenchmarkPostgres/With_Hooks-4 5000 354831 ns/op 1122 B/op 27 allocs/op
PASS
ok github.com/gchaincl/sqlhooks 11.713s
```

View file

@ -0,0 +1,76 @@
package sqlhooks
import (
"database/sql"
"os"
"testing"
"github.com/go-sql-driver/mysql"
"github.com/lib/pq"
"github.com/mattn/go-sqlite3"
"github.com/stretchr/testify/require"
)
func init() {
hooks := &testHooks{}
hooks.noop()
sql.Register("sqlite3-benchmark", Wrap(&sqlite3.SQLiteDriver{}, hooks))
sql.Register("mysql-benchmark", Wrap(&mysql.MySQLDriver{}, hooks))
sql.Register("postgres-benchmark", Wrap(&pq.Driver{}, hooks))
}
func benchmark(b *testing.B, driver, dsn string) {
db, err := sql.Open(driver, dsn)
require.NoError(b, err)
defer db.Close()
var query = "SELECT 'hello'"
b.ResetTimer()
for i := 0; i < b.N; i++ {
rows, err := db.Query(query)
require.NoError(b, err)
require.NoError(b, rows.Close())
}
}
func BenchmarkSQLite3(b *testing.B) {
b.Run("Without Hooks", func(b *testing.B) {
benchmark(b, "sqlite3", ":memory:")
})
b.Run("With Hooks", func(b *testing.B) {
benchmark(b, "sqlite3-benchmark", ":memory:")
})
}
func BenchmarkMySQL(b *testing.B) {
dsn := os.Getenv("SQLHOOKS_MYSQL_DSN")
if dsn == "" {
b.Skipf("SQLHOOKS_MYSQL_DSN not set")
}
b.Run("Without Hooks", func(b *testing.B) {
benchmark(b, "mysql", dsn)
})
b.Run("With Hooks", func(b *testing.B) {
benchmark(b, "mysql-benchmark", dsn)
})
}
func BenchmarkPostgres(b *testing.B) {
dsn := os.Getenv("SQLHOOKS_POSTGRES_DSN")
if dsn == "" {
b.Skipf("SQLHOOKS_POSTGRES_DSN not set")
}
b.Run("Without Hooks", func(b *testing.B) {
benchmark(b, "postgres", dsn)
})
b.Run("With Hooks", func(b *testing.B) {
benchmark(b, "postgres-benchmark", dsn)
})
}

View file

@ -0,0 +1,52 @@
// package sqlhooks allows you to attach hooks to any database/sql driver.
// The purpose of sqlhooks is to provide a way to instrument your sql statements, making really easy to log queries or measure execution time without modifying your actual code.
// This example shows how to instrument sql queries in order to display the time that they consume
// package main
//
// import (
// "context"
// "database/sql"
// "fmt"
// "time"
//
// "github.com/gchaincl/sqlhooks"
// "github.com/mattn/go-sqlite3"
// )
//
// // Hooks satisfies the sqlhook.Hooks interface
// type Hooks struct {}
//
// // Before hook will print the query with it's args and return the context with the timestamp
// func (h *Hooks) Before(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
// fmt.Printf("> %s %q", query, args)
// return context.WithValue(ctx, "begin", time.Now()), nil
// }
//
// // After hook will get the timestamp registered on the Before hook and print the elapsed time
// func (h *Hooks) After(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
// begin := ctx.Value("begin").(time.Time)
// fmt.Printf(". took: %s\n", time.Since(begin))
// return ctx, nil
// }
//
// func main() {
// // First, register the wrapper
// sql.Register("sqlite3WithHooks", sqlhooks.Wrap(&sqlite3.SQLiteDriver{}, &Hooks{}))
//
// // Connect to the registered wrapped driver
// db, _ := sql.Open("sqlite3WithHooks", ":memory:")
//
// // Do you're stuff
// db.Exec("CREATE TABLE t (id INTEGER, text VARCHAR(16))")
// db.Exec("INSERT into t (text) VALUES(?), (?)", "foo", "bar")
// db.Query("SELECT id, text FROM t")
// }
//
// /*
// Output should look like:
// > CREATE TABLE t (id INTEGER, text VARCHAR(16)) []. took: 121.238µs
// > INSERT into t (text) VALUES(?), (?) ["foo" "bar"]. took: 36.364µs
// > SELECT id, text FROM t []. took: 4.653µs
// */
package sqlhooks

View file

@ -0,0 +1,17 @@
package loghooks
import (
"database/sql"
"github.com/gchaincl/sqlhooks"
sqlite3 "github.com/mattn/go-sqlite3"
)
func Example() {
driver := sqlhooks.Wrap(&sqlite3.SQLiteDriver{}, New())
sql.Register("sqlite3-logger", driver)
db, _ := sql.Open("sqlite3-logger", ":memory:")
// This query will output logs
db.Query("SELECT 1+1")
}

View file

@ -0,0 +1,31 @@
package main
import (
"database/sql"
"log"
"github.com/gchaincl/sqlhooks"
"github.com/gchaincl/sqlhooks/hooks/loghooks"
"github.com/mattn/go-sqlite3"
)
func main() {
sql.Register("sqlite3log", sqlhooks.Wrap(&sqlite3.SQLiteDriver{}, loghooks.New()))
db, err := sql.Open("sqlite3log", ":memory:")
if err != nil {
log.Fatal(err)
}
if _, err := db.Exec("CREATE TABLE users(ID int, name text)"); err != nil {
log.Fatal(err)
}
if _, err := db.Exec(`INSERT INTO users (id, name) VALUES(?, ?)`, 1, "gus"); err != nil {
log.Fatal(err)
}
if _, err := db.Query(`SELECT id, name FROM users`); err != nil {
log.Fatal(err)
}
}

View file

@ -0,0 +1,30 @@
package loghooks
import (
"context"
"log"
"os"
"time"
)
type logger interface {
Printf(string, ...interface{})
}
type Hook struct {
log logger
}
func New() *Hook {
return &Hook{
log: log.New(os.Stderr, "", log.LstdFlags),
}
}
func (h *Hook) Before(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
return context.WithValue(ctx, "started", time.Now()), nil
}
func (h *Hook) After(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
h.log.Printf("Query: `%s`, Args: `%q`. took: %s", query, args, time.Since(ctx.Value("started").(time.Time)))
return ctx, nil
}

View file

@ -0,0 +1,39 @@
package main
import (
"context"
"database/sql"
"log"
"github.com/gchaincl/sqlhooks"
"github.com/gchaincl/sqlhooks/hooks/othooks"
"github.com/mattn/go-sqlite3"
"github.com/opentracing/opentracing-go"
)
func main() {
tracer := opentracing.GlobalTracer()
hooks := othooks.New(tracer)
sql.Register("sqlite3ot", sqlhooks.Wrap(&sqlite3.SQLiteDriver{}, hooks))
db, err := sql.Open("sqlite3ot", ":memory:")
if err != nil {
log.Fatal(err)
}
span := tracer.StartSpan("sql")
defer span.Finish()
ctx := opentracing.ContextWithSpan(context.Background(), span)
if _, err := db.ExecContext(ctx, "CREATE TABLE users(ID int, name text)"); err != nil {
log.Fatal(err)
}
if _, err := db.ExecContext(ctx, `INSERT INTO users (id, name) VALUES(?, ?)`, 1, "gus"); err != nil {
log.Fatal(err)
}
if _, err := db.QueryContext(ctx, `SELECT id, name FROM users`); err != nil {
log.Fatal(err)
}
}

View file

@ -0,0 +1,36 @@
package othooks
import (
"context"
"github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/ext"
)
type Hook struct {
tracer opentracing.Tracer
}
func New(tracer opentracing.Tracer) *Hook {
return &Hook{tracer: tracer}
}
func (h *Hook) Before(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
parent := opentracing.SpanFromContext(ctx)
if parent == nil {
return ctx, nil
}
span := h.tracer.StartSpan("sql", opentracing.ChildOf(parent.Context()))
ext.DBStatement.Set(span, query)
return opentracing.ContextWithSpan(ctx, span), nil
}
func (h *Hook) After(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
span := opentracing.SpanFromContext(ctx)
if span != nil {
defer span.Finish()
}
return ctx, nil
}

View file

@ -0,0 +1,74 @@
package othooks
import (
"context"
"database/sql"
"testing"
"github.com/gchaincl/sqlhooks"
sqlite3 "github.com/mattn/go-sqlite3"
opentracing "github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/mocktracer"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
var (
tracer *mocktracer.MockTracer
)
func init() {
tracer = mocktracer.New()
driver := sqlhooks.Wrap(&sqlite3.SQLiteDriver{}, New(tracer))
sql.Register("ot", driver)
}
func TestSpansAreRecorded(t *testing.T) {
db, err := sql.Open("ot", ":memory:")
require.NoError(t, err)
defer db.Close()
tracer.Reset()
parent := tracer.StartSpan("parent")
ctx := opentracing.ContextWithSpan(context.Background(), parent)
{
rows, err := db.QueryContext(ctx, "SELECT 1+?", "1")
require.NoError(t, err)
rows.Close()
}
{
rows, err := db.QueryContext(ctx, "SELECT 1+?", "1")
require.NoError(t, err)
rows.Close()
}
parent.Finish()
spans := tracer.FinishedSpans()
require.Len(t, spans, 3)
span := spans[1]
assert.Equal(t, "sql", span.OperationName)
logFields := span.Logs()[0].Fields
assert.Equal(t, "query", logFields[0].Key)
assert.Equal(t, "SELECT 1+?", logFields[0].ValueString)
assert.Equal(t, "args", logFields[1].Key)
assert.Equal(t, "[1]", logFields[1].ValueString)
assert.NotEmpty(t, span.FinishTime)
}
func TesNoSpansAreRecorded(t *testing.T) {
db, err := sql.Open("ot", ":memory:")
require.NoError(t, err)
defer db.Close()
tracer.Reset()
rows, err := db.QueryContext(context.Background(), "SELECT 1")
require.NoError(t, err)
rows.Close()
assert.Empty(t, tracer.FinishedSpans())
}

View file

@ -0,0 +1,276 @@
package sqlhooks
import (
"context"
"database/sql/driver"
"errors"
)
// Hook is the hook callback signature
type Hook func(ctx context.Context, query string, args ...interface{}) (context.Context, error)
// Hooks instances may be passed to Wrap() to define an instrumented driver
type Hooks interface {
Before(ctx context.Context, query string, args ...interface{}) (context.Context, error)
After(ctx context.Context, query string, args ...interface{}) (context.Context, error)
}
// Driver implements a database/sql/driver.Driver
type Driver struct {
driver.Driver
hooks Hooks
}
// Open opens a connection
func (drv *Driver) Open(name string) (driver.Conn, error) {
conn, err := drv.Driver.Open(name)
if err != nil {
return conn, err
}
wrapped := &Conn{conn, drv.hooks}
if isExecer(conn) {
// If conn implements an Execer interface, return a driver.Conn which
// also implements Execer
return &ExecerContext{wrapped}, nil
}
return wrapped, nil
}
// Conn implements a database/sql.driver.Conn
type Conn struct {
Conn driver.Conn
hooks Hooks
}
func (conn *Conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
var (
stmt driver.Stmt
err error
)
if c, ok := conn.Conn.(driver.ConnPrepareContext); ok {
stmt, err = c.PrepareContext(ctx, query)
} else {
stmt, err = conn.Prepare(query)
}
if err != nil {
return stmt, err
}
return &Stmt{stmt, conn.hooks, query}, nil
}
func (conn *Conn) Prepare(query string) (driver.Stmt, error) { return conn.Conn.Prepare(query) }
func (conn *Conn) Close() error { return conn.Conn.Close() }
func (conn *Conn) Begin() (driver.Tx, error) { return conn.Conn.Begin() }
func (conn *Conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
return conn.Conn.(driver.ConnBeginTx).BeginTx(ctx, opts)
}
// ExecerContext implements a database/sql.driver.ExecerContext
type ExecerContext struct {
*Conn
}
func isExecer(conn driver.Conn) bool {
switch conn.(type) {
case driver.ExecerContext:
return true
case driver.Execer:
return true
default:
return false
}
}
func (conn *ExecerContext) execContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
switch c := conn.Conn.Conn.(type) {
case driver.ExecerContext:
return c.ExecContext(ctx, query, args)
case driver.Execer:
dargs, err := namedValueToValue(args)
if err != nil {
return nil, err
}
return c.Exec(query, dargs)
default:
// This should not happen
return nil, errors.New("ExecerContext created for a non Execer driver.Conn")
}
}
func (conn *ExecerContext) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
var err error
list := namedToInterface(args)
// Exec `Before` Hooks
if ctx, err = conn.hooks.Before(ctx, query, list...); err != nil {
return nil, err
}
results, err := conn.execContext(ctx, query, args)
if err != nil {
return results, err
}
if ctx, err = conn.hooks.After(ctx, query, list...); err != nil {
return nil, err
}
return results, err
}
func (conn *ExecerContext) Exec(query string, args []driver.Value) (driver.Result, error) {
// We have to implement Exec since it is required in the current version of
// Go for it to run ExecContext. From Go 10 it will be optional. However,
// this code should never run since database/sql always prefers to run
// ExecContext.
return nil, errors.New("Exec was called when ExecContext was implemented")
}
// Stmt implements a database/sql/driver.Stmt
type Stmt struct {
Stmt driver.Stmt
hooks Hooks
query string
}
func (stmt *Stmt) execContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
if s, ok := stmt.Stmt.(driver.StmtExecContext); ok {
return s.ExecContext(ctx, args)
}
values := make([]driver.Value, len(args))
for _, arg := range args {
values[arg.Ordinal-1] = arg.Value
}
return stmt.Exec(values)
}
func (stmt *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
var err error
list := namedToInterface(args)
// Exec `Before` Hooks
if ctx, err = stmt.hooks.Before(ctx, stmt.query, list...); err != nil {
return nil, err
}
results, err := stmt.execContext(ctx, args)
if err != nil {
return results, err
}
if ctx, err = stmt.hooks.After(ctx, stmt.query, list...); err != nil {
return nil, err
}
return results, err
}
func (stmt *Stmt) queryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
if s, ok := stmt.Stmt.(driver.StmtQueryContext); ok {
return s.QueryContext(ctx, args)
}
values := make([]driver.Value, len(args))
for _, arg := range args {
values[arg.Ordinal-1] = arg.Value
}
return stmt.Query(values)
}
func (stmt *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
var err error
list := namedToInterface(args)
// Exec Before Hooks
if ctx, err = stmt.hooks.Before(ctx, stmt.query, list...); err != nil {
return nil, err
}
rows, err := stmt.queryContext(ctx, args)
if err != nil {
return rows, err
}
if ctx, err = stmt.hooks.After(ctx, stmt.query, list...); err != nil {
return nil, err
}
return rows, err
}
func (stmt *Stmt) Close() error { return stmt.Stmt.Close() }
func (stmt *Stmt) NumInput() int { return stmt.Stmt.NumInput() }
func (stmt *Stmt) Exec(args []driver.Value) (driver.Result, error) { return stmt.Stmt.Exec(args) }
func (stmt *Stmt) Query(args []driver.Value) (driver.Rows, error) { return stmt.Stmt.Query(args) }
// Wrap is used to create a new instrumented driver, it takes a vendor specific driver, and a Hooks instance to produce a new driver instance.
// It's usually used inside a sql.Register() statement
func Wrap(driver driver.Driver, hooks Hooks) driver.Driver {
return &Driver{driver, hooks}
}
func namedToInterface(args []driver.NamedValue) []interface{} {
list := make([]interface{}, len(args))
for i, a := range args {
list[i] = a.Value
}
return list
}
// namedValueToValue copied from database/sql
func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) {
dargs := make([]driver.Value, len(named))
for n, param := range named {
if len(param.Name) > 0 {
return nil, errors.New("sql: driver does not support the use of Named Parameters")
}
dargs[n] = param.Value
}
return dargs, nil
}
/*
type hooks struct {
}
func (h *hooks) Before(ctx context.Context, query string, args ...interface{}) error {
log.Printf("before> ctx = %+v, q=%s, args = %+v\n", ctx, query, args)
return nil
}
func (h *hooks) After(ctx context.Context, query string, args ...interface{}) error {
log.Printf("after> ctx = %+v, q=%s, args = %+v\n", ctx, query, args)
return nil
}
func main() {
sql.Register("sqlite3-proxy", Wrap(&sqlite3.SQLiteDriver{}, &hooks{}))
db, err := sql.Open("sqlite3-proxy", ":memory:")
if err != nil {
log.Fatalln(err)
}
if _, ok := driver.Stmt(&Stmt{}).(driver.StmtExecContext); !ok {
panic("NOPE")
}
if _, err := db.Exec("CREATE table users(id int)"); err != nil {
log.Printf("|err| = %+v\n", err)
}
if _, err := db.QueryContext(context.Background(), "SELECT * FROM users WHERE id = ?", 1); err != nil {
log.Printf("err = %+v\n", err)
}
}
*/

View file

@ -0,0 +1,56 @@
package sqlhooks
import (
"database/sql"
"os"
"testing"
"github.com/go-sql-driver/mysql"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func setUpMySQL(t *testing.T, dsn string) {
db, err := sql.Open("mysql", dsn)
require.NoError(t, err)
require.NoError(t, db.Ping())
defer db.Close()
_, err = db.Exec("CREATE table IF NOT EXISTS users(id int, name text)")
require.NoError(t, err)
}
func TestMySQL(t *testing.T) {
dsn := os.Getenv("SQLHOOKS_MYSQL_DSN")
if dsn == "" {
t.Skipf("SQLHOOKS_MYSQL_DSN not set")
}
setUpMySQL(t, dsn)
s := newSuite(t, &mysql.MySQLDriver{}, dsn)
s.TestHooksExecution(t, "SELECT * FROM users WHERE id = ?", 1)
s.TestHooksArguments(t, "SELECT * FROM users WHERE id = ? AND name = ?", int64(1), "Gus")
s.TestHooksErrors(t, "SELECT 1+1")
t.Run("DBWorks", func(t *testing.T) {
s.hooks.noop()
if _, err := s.db.Exec("DELETE FROM users"); err != nil {
t.Fatal(err)
}
stmt, err := s.db.Prepare("INSERT INTO users (id, name) VALUES(?, ?)")
require.NoError(t, err)
for i := range [5]struct{}{} {
_, err := stmt.Exec(i, "gus")
require.NoError(t, err)
}
var count int
require.NoError(t,
s.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count),
)
assert.Equal(t, 5, count)
})
}

View file

@ -0,0 +1,56 @@
package sqlhooks
import (
"database/sql"
"os"
"testing"
"github.com/lib/pq"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func setUpPostgres(t *testing.T, dsn string) {
db, err := sql.Open("postgres", dsn)
require.NoError(t, err)
require.NoError(t, db.Ping())
defer db.Close()
_, err = db.Exec("CREATE table IF NOT EXISTS users(id int, name text)")
require.NoError(t, err)
}
func TestPostgres(t *testing.T) {
dsn := os.Getenv("SQLHOOKS_POSTGRES_DSN")
if dsn == "" {
t.Skipf("SQLHOOKS_POSTGRES_DSN not set")
}
setUpPostgres(t, dsn)
s := newSuite(t, &pq.Driver{}, dsn)
s.TestHooksExecution(t, "SELECT * FROM users WHERE id = $1", 1)
s.TestHooksArguments(t, "SELECT * FROM users WHERE id = $1 AND name = $2", int64(1), "Gus")
s.TestHooksErrors(t, "SELECT 1+1")
t.Run("DBWorks", func(t *testing.T) {
s.hooks.noop()
if _, err := s.db.Exec("DELETE FROM users"); err != nil {
t.Fatal(err)
}
stmt, err := s.db.Prepare("INSERT INTO users (id, name) VALUES($1, $2)")
require.NoError(t, err)
for i := range [5]struct{}{} {
_, err := stmt.Exec(i, "gus")
require.NoError(t, err)
}
var count int
require.NoError(t,
s.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count),
)
assert.Equal(t, 5, count)
})
}

View file

@ -0,0 +1,54 @@
package sqlhooks
import (
"database/sql"
"os"
"testing"
"time"
sqlite3 "github.com/mattn/go-sqlite3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func setUp(t *testing.T) func() {
dbName := "sqlite3test.db"
db, err := sql.Open("sqlite3", dbName)
require.NoError(t, err)
defer db.Close()
_, err = db.Exec("CREATE table users(id int, name text)")
require.NoError(t, err)
return func() { os.Remove(dbName) }
}
func TestSQLite3(t *testing.T) {
defer setUp(t)()
s := newSuite(t, &sqlite3.SQLiteDriver{}, "sqlite3test.db")
s.TestHooksExecution(t, "SELECT * FROM users WHERE id = ?", 1)
s.TestHooksArguments(t, "SELECT * FROM users WHERE id = ? AND name = ?", int64(1), "Gus")
s.TestHooksErrors(t, "SELECT 1+1")
t.Run("DBWorks", func(t *testing.T) {
s.hooks.noop()
if _, err := s.db.Exec("DELETE FROM users"); err != nil {
t.Fatal(err)
}
stmt, err := s.db.Prepare("INSERT INTO users (id, name) VALUES(?, ?)")
require.NoError(t, err)
for range [5]struct{}{} {
_, err := stmt.Exec(time.Now().UnixNano(), "gus")
require.NoError(t, err)
}
var count int
require.NoError(t,
s.db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count),
)
assert.Equal(t, 5, count)
})
}

View file

@ -0,0 +1,167 @@
package sqlhooks
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type testHooks struct {
before Hook
after Hook
}
func (h *testHooks) noop() {
noop := func(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
return ctx, nil
}
h.before, h.after = noop, noop
}
func (h *testHooks) Before(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
return h.before(ctx, query, args...)
}
func (h *testHooks) After(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
return h.after(ctx, query, args...)
}
type suite struct {
db *sql.DB
hooks *testHooks
}
func newSuite(t *testing.T, driver driver.Driver, dsn string) *suite {
hooks := &testHooks{}
driverName := fmt.Sprintf("sqlhooks-%s", time.Now().String())
sql.Register(driverName, Wrap(driver, hooks))
db, err := sql.Open(driverName, dsn)
require.NoError(t, err)
require.NoError(t, db.Ping())
return &suite{db, hooks}
}
func (s *suite) TestHooksExecution(t *testing.T, query string, args ...interface{}) {
var before, after bool
s.hooks.before = func(ctx context.Context, q string, a ...interface{}) (context.Context, error) {
before = true
return ctx, nil
}
s.hooks.after = func(ctx context.Context, q string, a ...interface{}) (context.Context, error) {
after = true
return ctx, nil
}
t.Run("Query", func(t *testing.T) {
before, after = false, false
_, err := s.db.Query(query, args...)
require.NoError(t, err)
assert.True(t, before, "Before Hook did not run for query: "+query)
assert.True(t, after, "After Hook did not run for query: "+query)
})
t.Run("QueryContext", func(t *testing.T) {
before, after = false, false
_, err := s.db.QueryContext(context.Background(), query, args...)
require.NoError(t, err)
assert.True(t, before, "Before Hook did not run for query: "+query)
assert.True(t, after, "After Hook did not run for query: "+query)
})
t.Run("Exec", func(t *testing.T) {
before, after = false, false
_, err := s.db.Exec(query, args...)
require.NoError(t, err)
assert.True(t, before, "Before Hook did not run for query: "+query)
assert.True(t, after, "After Hook did not run for query: "+query)
})
t.Run("ExecContext", func(t *testing.T) {
before, after = false, false
_, err := s.db.ExecContext(context.Background(), query, args...)
require.NoError(t, err)
assert.True(t, before, "Before Hook did not run for query: "+query)
assert.True(t, after, "After Hook did not run for query: "+query)
})
t.Run("Statements", func(t *testing.T) {
before, after = false, false
stmt, err := s.db.Prepare(query)
require.NoError(t, err)
// Hooks just run when the stmt is executed (Query or Exec)
assert.False(t, before, "Before Hook run before execution: "+query)
assert.False(t, after, "After Hook run before execution: "+query)
stmt.Query(args...)
assert.True(t, before, "Before Hook did not run for query: "+query)
assert.True(t, after, "After Hook did not run for query: "+query)
})
}
func (s *suite) testHooksArguments(t *testing.T, query string, args ...interface{}) {
hook := func(ctx context.Context, q string, a ...interface{}) (context.Context, error) {
assert.Equal(t, query, q)
assert.Equal(t, args, a)
assert.Equal(t, "val", ctx.Value("key").(string))
return ctx, nil
}
s.hooks.before = hook
s.hooks.after = hook
ctx := context.WithValue(context.Background(), "key", "val")
{
_, err := s.db.QueryContext(ctx, query, args...)
require.NoError(t, err)
}
{
_, err := s.db.ExecContext(ctx, query, args...)
require.NoError(t, err)
}
}
func (s *suite) TestHooksArguments(t *testing.T, query string, args ...interface{}) {
t.Run("TestHooksArguments", func(t *testing.T) { s.testHooksArguments(t, query, args...) })
}
func (s *suite) testHooksErrors(t *testing.T, query string) {
boom := errors.New("boom")
s.hooks.before = func(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
return ctx, boom
}
s.hooks.after = func(ctx context.Context, query string, args ...interface{}) (context.Context, error) {
assert.False(t, true, "this should not run")
return ctx, nil
}
_, err := s.db.Query(query)
assert.Equal(t, boom, err)
}
func (s *suite) TestHooksErrors(t *testing.T, query string) {
t.Run("TestHooksErrors", func(t *testing.T) { s.testHooksErrors(t, query) })
}
func TestNamedValueToValue(t *testing.T) {
named := []driver.NamedValue{
{Ordinal: 1, Value: "foo"},
{Ordinal: 2, Value: 42},
}
want := []driver.Value{"foo", 42}
dargs, err := namedValueToValue(named)
require.NoError(t, err)
assert.Equal(t, want, dargs)
}

View file

@ -45,15 +45,22 @@ type Message struct {
Key []byte
Value []byte
Timestamp time.Time
Headers []sarama.RecordHeader
}
func (m *Message) consumerMessage(topic string) *sarama.ConsumerMessage {
var headers []*sarama.RecordHeader
for _, header := range m.Headers {
headers = append(headers, &header)
}
return &sarama.ConsumerMessage{
Topic: topic,
Offset: m.Offset,
Key: m.Key,
Value: m.Value,
Timestamp: m.Timestamp,
Headers: headers,
}
}
@ -321,6 +328,8 @@ func (t *topic) send(now time.Time, pmsgs []*sarama.ProducerMessage) error {
}
pmsgs[i].Timestamp = now
msgs[i].Timestamp = now
msgs[i].Headers = pmsgs[i].Headers
}
// Take the lock before assigning the offsets.
t.mutex.Lock()

View file

@ -4,6 +4,11 @@ import (
"database/sql"
"sync"
"time"
"github.com/lib/pq"
"github.com/pkg/errors"
sarama "gopkg.in/Shopify/sarama.v1"
)
const postgresqlSchema = `
@ -21,6 +26,7 @@ CREATE TABLE IF NOT EXISTS naffka_messages (
message_key BYTEA NOT NULL,
message_value BYTEA NOT NULL,
message_timestamp_ns BIGINT NOT NULL,
message_headers BYTEA[] NOT NULL, -- RecordHeaders stored in alternating key value pairs
UNIQUE (topic_nid, message_offset)
);
`
@ -37,11 +43,11 @@ const selectTopicsSQL = "" +
"SELECT topic_name, topic_nid FROM naffka_topics"
const insertMessageSQL = "" +
"INSERT INTO naffka_messages (topic_nid, message_offset, message_key, message_value, message_timestamp_ns)" +
" VALUES ($1, $2, $3, $4, $5)"
"INSERT INTO naffka_messages (topic_nid, message_offset, message_key, message_value, message_timestamp_ns, message_headers)" +
" VALUES ($1, $2, $3, $4, $5, $6)"
const selectMessagesSQL = "" +
"SELECT message_offset, message_key, message_value, message_timestamp_ns" +
"SELECT message_offset, message_key, message_value, message_timestamp_ns, message_headers" +
" FROM naffka_messages WHERE topic_nid = $1 AND $2 <= message_offset AND message_offset < $3" +
" ORDER BY message_offset ASC"
@ -104,7 +110,13 @@ func (p *postgresqlDatabase) StoreMessages(topic string, messages []Message) err
return err
}
for _, m := range messages {
_, err = s.Exec(topicNID, m.Offset, m.Key, m.Value, m.Timestamp.UnixNano())
// We store the headers as alternating key value pairs
var headers [][]byte
for _, h := range m.Headers {
headers = append(headers, h.Key, h.Value)
}
_, err = s.Exec(topicNID, m.Offset, m.Key, m.Value, m.Timestamp.UnixNano(), pq.Array(headers))
if err != nil {
return err
}
@ -130,15 +142,36 @@ func (p *postgresqlDatabase) FetchMessages(topic string, startOffset, endOffset
key []byte
value []byte
timestampNano int64
headerlists pq.ByteaArray
)
if err = rows.Scan(&offset, &key, &value, &timestampNano); err != nil {
if err = rows.Scan(&offset, &key, &value, &timestampNano, &headerlists); err != nil {
return
}
// We store the headers as alternating key value pairs, so check that
// there are an even number
if len(headerlists)%2 != 0 {
err = errors.Errorf(
"message_headers has non even number of entries for topic %s offset %d",
topic, offset,
)
return
}
var headers []sarama.RecordHeader
for i := 0; i < len(headerlists); i += 2 {
headers = append(headers, sarama.RecordHeader{
Key: headerlists[i],
Value: headerlists[i+1],
})
}
messages = append(messages, Message{
Offset: offset,
Key: key,
Value: value,
Timestamp: time.Unix(0, timestampNano),
Headers: headers,
})
}
return