Correctly use provided filters (#2339)

* Apply filters correctly

* Fix issues; Use prepareWithFilters

* Update gmsl & tests

* go.mod..

* PR comments
This commit is contained in:
Till 2022-04-11 09:05:23 +02:00 committed by GitHub
parent b4b2fbc36b
commit 69f2ff7c82
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 109 additions and 67 deletions

View file

@ -233,9 +233,10 @@ func (s *currentRoomStateStatements) SelectCurrentState(
excludeEventIDs []string,
) ([]*gomatrixserverlib.HeaderedEvent, error) {
stmt := sqlutil.TxStmt(txn, s.selectCurrentStateStmt)
senders, notSenders := getSendersStateFilterFilter(stateFilter)
rows, err := stmt.QueryContext(ctx, roomID,
pq.StringArray(stateFilter.Senders),
pq.StringArray(stateFilter.NotSenders),
pq.StringArray(senders),
pq.StringArray(notSenders),
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.Types)),
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.NotTypes)),
stateFilter.ContainsURL,

View file

@ -16,21 +16,45 @@ package postgres
import (
"strings"
"github.com/matrix-org/gomatrixserverlib"
)
// filterConvertWildcardToSQL converts wildcards as defined in
// https://matrix.org/docs/spec/client_server/r0.3.0.html#post-matrix-client-r0-user-userid-filter
// to SQL wildcards that can be used with LIKE()
func filterConvertTypeWildcardToSQL(values []string) []string {
func filterConvertTypeWildcardToSQL(values *[]string) []string {
if values == nil {
// Return nil instead of []string{} so IS NULL can work correctly when
// the return value is passed into SQL queries
return nil
}
ret := make([]string, len(values))
for i := range values {
ret[i] = strings.Replace(values[i], "*", "%", -1)
v := *values
ret := make([]string, len(v))
for i := range v {
ret[i] = strings.Replace(v[i], "*", "%", -1)
}
return ret
}
// TODO: Replace when Dendrite uses Go 1.18
func getSendersRoomEventFilter(filter *gomatrixserverlib.RoomEventFilter) (senders []string, notSenders []string) {
if filter.Senders != nil {
senders = *filter.Senders
}
if filter.NotSenders != nil {
notSenders = *filter.NotSenders
}
return senders, notSenders
}
func getSendersStateFilterFilter(filter *gomatrixserverlib.StateFilter) (senders []string, notSenders []string) {
if filter.Senders != nil {
senders = *filter.Senders
}
if filter.NotSenders != nil {
notSenders = *filter.NotSenders
}
return senders, notSenders
}

View file

@ -204,11 +204,11 @@ func (s *outputRoomEventsStatements) SelectStateInRange(
stateFilter *gomatrixserverlib.StateFilter, roomIDs []string,
) (map[string]map[string]bool, map[string]types.StreamEvent, error) {
stmt := sqlutil.TxStmt(txn, s.selectStateInRangeStmt)
senders, notSenders := getSendersStateFilterFilter(stateFilter)
rows, err := stmt.QueryContext(
ctx, r.Low(), r.High(), pq.StringArray(roomIDs),
pq.StringArray(stateFilter.Senders),
pq.StringArray(stateFilter.NotSenders),
pq.StringArray(senders),
pq.StringArray(notSenders),
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.Types)),
pq.StringArray(filterConvertTypeWildcardToSQL(stateFilter.NotTypes)),
stateFilter.ContainsURL,
@ -353,10 +353,11 @@ func (s *outputRoomEventsStatements) SelectRecentEvents(
} else {
stmt = sqlutil.TxStmt(txn, s.selectRecentEventsStmt)
}
senders, notSenders := getSendersRoomEventFilter(eventFilter)
rows, err := stmt.QueryContext(
ctx, roomID, r.Low(), r.High(),
pq.StringArray(eventFilter.Senders),
pq.StringArray(eventFilter.NotSenders),
pq.StringArray(senders),
pq.StringArray(notSenders),
pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.Types)),
pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.NotTypes)),
eventFilter.Limit+1,
@ -398,11 +399,12 @@ func (s *outputRoomEventsStatements) SelectEarlyEvents(
ctx context.Context, txn *sql.Tx,
roomID string, r types.Range, eventFilter *gomatrixserverlib.RoomEventFilter,
) ([]types.StreamEvent, error) {
senders, notSenders := getSendersRoomEventFilter(eventFilter)
stmt := sqlutil.TxStmt(txn, s.selectEarlyEventsStmt)
rows, err := stmt.QueryContext(
ctx, roomID, r.Low(), r.High(),
pq.StringArray(eventFilter.Senders),
pq.StringArray(eventFilter.NotSenders),
pq.StringArray(senders),
pq.StringArray(notSenders),
pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.Types)),
pq.StringArray(filterConvertTypeWildcardToSQL(eventFilter.NotTypes)),
eventFilter.Limit,
@ -480,10 +482,11 @@ func (s *outputRoomEventsStatements) SelectContextEvent(ctx context.Context, txn
func (s *outputRoomEventsStatements) SelectContextBeforeEvent(
ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter,
) (evts []*gomatrixserverlib.HeaderedEvent, err error) {
senders, notSenders := getSendersRoomEventFilter(filter)
rows, err := sqlutil.TxStmt(txn, s.selectContextBeforeEventStmt).QueryContext(
ctx, roomID, id, filter.Limit,
pq.StringArray(filter.Senders),
pq.StringArray(filter.NotSenders),
pq.StringArray(senders),
pq.StringArray(notSenders),
pq.StringArray(filterConvertTypeWildcardToSQL(filter.Types)),
pq.StringArray(filterConvertTypeWildcardToSQL(filter.NotTypes)),
)
@ -512,10 +515,11 @@ func (s *outputRoomEventsStatements) SelectContextBeforeEvent(
func (s *outputRoomEventsStatements) SelectContextAfterEvent(
ctx context.Context, txn *sql.Tx, id int, roomID string, filter *gomatrixserverlib.RoomEventFilter,
) (lastID int, evts []*gomatrixserverlib.HeaderedEvent, err error) {
senders, notSenders := getSendersRoomEventFilter(filter)
rows, err := sqlutil.TxStmt(txn, s.selectContextAfterEventStmt).QueryContext(
ctx, roomID, id, filter.Limit,
pq.StringArray(filter.Senders),
pq.StringArray(filter.NotSenders),
pq.StringArray(senders),
pq.StringArray(notSenders),
pq.StringArray(filterConvertTypeWildcardToSQL(filter.Types)),
pq.StringArray(filterConvertTypeWildcardToSQL(filter.NotTypes)),
)

View file

@ -41,10 +41,10 @@ const insertAccountDataSQL = "" +
" ON CONFLICT (user_id, room_id, type) DO UPDATE" +
" SET id = $5"
// further parameters are added by prepareWithFilters
const selectAccountDataInRangeSQL = "" +
"SELECT room_id, type FROM syncapi_account_data_type" +
" WHERE user_id = $1 AND id > $2 AND id <= $3" +
" ORDER BY id ASC"
" WHERE user_id = $1 AND id > $2 AND id <= $3"
const selectMaxAccountDataIDSQL = "" +
"SELECT MAX(id) FROM syncapi_account_data_type"
@ -94,18 +94,25 @@ func (s *accountDataStatements) SelectAccountDataInRange(
ctx context.Context,
userID string,
r types.Range,
accountDataFilterPart *gomatrixserverlib.EventFilter,
filter *gomatrixserverlib.EventFilter,
) (data map[string][]string, err error) {
data = make(map[string][]string)
stmt, params, err := prepareWithFilters(
s.db, nil, selectAccountDataInRangeSQL,
[]interface{}{
userID, r.Low(), r.High(),
},
filter.Senders, filter.NotSenders,
filter.Types, filter.NotTypes,
[]string{}, filter.Limit, FilterOrderAsc,
)
rows, err := s.selectAccountDataInRangeStmt.QueryContext(ctx, userID, r.Low(), r.High())
rows, err := stmt.QueryContext(ctx, params...)
if err != nil {
return
}
defer internal.CloseAndLogIfError(ctx, rows, "selectAccountDataInRange: rows.close() failed")
var entries int
for rows.Next() {
var dataType string
var roomID string
@ -114,31 +121,11 @@ func (s *accountDataStatements) SelectAccountDataInRange(
return
}
// check if we should add this by looking at the filter.
// It would be nice if we could do this in SQL-land, but the mix of variadic
// and positional parameters makes the query annoyingly hard to do, it's easier
// and clearer to do it in Go-land. If there are no filters for [not]types then
// this gets skipped.
for _, includeType := range accountDataFilterPart.Types {
if includeType != dataType { // TODO: wildcard support
continue
}
}
for _, excludeType := range accountDataFilterPart.NotTypes {
if excludeType == dataType { // TODO: wildcard support
continue
}
}
if len(data[roomID]) > 0 {
data[roomID] = append(data[roomID], dataType)
} else {
data[roomID] = []string{dataType}
}
entries++
if entries >= accountDataFilterPart.Limit {
break
}
}
return data, nil

View file

@ -25,32 +25,48 @@ const (
// parts.
func prepareWithFilters(
db *sql.DB, txn *sql.Tx, query string, params []interface{},
senders, notsenders, types, nottypes []string, excludeEventIDs []string,
senders, notsenders, types, nottypes *[]string, excludeEventIDs []string,
limit int, order FilterOrder,
) (*sql.Stmt, []interface{}, error) {
offset := len(params)
if count := len(senders); count > 0 {
query += " AND sender IN " + sqlutil.QueryVariadicOffset(count, offset)
for _, v := range senders {
params, offset = append(params, v), offset+1
if senders != nil {
if count := len(*senders); count > 0 {
query += " AND sender IN " + sqlutil.QueryVariadicOffset(count, offset)
for _, v := range *senders {
params, offset = append(params, v), offset+1
}
} else {
query += ` AND sender = ""`
}
}
if count := len(notsenders); count > 0 {
query += " AND sender NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
for _, v := range notsenders {
params, offset = append(params, v), offset+1
if notsenders != nil {
if count := len(*notsenders); count > 0 {
query += " AND sender NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
for _, v := range *notsenders {
params, offset = append(params, v), offset+1
}
} else {
query += ` AND sender NOT = ""`
}
}
if count := len(types); count > 0 {
query += " AND type IN " + sqlutil.QueryVariadicOffset(count, offset)
for _, v := range types {
params, offset = append(params, v), offset+1
if types != nil {
if count := len(*types); count > 0 {
query += " AND type IN " + sqlutil.QueryVariadicOffset(count, offset)
for _, v := range *types {
params, offset = append(params, v), offset+1
}
} else {
query += ` AND type = ""`
}
}
if count := len(nottypes); count > 0 {
query += " AND type NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
for _, v := range nottypes {
params, offset = append(params, v), offset+1
if nottypes != nil {
if count := len(*nottypes); count > 0 {
query += " AND type NOT IN " + sqlutil.QueryVariadicOffset(count, offset)
for _, v := range *nottypes {
params, offset = append(params, v), offset+1
}
} else {
query += ` AND type NOT = ""`
}
}
if count := len(excludeEventIDs); count > 0 {