Don't use more than 999 variables in SQLite querys. (#1425)

* Don't use more than 999 variables in SQLite querys.

Solve this problem in a more general and reusable way.
Also fix #1369
Add some unit tests.

Signed-off-by: Henrik Sölver <henrik.solver@gmail.com>

* Don't rely on testify for basic assertions

* Readability improvements and linting

Co-authored-by: Henrik Sölver <henrik.solver@gmail.com>
This commit is contained in:
Kegsay 2020-09-14 16:39:38 +01:00 committed by GitHub
parent 913020e4b7
commit 8dc9506210
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 255 additions and 33 deletions

1
go.mod
View file

@ -1,6 +1,7 @@
module github.com/matrix-org/dendrite module github.com/matrix-org/dendrite
require ( require (
github.com/DATA-DOG/go-sqlmock v1.5.0
github.com/Shopify/sarama v1.27.0 github.com/Shopify/sarama v1.27.0
github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd // indirect github.com/codahale/hdrhistogram v0.0.0-20161010025455-3a0bb77429bd // indirect
github.com/gologme/log v1.2.0 github.com/gologme/log v1.2.0

2
go.sum
View file

@ -13,6 +13,8 @@ github.com/Arceliar/phony v0.0.0-20191006174943-d0c68492aca0 h1:p3puK8Sl2xK+2Fnn
github.com/Arceliar/phony v0.0.0-20191006174943-d0c68492aca0/go.mod h1:6Lkn+/zJilRMsKmbmG1RPoamiArC6HS73xbwRyp3UyI= github.com/Arceliar/phony v0.0.0-20191006174943-d0c68492aca0/go.mod h1:6Lkn+/zJilRMsKmbmG1RPoamiArC6HS73xbwRyp3UyI=
github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ= github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60=
github.com/DATA-DOG/go-sqlmock v1.5.0/go.mod h1:f/Ixk793poVmq4qj/V1dPUg2JEAKC73Q5eFN3EC/SaM=
github.com/Kubuxu/go-os-helper v0.0.1/go.mod h1:N8B+I7vPCT80IcP58r50u4+gEEcsZETFUpAzWW2ep1Y= github.com/Kubuxu/go-os-helper v0.0.1/go.mod h1:N8B+I7vPCT80IcP58r50u4+gEEcsZETFUpAzWW2ep1Y=
github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU= github.com/OneOfOne/xxhash v1.2.2/go.mod h1:HSdplMjZKSmBqAxg5vPj2TmRDmfkzw+cTzAElWljhcU=
github.com/PuerkitoBio/goquery v1.5.1/go.mod h1:GsLWisAFVj4WgDibEWF4pvYnkVQBpKBKeU+7zCJoLcc= github.com/PuerkitoBio/goquery v1.5.1/go.mod h1:GsLWisAFVj4WgDibEWF4pvYnkVQBpKBKeU+7zCJoLcc=

View file

@ -15,10 +15,14 @@
package sqlutil package sqlutil
import ( import (
"context"
"database/sql" "database/sql"
"errors" "errors"
"fmt" "fmt"
"runtime" "runtime"
"strings"
"github.com/matrix-org/util"
) )
// ErrUserExists is returned if a username already exists in the database. // ErrUserExists is returned if a username already exists in the database.
@ -107,3 +111,44 @@ func SQLiteDriverName() string {
} }
return "sqlite3" return "sqlite3"
} }
func minOfInts(a, b int) int {
if a <= b {
return a
}
return b
}
// QueryProvider defines the interface for querys used by RunLimitedVariablesQuery.
type QueryProvider interface {
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
}
// SQLite3MaxVariables is the default maximum number of host parameters in a single SQL statement
// SQLlite can handle. See https://www.sqlite.org/limits.html for more information.
const SQLite3MaxVariables = 999
// RunLimitedVariablesQuery split up a query with more variables than the used database can handle in multiple queries.
func RunLimitedVariablesQuery(ctx context.Context, query string, qp QueryProvider, variables []interface{}, limit uint, rowHandler func(*sql.Rows) error) error {
var start int
for start < len(variables) {
n := minOfInts(len(variables)-start, int(limit))
nextQuery := strings.Replace(query, "($1)", QueryVariadic(n), 1)
rows, err := qp.QueryContext(ctx, nextQuery, variables[start:start+n]...)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("QueryContext returned an error")
return err
}
err = rowHandler(rows)
if closeErr := rows.Close(); closeErr != nil {
util.GetLogger(ctx).WithError(closeErr).Error("RunLimitedVariablesQuery: failed to close rows")
return err
}
if err != nil {
util.GetLogger(ctx).WithError(err).Error("RunLimitedVariablesQuery: rowHandler returned error")
return err
}
start = start + n
}
return nil
}

View file

@ -0,0 +1,173 @@
package sqlutil
import (
"context"
"database/sql"
"reflect"
"testing"
sqlmock "github.com/DATA-DOG/go-sqlmock"
)
func TestShouldReturnCorrectAmountOfResulstIfFewerVariablesThanLimit(t *testing.T) {
db, mock, err := sqlmock.New()
assertNoError(t, err, "Failed to make DB")
limit := uint(4)
r := mock.NewRows([]string{"id"}).
AddRow(1).
AddRow(2).
AddRow(3)
mock.ExpectQuery(`SELECT id WHERE id IN \(\$1, \$2, \$3\)`).WillReturnRows(r)
// nolint:goconst
q := "SELECT id WHERE id IN ($1)"
v := []int{1, 2, 3}
iKeyIDs := make([]interface{}, len(v))
for i, d := range v {
iKeyIDs[i] = d
}
ctx := context.Background()
var result = make([]int, 0)
err = RunLimitedVariablesQuery(ctx, q, db, iKeyIDs, limit, func(rows *sql.Rows) error {
for rows.Next() {
var id int
err = rows.Scan(&id)
assertNoError(t, err, "rows.Scan returned an error")
result = append(result, id)
}
return nil
})
assertNoError(t, err, "Call returned an error")
if len(result) != len(v) {
t.Fatalf("Result should be 3 long")
}
}
func TestShouldReturnCorrectAmountOfResulstIfEqualVariablesAsLimit(t *testing.T) {
db, mock, err := sqlmock.New()
assertNoError(t, err, "Failed to make DB")
limit := uint(4)
r := mock.NewRows([]string{"id"}).
AddRow(1).
AddRow(2).
AddRow(3).
AddRow(4)
mock.ExpectQuery(`SELECT id WHERE id IN \(\$1, \$2, \$3, \$4\)`).WillReturnRows(r)
// nolint:goconst
q := "SELECT id WHERE id IN ($1)"
v := []int{1, 2, 3, 4}
iKeyIDs := make([]interface{}, len(v))
for i, d := range v {
iKeyIDs[i] = d
}
ctx := context.Background()
var result = make([]int, 0)
err = RunLimitedVariablesQuery(ctx, q, db, iKeyIDs, limit, func(rows *sql.Rows) error {
for rows.Next() {
var id int
err = rows.Scan(&id)
assertNoError(t, err, "rows.Scan returned an error")
result = append(result, id)
}
return nil
})
assertNoError(t, err, "Call returned an error")
if len(result) != len(v) {
t.Fatalf("Result should be 4 long")
}
}
func TestShouldReturnCorrectAmountOfResultsIfMoreVariablesThanLimit(t *testing.T) {
db, mock, err := sqlmock.New()
assertNoError(t, err, "Failed to make DB")
limit := uint(4)
r1 := mock.NewRows([]string{"id"}).
AddRow(1).
AddRow(2).
AddRow(3).
AddRow(4)
r2 := mock.NewRows([]string{"id"}).
AddRow(5)
mock.ExpectQuery(`SELECT id WHERE id IN \(\$1, \$2, \$3, \$4\)`).WillReturnRows(r1)
mock.ExpectQuery(`SELECT id WHERE id IN \(\$1\)`).WillReturnRows(r2)
// nolint:goconst
q := "SELECT id WHERE id IN ($1)"
v := []int{1, 2, 3, 4, 5}
iKeyIDs := make([]interface{}, len(v))
for i, d := range v {
iKeyIDs[i] = d
}
ctx := context.Background()
var result = make([]int, 0)
err = RunLimitedVariablesQuery(ctx, q, db, iKeyIDs, limit, func(rows *sql.Rows) error {
for rows.Next() {
var id int
err = rows.Scan(&id)
assertNoError(t, err, "rows.Scan returned an error")
result = append(result, id)
}
return nil
})
assertNoError(t, err, "Call returned an error")
if len(result) != len(v) {
t.Fatalf("Result should be 5 long")
}
if !reflect.DeepEqual(v, result) {
t.Fatalf("Result is not as expected: got %v want %v", v, result)
}
}
func TestShouldReturnErrorIfRowsScanReturnsError(t *testing.T) {
db, mock, err := sqlmock.New()
assertNoError(t, err, "Failed to make DB")
limit := uint(4)
// adding a string ID should result in rows.Scan returning an error
r := mock.NewRows([]string{"id"}).
AddRow("hej").
AddRow(2).
AddRow(3)
mock.ExpectQuery(`SELECT id WHERE id IN \(\$1, \$2, \$3\)`).WillReturnRows(r)
// nolint:goconst
q := "SELECT id WHERE id IN ($1)"
v := []int{-1, -2, 3}
iKeyIDs := make([]interface{}, len(v))
for i, d := range v {
iKeyIDs[i] = d
}
ctx := context.Background()
var result = make([]uint, 0)
err = RunLimitedVariablesQuery(ctx, q, db, iKeyIDs, limit, func(rows *sql.Rows) error {
for rows.Next() {
var id uint
err = rows.Scan(&id)
if err != nil {
return err
}
result = append(result, id)
}
return nil
})
if err == nil {
t.Fatalf("Call did not return an error")
}
}
func assertNoError(t *testing.T, err error, msg string) {
t.Helper()
if err == nil {
return
}
t.Fatalf(msg)
}

View file

@ -18,9 +18,8 @@ package sqlite3
import ( import (
"context" "context"
"database/sql" "database/sql"
"strings" "fmt"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -88,41 +87,36 @@ func (s *serverKeyStatements) bulkSelectServerKeys(
ctx context.Context, ctx context.Context,
requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp, requests map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.Timestamp,
) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) { ) (map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, error) {
var nameAndKeyIDs []string nameAndKeyIDs := make([]string, 0, len(requests))
for request := range requests { for request := range requests {
nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request)) nameAndKeyIDs = append(nameAndKeyIDs, nameAndKeyID(request))
} }
results := make(map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult, len(requests))
query := strings.Replace(bulkSelectServerKeysSQL, "($1)", sqlutil.QueryVariadic(len(nameAndKeyIDs)), 1)
iKeyIDs := make([]interface{}, len(nameAndKeyIDs)) iKeyIDs := make([]interface{}, len(nameAndKeyIDs))
for i, v := range nameAndKeyIDs { for i, v := range nameAndKeyIDs {
iKeyIDs[i] = v iKeyIDs[i] = v
} }
rows, err := s.db.QueryContext(ctx, query, iKeyIDs...) err := sqlutil.RunLimitedVariablesQuery(
if err != nil { ctx, bulkSelectServerKeysSQL, s.db, iKeyIDs, sqlutil.SQLite3MaxVariables,
return nil, err func(rows *sql.Rows) error {
}
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectServerKeys: rows.close() failed")
results := map[gomatrixserverlib.PublicKeyLookupRequest]gomatrixserverlib.PublicKeyLookupResult{}
for rows.Next() { for rows.Next() {
var serverName string var serverName string
var keyID string var keyID string
var key string var key string
var validUntilTS int64 var validUntilTS int64
var expiredTS int64 var expiredTS int64
if err = rows.Scan(&serverName, &keyID, &validUntilTS, &expiredTS, &key); err != nil { if err := rows.Scan(&serverName, &keyID, &validUntilTS, &expiredTS, &key); err != nil {
return nil, err return fmt.Errorf("bulkSelectServerKeys: %v", err)
} }
r := gomatrixserverlib.PublicKeyLookupRequest{ r := gomatrixserverlib.PublicKeyLookupRequest{
ServerName: gomatrixserverlib.ServerName(serverName), ServerName: gomatrixserverlib.ServerName(serverName),
KeyID: gomatrixserverlib.KeyID(keyID), KeyID: gomatrixserverlib.KeyID(keyID),
} }
vk := gomatrixserverlib.VerifyKey{} vk := gomatrixserverlib.VerifyKey{}
err = vk.Key.Decode(key) err := vk.Key.Decode(key)
if err != nil { if err != nil {
return nil, err return fmt.Errorf("bulkSelectServerKeys: %v", err)
} }
results[r] = gomatrixserverlib.PublicKeyLookupResult{ results[r] = gomatrixserverlib.PublicKeyLookupResult{
VerifyKey: vk, VerifyKey: vk,
@ -130,6 +124,13 @@ func (s *serverKeyStatements) bulkSelectServerKeys(
ExpiredTS: gomatrixserverlib.Timestamp(expiredTS), ExpiredTS: gomatrixserverlib.Timestamp(expiredTS),
} }
} }
return nil
},
)
if err != nil {
return nil, err
}
return results, nil return results, nil
} }