Correctly use provided filters (#2339)

* Apply filters correctly

* Fix issues; Use prepareWithFilters

* Update gmsl & tests

* go.mod..

* PR comments
This commit is contained in:
Till 2022-04-11 09:05:23 +02:00 committed by GitHub
parent b4b2fbc36b
commit 69f2ff7c82
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 109 additions and 67 deletions

View file

@ -41,10 +41,10 @@ const insertAccountDataSQL = "" +
" ON CONFLICT (user_id, room_id, type) DO UPDATE" +
" SET id = $5"
// further parameters are added by prepareWithFilters
const selectAccountDataInRangeSQL = "" +
"SELECT room_id, type FROM syncapi_account_data_type" +
" WHERE user_id = $1 AND id > $2 AND id <= $3" +
" ORDER BY id ASC"
" WHERE user_id = $1 AND id > $2 AND id <= $3"
const selectMaxAccountDataIDSQL = "" +
"SELECT MAX(id) FROM syncapi_account_data_type"
@ -94,18 +94,25 @@ func (s *accountDataStatements) SelectAccountDataInRange(
ctx context.Context,
userID string,
r types.Range,
accountDataFilterPart *gomatrixserverlib.EventFilter,
filter *gomatrixserverlib.EventFilter,
) (data map[string][]string, err error) {
data = make(map[string][]string)
stmt, params, err := prepareWithFilters(
s.db, nil, selectAccountDataInRangeSQL,
[]interface{}{
userID, r.Low(), r.High(),
},
filter.Senders, filter.NotSenders,
filter.Types, filter.NotTypes,
[]string{}, filter.Limit, FilterOrderAsc,
)
rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, r.Low(), r.High())
rows, err := stmt.QueryContext(ctx, params...)
if err != nil {
return
}
defer internal.CloseAndLogIfError(ctx, rows, "selectAccountDataInRange: rows.close() failed")
var entries int
for rows.Next() {
var dataType string
var roomID string
@ -114,31 +121,11 @@ func (s *accountDataStatements) SelectAccountDataInRange(
return
}
// check if we should add this by looking at the filter.
// It would be nice if we could do this in SQL-land, but the mix of variadic
// and positional parameters makes the query annoyingly hard to do, it's easier
// and clearer to do it in Go-land. If there are no filters for [not]types then
// this gets skipped.
for _, includeType := range accountDataFilterPart.Types {
if includeType != dataType { // TODO: wildcard support
continue
}
}
for _, excludeType := range accountDataFilterPart.NotTypes {
if excludeType == dataType { // TODO: wildcard support
continue
}
}
if len(data[roomID]) > 0 {
data[roomID] = append(data[roomID], dataType)
} else {
data[roomID] = []string{dataType}
}
entries++
if entries >= accountDataFilterPart.Limit {
break
}
}
return data, nil

View file

@ -25,32 +25,48 @@ const (
// parts.
func prepareWithFilters(
db *sql.DB, txn *sql.Tx, query string, params []interface{},
senders, notsenders, types, nottypes []string, excludeEventIDs []string,
senders, notsenders, types, nottypes *[]string, excludeEventIDs []string,
limit int, order FilterOrder,
) (*sql.Stmt, []interface{}, error) {
offset := len(params)
if count := len(senders); count > 0 {
query += " AND sender IN " + sqlutil.QueryVariadicOffset(count, offset)
for _, v := range senders {
params, offset = append(params, v), offset+1
if senders != nil {
if count := len(*senders); count > 0 {
query += " AND sender IN " + sqlutil.QueryVariadicOffset(count, offset)
for _, v := range *senders {
params, offset = append(params, v), offset+1
}
} else {
query += ` AND sender = ""`
}
}
if count := len(notsenders); count > 0 {
query += " AND sender NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
for _, v := range notsenders {
params, offset = append(params, v), offset+1
if notsenders != nil {
if count := len(*notsenders); count > 0 {
query += " AND sender NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
for _, v := range *notsenders {
params, offset = append(params, v), offset+1
}
} else {
query += ` AND sender NOT = ""`
}
}
if count := len(types); count > 0 {
query += " AND type IN " + sqlutil.QueryVariadicOffset(count, offset)
for _, v := range types {
params, offset = append(params, v), offset+1
if types != nil {
if count := len(*types); count > 0 {
query += " AND type IN " + sqlutil.QueryVariadicOffset(count, offset)
for _, v := range *types {
params, offset = append(params, v), offset+1
}
} else {
query += ` AND type = ""`
}
}
if count := len(nottypes); count > 0 {
query += " AND type NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
for _, v := range nottypes {
params, offset = append(params, v), offset+1
if nottypes != nil {
if count := len(*nottypes); count > 0 {
query += " AND type NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
for _, v := range *nottypes {
params, offset = append(params, v), offset+1
}
} else {
query += ` AND type NOT = ""`
}
}
if count := len(excludeEventIDs); count > 0 {