Add roomserver tests (3/4) (#2447)

* Add Room Aliases tests

* Add Rooms table test

* Move StateKeyTuplerSorter to the types package

* Add StateBlock tests
Some optimizations

* Add State Snapshot tests
Some optimization

* Return []int64 and convert to pq.Int64Array for postgres

* Move []types.EventNID back to rows.Next()

* Update tests, rename SelectRoomIDs
This commit is contained in:
Till 2022-05-16 19:33:16 +02:00 committed by GitHub
parent 6af35385ba
commit 05607d6b87
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 570 additions and 313 deletions

View file

@ -264,11 +264,11 @@ func (s *eventStatements) BulkSelectStateEventByNID(
ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID,
stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntry, error) {
tuples := stateKeyTupleSorter(stateKeyTuples)
tuples := types.StateKeyTupleSorter(stateKeyTuples)
sort.Sort(tuples)
eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays()
eventTypeNIDArray, eventStateKeyNIDArray := tuples.TypesAndStateKeysAsArrays()
stmt := sqlutil.TxStmt(txn, s.bulkSelectStateEventByNIDStmt)
rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs), eventTypeNIDArray, eventStateKeyNIDArray)
rows, err := stmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs), pq.Int64Array(eventTypeNIDArray), pq.Int64Array(eventStateKeyNIDArray))
if err != nil {
return nil, err
}

View file

@ -61,12 +61,12 @@ type roomAliasesStatements struct {
deleteRoomAliasStmt *sql.Stmt
}
func createRoomAliasesTable(db *sql.DB) error {
func CreateRoomAliasesTable(db *sql.DB) error {
_, err := db.Exec(roomAliasesSchema)
return err
}
func prepareRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) {
func PrepareRoomAliasesTable(db *sql.DB) (tables.RoomAliases, error) {
s := &roomAliasesStatements{}
return s, sqlutil.StatementList{
@ -108,8 +108,8 @@ func (s *roomAliasesStatements) SelectAliasesFromRoomID(
defer internal.CloseAndLogIfError(ctx, rows, "selectAliasesFromRoomID: rows.close() failed")
var aliases []string
var alias string
for rows.Next() {
var alias string
if err = rows.Scan(&alias); err != nil {
return nil, err
}

View file

@ -95,12 +95,12 @@ type roomStatements struct {
bulkSelectRoomNIDsStmt *sql.Stmt
}
func createRoomsTable(db *sql.DB) error {
func CreateRoomsTable(db *sql.DB) error {
_, err := db.Exec(roomsSchema)
return err
}
func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
func PrepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
s := &roomStatements{}
return s, sqlutil.StatementList{
@ -117,7 +117,7 @@ func prepareRoomsTable(db *sql.DB) (tables.Rooms, error) {
}.Prepare(db)
}
func (s *roomStatements) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]string, error) {
func (s *roomStatements) SelectRoomIDsWithEvents(ctx context.Context, txn *sql.Tx) ([]string, error) {
stmt := sqlutil.TxStmt(txn, s.selectRoomIDsStmt)
rows, err := stmt.QueryContext(ctx)
if err != nil {
@ -125,8 +125,8 @@ func (s *roomStatements) SelectRoomIDs(ctx context.Context, txn *sql.Tx) ([]stri
}
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomIDsStmt: rows.close() failed")
var roomIDs []string
var roomID string
for rows.Next() {
var roomID string
if err = rows.Scan(&roomID); err != nil {
return nil, err
}
@ -231,9 +231,9 @@ func (s *roomStatements) SelectRoomVersionsForRoomNIDs(
}
defer internal.CloseAndLogIfError(ctx, rows, "selectRoomVersionsForRoomNIDsStmt: rows.close() failed")
result := make(map[types.RoomNID]gomatrixserverlib.RoomVersion)
var roomNID types.RoomNID
var roomVersion gomatrixserverlib.RoomVersion
for rows.Next() {
var roomNID types.RoomNID
var roomVersion gomatrixserverlib.RoomVersion
if err = rows.Scan(&roomNID, &roomVersion); err != nil {
return nil, err
}
@ -254,8 +254,8 @@ func (s *roomStatements) BulkSelectRoomIDs(ctx context.Context, txn *sql.Tx, roo
}
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomIDsStmt: rows.close() failed")
var roomIDs []string
var roomID string
for rows.Next() {
var roomID string
if err = rows.Scan(&roomID); err != nil {
return nil, err
}
@ -276,8 +276,8 @@ func (s *roomStatements) BulkSelectRoomNIDs(ctx context.Context, txn *sql.Tx, ro
}
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectRoomNIDsStmt: rows.close() failed")
var roomNIDs []types.RoomNID
var roomNID types.RoomNID
for rows.Next() {
var roomNID types.RoomNID
if err = rows.Scan(&roomNID); err != nil {
return nil, err
}

View file

@ -19,7 +19,6 @@ import (
"context"
"database/sql"
"fmt"
"sort"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal"
@ -71,12 +70,12 @@ type stateBlockStatements struct {
bulkSelectStateBlockEntriesStmt *sql.Stmt
}
func createStateBlockTable(db *sql.DB) error {
func CreateStateBlockTable(db *sql.DB) error {
_, err := db.Exec(stateDataSchema)
return err
}
func prepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
func PrepareStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
s := &stateBlockStatements{}
return s, sqlutil.StatementList{
@ -90,9 +89,9 @@ func (s *stateBlockStatements) BulkInsertStateData(
entries types.StateEntries,
) (id types.StateBlockNID, err error) {
entries = entries[:util.SortAndUnique(entries)]
var nids types.EventNIDs
for _, e := range entries {
nids = append(nids, e.EventNID)
nids := make(types.EventNIDs, entries.Len())
for i := range entries {
nids[i] = entries[i].EventNID
}
stmt := sqlutil.TxStmt(txn, s.insertStateDataStmt)
err = stmt.QueryRowContext(
@ -113,15 +112,15 @@ func (s *stateBlockStatements) BulkSelectStateBlockEntries(
results := make([][]types.EventNID, len(stateBlockNIDs))
i := 0
var stateBlockNID types.StateBlockNID
var result pq.Int64Array
for ; rows.Next(); i++ {
var stateBlockNID types.StateBlockNID
var result pq.Int64Array
if err = rows.Scan(&stateBlockNID, &result); err != nil {
return nil, err
}
r := []types.EventNID{}
for _, e := range result {
r = append(r, types.EventNID(e))
r := make([]types.EventNID, len(result))
for x := range result {
r[x] = types.EventNID(result[x])
}
results[i] = r
}
@ -141,35 +140,3 @@ func stateBlockNIDsAsArray(stateBlockNIDs []types.StateBlockNID) pq.Int64Array {
}
return pq.Int64Array(nids)
}
type stateKeyTupleSorter []types.StateKeyTuple
func (s stateKeyTupleSorter) Len() int { return len(s) }
func (s stateKeyTupleSorter) Less(i, j int) bool { return s[i].LessThan(s[j]) }
func (s stateKeyTupleSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
// Check whether a tuple is in the list. Assumes that the list is sorted.
func (s stateKeyTupleSorter) contains(value types.StateKeyTuple) bool {
i := sort.Search(len(s), func(i int) bool { return !s[i].LessThan(value) })
return i < len(s) && s[i] == value
}
// List the unique eventTypeNIDs and eventStateKeyNIDs.
// Assumes that the list is sorted.
func (s stateKeyTupleSorter) typesAndStateKeysAsArrays() (eventTypeNIDs pq.Int64Array, eventStateKeyNIDs pq.Int64Array) {
eventTypeNIDs = make(pq.Int64Array, len(s))
eventStateKeyNIDs = make(pq.Int64Array, len(s))
for i := range s {
eventTypeNIDs[i] = int64(s[i].EventTypeNID)
eventStateKeyNIDs[i] = int64(s[i].EventStateKeyNID)
}
eventTypeNIDs = eventTypeNIDs[:util.SortAndUnique(int64Sorter(eventTypeNIDs))]
eventStateKeyNIDs = eventStateKeyNIDs[:util.SortAndUnique(int64Sorter(eventStateKeyNIDs))]
return
}
type int64Sorter []int64
func (s int64Sorter) Len() int { return len(s) }
func (s int64Sorter) Less(i, j int) bool { return s[i] < s[j] }
func (s int64Sorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }

View file

@ -1,86 +0,0 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"sort"
"testing"
"github.com/matrix-org/dendrite/roomserver/types"
)
func TestStateKeyTupleSorter(t *testing.T) {
input := stateKeyTupleSorter{
{EventTypeNID: 1, EventStateKeyNID: 2},
{EventTypeNID: 1, EventStateKeyNID: 4},
{EventTypeNID: 2, EventStateKeyNID: 2},
{EventTypeNID: 1, EventStateKeyNID: 1},
}
want := []types.StateKeyTuple{
{EventTypeNID: 1, EventStateKeyNID: 1},
{EventTypeNID: 1, EventStateKeyNID: 2},
{EventTypeNID: 1, EventStateKeyNID: 4},
{EventTypeNID: 2, EventStateKeyNID: 2},
}
doNotWant := []types.StateKeyTuple{
{EventTypeNID: 0, EventStateKeyNID: 0},
{EventTypeNID: 1, EventStateKeyNID: 3},
{EventTypeNID: 2, EventStateKeyNID: 1},
{EventTypeNID: 3, EventStateKeyNID: 1},
}
wantTypeNIDs := []int64{1, 2}
wantStateKeyNIDs := []int64{1, 2, 4}
// Sort the input and check it's in the right order.
sort.Sort(input)
gotTypeNIDs, gotStateKeyNIDs := input.typesAndStateKeysAsArrays()
for i := range want {
if input[i] != want[i] {
t.Errorf("Wanted %#v at index %d got %#v", want[i], i, input[i])
}
if !input.contains(want[i]) {
t.Errorf("Wanted %#v.contains(%#v) to be true but got false", input, want[i])
}
}
for i := range doNotWant {
if input.contains(doNotWant[i]) {
t.Errorf("Wanted %#v.contains(%#v) to be false but got true", input, doNotWant[i])
}
}
if len(wantTypeNIDs) != len(gotTypeNIDs) {
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
}
for i := range wantTypeNIDs {
if wantTypeNIDs[i] != gotTypeNIDs[i] {
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
}
}
if len(wantStateKeyNIDs) != len(gotStateKeyNIDs) {
t.Fatalf("Wanted state key NIDs %#v got %#v", wantStateKeyNIDs, gotStateKeyNIDs)
}
for i := range wantStateKeyNIDs {
if wantStateKeyNIDs[i] != gotStateKeyNIDs[i] {
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
}
}
}

View file

@ -77,12 +77,12 @@ type stateSnapshotStatements struct {
bulkSelectStateBlockNIDsStmt *sql.Stmt
}
func createStateSnapshotTable(db *sql.DB) error {
func CreateStateSnapshotTable(db *sql.DB) error {
_, err := db.Exec(stateSnapshotSchema)
return err
}
func prepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
func PrepareStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
s := &stateSnapshotStatements{}
return s, sqlutil.StatementList{
@ -95,12 +95,10 @@ func (s *stateSnapshotStatements) InsertState(
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, nids types.StateBlockNIDs,
) (stateNID types.StateSnapshotNID, err error) {
nids = nids[:util.SortAndUnique(nids)]
var id int64
err = sqlutil.TxStmt(txn, s.insertStateStmt).QueryRowContext(ctx, nids.Hash(), int64(roomNID), stateBlockNIDsAsArray(nids)).Scan(&id)
err = sqlutil.TxStmt(txn, s.insertStateStmt).QueryRowContext(ctx, nids.Hash(), int64(roomNID), stateBlockNIDsAsArray(nids)).Scan(&stateNID)
if err != nil {
return 0, err
}
stateNID = types.StateSnapshotNID(id)
return
}
@ -119,9 +117,9 @@ func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
defer rows.Close() // nolint: errcheck
results := make([]types.StateBlockNIDList, len(stateNIDs))
i := 0
var stateBlockNIDs pq.Int64Array
for ; rows.Next(); i++ {
result := &results[i]
var stateBlockNIDs pq.Int64Array
if err = rows.Scan(&result.StateSnapshotNID, &stateBlockNIDs); err != nil {
return nil, err
}

View file

@ -80,19 +80,19 @@ func (d *Database) create(db *sql.DB) error {
if err := CreateEventsTable(db); err != nil {
return err
}
if err := createRoomsTable(db); err != nil {
if err := CreateRoomsTable(db); err != nil {
return err
}
if err := createStateBlockTable(db); err != nil {
if err := CreateStateBlockTable(db); err != nil {
return err
}
if err := createStateSnapshotTable(db); err != nil {
if err := CreateStateSnapshotTable(db); err != nil {
return err
}
if err := CreatePrevEventsTable(db); err != nil {
return err
}
if err := createRoomAliasesTable(db); err != nil {
if err := CreateRoomAliasesTable(db); err != nil {
return err
}
if err := CreateInvitesTable(db); err != nil {
@ -128,15 +128,15 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room
if err != nil {
return err
}
rooms, err := prepareRoomsTable(db)
rooms, err := PrepareRoomsTable(db)
if err != nil {
return err
}
stateBlock, err := prepareStateBlockTable(db)
stateBlock, err := PrepareStateBlockTable(db)
if err != nil {
return err
}
stateSnapshot, err := prepareStateSnapshotTable(db)
stateSnapshot, err := PrepareStateSnapshotTable(db)
if err != nil {
return err
}
@ -144,7 +144,7 @@ func (d *Database) prepare(db *sql.DB, writer sqlutil.Writer, cache caching.Room
if err != nil {
return err
}
roomAliases, err := prepareRoomAliasesTable(db)
roomAliases, err := PrepareRoomAliasesTable(db)
if err != nil {
return err
}