From 6900e0f495e5132d7bc7e98c30990fd1b5574e5f Mon Sep 17 00:00:00 2001 From: Neil Alexander Date: Thu, 15 Apr 2021 15:51:15 +0100 Subject: [PATCH] Deduplicate state block contents --- roomserver/state/state.go | 5 + roomserver/storage/postgres/events_table.go | 49 ++++ .../storage/postgres/state_block_table.go | 194 +++------------ .../storage/postgres/state_snapshot_table.go | 9 +- roomserver/storage/shared/storage.go | 59 ++++- roomserver/storage/sqlite3/events_table.go | 56 +++++ .../storage/sqlite3/state_block_table.go | 223 ++++-------------- .../storage/sqlite3/state_snapshot_table.go | 4 +- roomserver/storage/tables/interface.go | 9 +- roomserver/types/types.go | 13 + 10 files changed, 277 insertions(+), 344 deletions(-) diff --git a/roomserver/state/state.go b/roomserver/state/state.go index 0d9511ac..79bc1577 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -25,6 +25,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/util" "github.com/prometheus/client_golang/prometheus" + "github.com/sirupsen/logrus" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" @@ -61,14 +62,18 @@ func (v *StateResolution) LoadStateAtSnapshot( if err != nil { return nil, err } + logrus.Warn("LISTS: ", stateEntryLists) stateEntriesMap := stateEntryListMap(stateEntryLists) + logrus.Warn("Map: ", stateEntriesMap) // Combine all the state entries for this snapshot. // The order of state block NIDs in the list tells us the order to combine them in. var fullState []types.StateEntry for _, stateBlockNID := range stateBlockNIDList.StateBlockNIDs { + logrus.Warn("Looking up ", stateBlockNID) entries, ok := stateEntriesMap.lookup(stateBlockNID) if !ok { + logrus.Warnf("State block NID %d: %+v", stateBlockNID, entries) // This should only get hit if the database is corrupt. // It should be impossible for an event to reference a NID that doesn't exist panic(fmt.Errorf("Corrupt DB: Missing state block numeric ID %d", stateBlockNID)) diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go index 0cf0bd22..02c3446e 100644 --- a/roomserver/storage/postgres/events_table.go +++ b/roomserver/storage/postgres/events_table.go @@ -88,6 +88,11 @@ const bulkSelectStateEventByIDSQL = "" + " WHERE event_id = ANY($1)" + " ORDER BY event_type_nid, event_state_key_nid ASC" +const bulkSelectStateEventByNIDSQL = "" + + "SELECT event_type_nid, event_state_key_nid, event_nid FROM roomserver_events" + + " WHERE event_nid = ANY($1)" + + " ORDER BY event_type_nid, event_state_key_nid ASC" + const bulkSelectStateAtEventByIDSQL = "" + "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, is_rejected FROM roomserver_events" + " WHERE event_id = ANY($1)" @@ -127,6 +132,7 @@ type eventStatements struct { insertEventStmt *sql.Stmt selectEventStmt *sql.Stmt bulkSelectStateEventByIDStmt *sql.Stmt + bulkSelectStateEventByNIDStmt *sql.Stmt bulkSelectStateAtEventByIDStmt *sql.Stmt updateEventStateStmt *sql.Stmt selectEventSentToOutputStmt *sql.Stmt @@ -151,6 +157,7 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) { {&s.insertEventStmt, insertEventSQL}, {&s.selectEventStmt, selectEventSQL}, {&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL}, + {&s.bulkSelectStateEventByNIDStmt, bulkSelectStateEventByNIDSQL}, {&s.bulkSelectStateAtEventByIDStmt, bulkSelectStateAtEventByIDSQL}, {&s.updateEventStateStmt, updateEventStateSQL}, {&s.updateEventSentToOutputStmt, updateEventSentToOutputSQL}, @@ -238,6 +245,48 @@ func (s *eventStatements) BulkSelectStateEventByID( return results, nil } +// bulkSelectStateEventByNID lookups a list of state events by event NID. +// If any of the requested events are missing from the database it returns a types.MissingEventError +func (s *eventStatements) BulkSelectStateEventByNID( + ctx context.Context, eventNIDs []types.EventNID, +) ([]types.StateEntry, error) { + rows, err := s.bulkSelectStateEventByNIDStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateEventByID: rows.close() failed") + // We know that we will only get as many results as event IDs + // because of the unique constraint on event IDs. + // So we can allocate an array of the correct size now. + // We might get fewer results than IDs so we adjust the length of the slice before returning it. + results := make([]types.StateEntry, len(eventNIDs)) + i := 0 + for ; rows.Next(); i++ { + result := &results[i] + if err = rows.Scan( + &result.EventTypeNID, + &result.EventStateKeyNID, + &result.EventNID, + ); err != nil { + return nil, err + } + } + if err = rows.Err(); err != nil { + return nil, err + } + if i != len(eventNIDs) { + // If there are fewer rows returned than IDs then we were asked to lookup event IDs we don't have. + // We don't know which ones were missing because we don't return the string IDs in the query. + // However it should be possible debug this by replaying queries or entries from the input kafka logs. + // If this turns out to be impossible and we do need the debug information here, it would be better + // to do it as a separate query rather than slowing down/complicating the internal case. + return nil, types.MissingEventError( + fmt.Sprintf("storage: state event IDs missing from the database (%d != %d)", i, len(eventNIDs)), + ) + } + return results, nil +} + // bulkSelectStateAtEventByID lookups the state at a list of events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError. // If we do not have the state for any of the requested events it returns a types.MissingEventError. diff --git a/roomserver/storage/postgres/state_block_table.go b/roomserver/storage/postgres/state_block_table.go index d618686f..304241c9 100644 --- a/roomserver/storage/postgres/state_block_table.go +++ b/roomserver/storage/postgres/state_block_table.go @@ -39,53 +39,27 @@ const stateDataSchema = ` -- lookup a specific (type, state_key) pair for an event. It also makes it easy -- to read the state for a given state_block_nid ordered by (type, state_key) -- which in turn makes it easier to merge state data blocks. -CREATE SEQUENCE IF NOT EXISTS roomserver_state_block_nid_seq; CREATE TABLE IF NOT EXISTS roomserver_state_block ( -- Local numeric ID for this state data. - state_block_nid bigint NOT NULL, - event_type_nid bigint NOT NULL, - event_state_key_nid bigint NOT NULL, - event_nid bigint NOT NULL, - UNIQUE (state_block_nid, event_type_nid, event_state_key_nid) + state_block_nid bigserial PRIMARY KEY, + event_nids bigint[] NOT NULL, + UNIQUE (event_nids) ); ` const insertStateDataSQL = "" + - "INSERT INTO roomserver_state_block (state_block_nid, event_type_nid, event_state_key_nid, event_nid)" + - " VALUES ($1, $2, $3, $4)" + "INSERT INTO roomserver_state_block (event_nids)" + + " VALUES ($1)" + + " ON CONFLICT (event_nids) DO UPDATE SET event_nids=$1" + + " RETURNING state_block_nid" -const selectNextStateBlockNIDSQL = "" + - "SELECT nextval('roomserver_state_block_nid_seq')" - -// Bulk state lookup by numeric state block ID. -// Sort by the state_block_nid, event_type_nid, event_state_key_nid -// This means that all the entries for a given state_block_nid will appear -// together in the list and those entries will sorted by event_type_nid -// and event_state_key_nid. This property makes it easier to merge two -// state data blocks together. const bulkSelectStateBlockEntriesSQL = "" + - "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" + - " FROM roomserver_state_block WHERE state_block_nid = ANY($1)" + - " ORDER BY state_block_nid, event_type_nid, event_state_key_nid" - -// Bulk state lookup by numeric state block ID. -// Filters the rows in each block to the requested types and state keys. -// We would like to restrict to particular type state key pairs but we are -// restricted by the query language to pull the cross product of a list -// of types and a list state_keys. So we have to filter the result in the -// application to restrict it to the list of event types and state keys we -// actually wanted. -const bulkSelectFilteredStateBlockEntriesSQL = "" + - "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" + - " FROM roomserver_state_block WHERE state_block_nid = ANY($1)" + - " AND event_type_nid = ANY($2) AND event_state_key_nid = ANY($3)" + - " ORDER BY state_block_nid, event_type_nid, event_state_key_nid" + "SELECT state_block_nid, event_nids" + + " FROM roomserver_state_block WHERE state_block_nid = ANY($1)" type stateBlockStatements struct { - insertStateDataStmt *sql.Stmt - selectNextStateBlockNIDStmt *sql.Stmt - bulkSelectStateBlockEntriesStmt *sql.Stmt - bulkSelectFilteredStateBlockEntriesStmt *sql.Stmt + insertStateDataStmt *sql.Stmt + bulkSelectStateBlockEntriesStmt *sql.Stmt } func NewPostgresStateBlockTable(db *sql.DB) (tables.StateBlock, error) { @@ -97,85 +71,48 @@ func NewPostgresStateBlockTable(db *sql.DB) (tables.StateBlock, error) { return s, shared.StatementList{ {&s.insertStateDataStmt, insertStateDataSQL}, - {&s.selectNextStateBlockNIDStmt, selectNextStateBlockNIDSQL}, {&s.bulkSelectStateBlockEntriesStmt, bulkSelectStateBlockEntriesSQL}, - {&s.bulkSelectFilteredStateBlockEntriesStmt, bulkSelectFilteredStateBlockEntriesSQL}, }.Prepare(db) } func (s *stateBlockStatements) BulkInsertStateData( ctx context.Context, txn *sql.Tx, - entries []types.StateEntry, -) (types.StateBlockNID, error) { - stateBlockNID, err := s.selectNextStateBlockNID(ctx) - if err != nil { - return 0, err + entries types.StateEntries, +) (id types.StateBlockNID, err error) { + entries = entries[:util.SortAndUnique(entries)] + var nids []int64 + for _, e := range entries { + nids = append(nids, int64(e.EventNID)) } - for _, entry := range entries { - _, err := s.insertStateDataStmt.ExecContext( - ctx, - int64(stateBlockNID), - int64(entry.EventTypeNID), - int64(entry.EventStateKeyNID), - int64(entry.EventNID), - ) - if err != nil { - return 0, err - } - } - return stateBlockNID, nil -} - -func (s *stateBlockStatements) selectNextStateBlockNID( - ctx context.Context, -) (types.StateBlockNID, error) { - var stateBlockNID int64 - err := s.selectNextStateBlockNIDStmt.QueryRowContext(ctx).Scan(&stateBlockNID) - return types.StateBlockNID(stateBlockNID), err + err = s.insertStateDataStmt.QueryRowContext( + ctx, pq.Int64Array(nids), + ).Scan(&id) + return } func (s *stateBlockStatements) BulkSelectStateBlockEntries( - ctx context.Context, stateBlockNIDs []types.StateBlockNID, -) ([]types.StateEntryList, error) { - nids := make([]int64, len(stateBlockNIDs)) - for i := range stateBlockNIDs { - nids[i] = int64(stateBlockNIDs[i]) - } - rows, err := s.bulkSelectStateBlockEntriesStmt.QueryContext(ctx, pq.Int64Array(nids)) + ctx context.Context, stateBlockNIDs types.StateBlockNIDs, +) ([][]types.EventNID, error) { + rows, err := s.bulkSelectStateBlockEntriesStmt.QueryContext(ctx, stateBlockNIDsAsArray(stateBlockNIDs)) if err != nil { return nil, err } defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockEntries: rows.close() failed") - results := make([]types.StateEntryList, len(stateBlockNIDs)) - // current is a pointer to the StateEntryList to append the state entries to. - var current *types.StateEntryList + results := make([][]types.EventNID, len(stateBlockNIDs)) i := 0 - for rows.Next() { - var ( - stateBlockNID int64 - eventTypeNID int64 - eventStateKeyNID int64 - eventNID int64 - entry types.StateEntry - ) - if err = rows.Scan( - &stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID, - ); err != nil { + for ; rows.Next(); i++ { + var stateBlockNID types.StateBlockNID + var result pq.Int64Array + if err = rows.Scan(&stateBlockNID, &result); err != nil { return nil, err } - entry.EventTypeNID = types.EventTypeNID(eventTypeNID) - entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID) - entry.EventNID = types.EventNID(eventNID) - if current == nil || types.StateBlockNID(stateBlockNID) != current.StateBlockNID { - // The state entry row is for a different state data block to the current one. - // So we start appending to the next entry in the list. - current = &results[i] - current.StateBlockNID = types.StateBlockNID(stateBlockNID) - i++ + r := []types.EventNID{} + for _, e := range result { + r = append(r, types.EventNID(e)) } - current.StateEntries = append(current.StateEntries, entry) + results[i] = r } if err = rows.Err(); err != nil { return nil, err @@ -186,71 +123,6 @@ func (s *stateBlockStatements) BulkSelectStateBlockEntries( return results, err } -func (s *stateBlockStatements) BulkSelectFilteredStateBlockEntries( - ctx context.Context, - stateBlockNIDs []types.StateBlockNID, - stateKeyTuples []types.StateKeyTuple, -) ([]types.StateEntryList, error) { - tuples := stateKeyTupleSorter(stateKeyTuples) - // Sort the tuples so that we can run binary search against them as we filter the rows returned by the db. - sort.Sort(tuples) - - eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays() - rows, err := s.bulkSelectFilteredStateBlockEntriesStmt.QueryContext( - ctx, - stateBlockNIDsAsArray(stateBlockNIDs), - eventTypeNIDArray, - eventStateKeyNIDArray, - ) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectFilteredStateBlockEntries: rows.close() failed") - - var results []types.StateEntryList - var current types.StateEntryList - for rows.Next() { - var ( - stateBlockNID int64 - eventTypeNID int64 - eventStateKeyNID int64 - eventNID int64 - entry types.StateEntry - ) - if err := rows.Scan( - &stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID, - ); err != nil { - return nil, err - } - entry.EventTypeNID = types.EventTypeNID(eventTypeNID) - entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID) - entry.EventNID = types.EventNID(eventNID) - - // We can use binary search here because we sorted the tuples earlier - if !tuples.contains(entry.StateKeyTuple) { - // The select will return the cross product of types and state keys. - // So we need to check if type of the entry is in the list. - continue - } - - if types.StateBlockNID(stateBlockNID) != current.StateBlockNID { - // The state entry row is for a different state data block to the current one. - // So we append the current entry to the results and start adding to a new one. - // The first time through the loop current will be empty. - if current.StateEntries != nil { - results = append(results, current) - } - current = types.StateEntryList{StateBlockNID: types.StateBlockNID(stateBlockNID)} - } - current.StateEntries = append(current.StateEntries, entry) - } - // Add the last entry to the list if it is not empty. - if current.StateEntries != nil { - results = append(results, current) - } - return results, rows.Err() -} - func stateBlockNIDsAsArray(stateBlockNIDs []types.StateBlockNID) pq.Int64Array { nids := make([]int64, len(stateBlockNIDs)) for i := range stateBlockNIDs { diff --git a/roomserver/storage/postgres/state_snapshot_table.go b/roomserver/storage/postgres/state_snapshot_table.go index 2a8528d5..2ec5f17d 100644 --- a/roomserver/storage/postgres/state_snapshot_table.go +++ b/roomserver/storage/postgres/state_snapshot_table.go @@ -25,6 +25,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/util" ) const stateSnapshotSchema = ` @@ -86,14 +87,16 @@ func NewPostgresStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { } func (s *stateSnapshotStatements) InsertState( - ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs types.StateBlockNIDs, ) (stateNID types.StateSnapshotNID, err error) { - nids := make([]int64, len(stateBlockNIDs)) + stateBlockNIDs = stateBlockNIDs[:util.SortAndUnique(stateBlockNIDs)] + nids := make(int64Sorter, len(stateBlockNIDs)) for i := range stateBlockNIDs { nids[i] = int64(stateBlockNIDs[i]) } + nids = nids[:util.SortAndUnique(nids)] var id int64 - err = sqlutil.TxStmt(txn, s.insertStateStmt).QueryRowContext(ctx, int64(roomNID), pq.Int64Array(nids), int64(roomNID)).Scan(&id) + err = sqlutil.TxStmt(txn, s.insertStateStmt).QueryRowContext(ctx, int64(roomNID), pq.Int64Array(nids)).Scan(&id) if err != nil { return 0, err } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 24b48772..14c3d1a4 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -118,9 +118,45 @@ func (d *Database) StateEntriesForTuples( stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple, ) ([]types.StateEntryList, error) { - return d.StateBlockTable.BulkSelectFilteredStateBlockEntries( - ctx, stateBlockNIDs, stateKeyTuples, + entries, err := d.StateBlockTable.BulkSelectStateBlockEntries( + ctx, stateBlockNIDs, ) + if err != nil { + return nil, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", err) + } + lists := []types.StateEntryList{} + for i, entry := range entries { + entries, err := d.EventsTable.BulkSelectStateEventByNID(ctx, entry) + if err != nil { + return nil, fmt.Errorf("d.EventsTable.BulkSelectStateEventByNID: %w", err) + } + if len(stateKeyTuples) == 0 { + lists = append(lists, types.StateEntryList{ + StateBlockNID: stateBlockNIDs[i], + StateEntries: entries, + }) + } else { + eventTypes := map[types.EventTypeNID]struct{}{} + stateKeys := map[types.EventStateKeyNID]struct{}{} + for _, t := range stateKeyTuples { + eventTypes[t.EventTypeNID] = struct{}{} + stateKeys[t.EventStateKeyNID] = struct{}{} + } + filteredEntries := []types.StateEntry{} + for _, entry := range entries { + _, tok := eventTypes[entry.EventTypeNID] + _, sok := stateKeys[entry.EventStateKeyNID] + if tok && sok { + filteredEntries = append(filteredEntries, entry) + } + } + lists = append(lists, types.StateEntryList{ + StateBlockNID: stateBlockNIDs[i], + StateEntries: filteredEntries, + }) + } + } + return lists, nil } func (d *Database) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) { @@ -237,7 +273,24 @@ func (d *Database) StateBlockNIDs( func (d *Database) StateEntries( ctx context.Context, stateBlockNIDs []types.StateBlockNID, ) ([]types.StateEntryList, error) { - return d.StateBlockTable.BulkSelectStateBlockEntries(ctx, stateBlockNIDs) + entries, err := d.StateBlockTable.BulkSelectStateBlockEntries( + ctx, stateBlockNIDs, + ) + if err != nil { + return nil, fmt.Errorf("d.StateBlockTable.BulkSelectStateBlockEntries: %w", err) + } + lists := []types.StateEntryList{} + for i, entry := range entries { + eventNIDs, err := d.EventsTable.BulkSelectStateEventByNID(ctx, entry) + if err != nil { + return nil, fmt.Errorf("d.EventsTable.BulkSelectStateEventByNID: %w", err) + } + lists = append(lists, types.StateEntryList{ + StateBlockNID: stateBlockNIDs[i], + StateEntries: eventNIDs, + }) + } + return lists, nil } func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error { diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index 53269657..52dbe600 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -63,6 +63,11 @@ const bulkSelectStateEventByIDSQL = "" + " WHERE event_id IN ($1)" + " ORDER BY event_type_nid, event_state_key_nid ASC" +const bulkSelectStateEventByNIDSQL = "" + + "SELECT event_type_nid, event_state_key_nid, event_nid FROM roomserver_events" + + " WHERE event_nid IN ($1)" + + " ORDER BY event_type_nid, event_state_key_nid ASC" + const bulkSelectStateAtEventByIDSQL = "" + "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, is_rejected FROM roomserver_events" + " WHERE event_id IN ($1)" @@ -232,6 +237,57 @@ func (s *eventStatements) BulkSelectStateEventByID( return results, err } +// bulkSelectStateEventByID lookups a list of state events by event ID. +// If any of the requested events are missing from the database it returns a types.MissingEventError +func (s *eventStatements) BulkSelectStateEventByNID( + ctx context.Context, eventNIDs []types.EventNID, +) ([]types.StateEntry, error) { + /////////////// + iEventIDs := make([]interface{}, len(eventNIDs)) + for k, v := range eventNIDs { + iEventIDs[k] = v + } + selectOrig := strings.Replace(bulkSelectStateEventByNIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventIDs)), 1) + selectStmt, err := s.db.Prepare(selectOrig) + if err != nil { + return nil, err + } + /////////////// + + rows, err := selectStmt.QueryContext(ctx, iEventIDs...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateEventByID: rows.close() failed") + // We know that we will only get as many results as event IDs + // because of the unique constraint on event IDs. + // So we can allocate an array of the correct size now. + // We might get fewer results than IDs so we adjust the length of the slice before returning it. + results := make([]types.StateEntry, len(eventNIDs)) + i := 0 + for ; rows.Next(); i++ { + result := &results[i] + if err = rows.Scan( + &result.EventTypeNID, + &result.EventStateKeyNID, + &result.EventNID, + ); err != nil { + return nil, err + } + } + if i != len(eventNIDs) { + // If there are fewer rows returned than IDs then we were asked to lookup event IDs we don't have. + // We don't know which ones were missing because we don't return the string IDs in the query. + // However it should be possible debug this by replaying queries or entries from the input kafka logs. + // If this turns out to be impossible and we do need the debug information here, it would be better + // to do it as a separate query rather than slowing down/complicating the internal case. + return nil, types.MissingEventError( + fmt.Sprintf("storage: state event IDs missing from the database (%d != %d)", i, len(eventNIDs)), + ) + } + return results, err +} + // bulkSelectStateAtEventByID lookups the state at a list of events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError. // If we do not have the state for any of the requested events it returns a types.MissingEventError. diff --git a/roomserver/storage/sqlite3/state_block_table.go b/roomserver/storage/sqlite3/state_block_table.go index 2c544f2b..762837d5 100644 --- a/roomserver/storage/sqlite3/state_block_table.go +++ b/roomserver/storage/sqlite3/state_block_table.go @@ -18,6 +18,7 @@ package sqlite3 import ( "context" "database/sql" + "encoding/json" "fmt" "sort" "strings" @@ -28,56 +29,32 @@ import ( "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/util" + "github.com/sirupsen/logrus" ) const stateDataSchema = ` CREATE TABLE IF NOT EXISTS roomserver_state_block ( - state_block_nid INTEGER NOT NULL, - event_type_nid INTEGER NOT NULL, - event_state_key_nid INTEGER NOT NULL, - event_nid INTEGER NOT NULL, - UNIQUE (state_block_nid, event_type_nid, event_state_key_nid) + state_block_nid INTEGER PRIMARY KEY AUTOINCREMENT, + event_nids TEXT NOT NULL, + UNIQUE (event_nids) ); ` -const insertStateDataSQL = "" + - "INSERT INTO roomserver_state_block (state_block_nid, event_type_nid, event_state_key_nid, event_nid)" + - " VALUES ($1, $2, $3, $4)" - -const selectNextStateBlockNIDSQL = ` -SELECT IFNULL(MAX(state_block_nid), 0) + 1 FROM roomserver_state_block +const insertStateDataSQL = ` + INSERT INTO roomserver_state_block (event_nids) + VALUES ($1) + ON CONFLICT (event_nids) DO UPDATE SET event_nids=$1 + RETURNING state_block_nid ` -// Bulk state lookup by numeric state block ID. -// Sort by the state_block_nid, event_type_nid, event_state_key_nid -// This means that all the entries for a given state_block_nid will appear -// together in the list and those entries will sorted by event_type_nid -// and event_state_key_nid. This property makes it easier to merge two -// state data blocks together. const bulkSelectStateBlockEntriesSQL = "" + - "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" + - " FROM roomserver_state_block WHERE state_block_nid IN ($1)" + - " ORDER BY state_block_nid, event_type_nid, event_state_key_nid" - -// Bulk state lookup by numeric state block ID. -// Filters the rows in each block to the requested types and state keys. -// We would like to restrict to particular type state key pairs but we are -// restricted by the query language to pull the cross product of a list -// of types and a list state_keys. So we have to filter the result in the -// application to restrict it to the list of event types and state keys we -// actually wanted. -const bulkSelectFilteredStateBlockEntriesSQL = "" + - "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" + - " FROM roomserver_state_block WHERE state_block_nid IN ($1)" + - " AND event_type_nid IN ($2) AND event_state_key_nid IN ($3)" + - " ORDER BY state_block_nid, event_type_nid, event_state_key_nid" + "SELECT state_block_nid, event_nids" + + " FROM roomserver_state_block WHERE state_block_nid IN ($1)" type stateBlockStatements struct { - db *sql.DB - insertStateDataStmt *sql.Stmt - selectNextStateBlockNIDStmt *sql.Stmt - bulkSelectStateBlockEntriesStmt *sql.Stmt - bulkSelectFilteredStateBlockEntriesStmt *sql.Stmt + db *sql.DB + insertStateDataStmt *sql.Stmt + bulkSelectStateBlockEntriesStmt *sql.Stmt } func NewSqliteStateBlockTable(db *sql.DB) (tables.StateBlock, error) { @@ -91,169 +68,71 @@ func NewSqliteStateBlockTable(db *sql.DB) (tables.StateBlock, error) { return s, shared.StatementList{ {&s.insertStateDataStmt, insertStateDataSQL}, - {&s.selectNextStateBlockNIDStmt, selectNextStateBlockNIDSQL}, {&s.bulkSelectStateBlockEntriesStmt, bulkSelectStateBlockEntriesSQL}, - {&s.bulkSelectFilteredStateBlockEntriesStmt, bulkSelectFilteredStateBlockEntriesSQL}, }.Prepare(db) } func (s *stateBlockStatements) BulkInsertStateData( - ctx context.Context, txn *sql.Tx, - entries []types.StateEntry, -) (types.StateBlockNID, error) { - if len(entries) == 0 { - return 0, nil + ctx context.Context, + txn *sql.Tx, + entries types.StateEntries, +) (id types.StateBlockNID, err error) { + entries = entries[:util.SortAndUnique(entries)] + var nids []int64 + for _, e := range entries { + nids = append(nids, int64(e.EventNID)) } - var stateBlockNID types.StateBlockNID - err := sqlutil.TxStmt(txn, s.selectNextStateBlockNIDStmt).QueryRowContext(ctx).Scan(&stateBlockNID) + js, err := json.Marshal(nids) if err != nil { - return 0, err + return 0, fmt.Errorf("json.Marshal: %w", err) } - for _, entry := range entries { - _, err = sqlutil.TxStmt(txn, s.insertStateDataStmt).ExecContext( - ctx, - int64(stateBlockNID), - int64(entry.EventTypeNID), - int64(entry.EventStateKeyNID), - int64(entry.EventNID), - ) - if err != nil { - return 0, err - } - } - return stateBlockNID, err + err = s.insertStateDataStmt.QueryRowContext( + ctx, js, + ).Scan(&id) + return } func (s *stateBlockStatements) BulkSelectStateBlockEntries( - ctx context.Context, stateBlockNIDs []types.StateBlockNID, -) ([]types.StateEntryList, error) { - nids := make([]interface{}, len(stateBlockNIDs)) - for k, v := range stateBlockNIDs { - nids[k] = v + ctx context.Context, stateBlockNIDs types.StateBlockNIDs, +) ([][]types.EventNID, error) { + intfs := make([]interface{}, len(stateBlockNIDs)) + for i := range stateBlockNIDs { + intfs[i] = int64(stateBlockNIDs[i]) } - selectOrig := strings.Replace(bulkSelectStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1) + selectOrig := strings.Replace(bulkSelectStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(intfs)), 1) + logrus.Warnf("Query: %s", selectOrig) + logrus.Warnf("Values: %+v", intfs) selectStmt, err := s.db.Prepare(selectOrig) if err != nil { return nil, err } - rows, err := selectStmt.QueryContext(ctx, nids...) + rows, err := selectStmt.QueryContext(ctx, intfs...) if err != nil { return nil, err } defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockEntries: rows.close() failed") - results := make([]types.StateEntryList, len(stateBlockNIDs)) - // current is a pointer to the StateEntryList to append the state entries to. - var current *types.StateEntryList + results := make([][]types.EventNID, len(stateBlockNIDs)) i := 0 - for rows.Next() { - var ( - stateBlockNID int64 - eventTypeNID int64 - eventStateKeyNID int64 - eventNID int64 - entry types.StateEntry - ) - if err := rows.Scan( - &stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID, - ); err != nil { + for ; rows.Next(); i++ { + var stateBlockNID types.StateBlockNID + var result json.RawMessage + if err = rows.Scan(&stateBlockNID, &result); err != nil { return nil, err } - entry.EventTypeNID = types.EventTypeNID(eventTypeNID) - entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID) - entry.EventNID = types.EventNID(eventNID) - if current == nil || types.StateBlockNID(stateBlockNID) != current.StateBlockNID { - // The state entry row is for a different state data block to the current one. - // So we start appending to the next entry in the list. - current = &results[i] - current.StateBlockNID = types.StateBlockNID(stateBlockNID) - i++ + r := []types.EventNID{} + if err = json.Unmarshal(result, &r); err != nil { + return nil, fmt.Errorf("json.Unmarshal: %w", err) } - current.StateEntries = append(current.StateEntries, entry) + results[i] = r } - if i != len(nids) { - return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", i, len(nids)) - } - return results, nil -} - -func (s *stateBlockStatements) BulkSelectFilteredStateBlockEntries( - ctx context.Context, - stateBlockNIDs []types.StateBlockNID, - stateKeyTuples []types.StateKeyTuple, -) ([]types.StateEntryList, error) { - tuples := stateKeyTupleSorter(stateKeyTuples) - // Sort the tuples so that we can run binary search against them as we filter the rows returned by the db. - sort.Sort(tuples) - - eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays() - sqlStatement := strings.Replace(bulkSelectFilteredStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(stateBlockNIDs)), 1) - sqlStatement = strings.Replace(sqlStatement, "($2)", sqlutil.QueryVariadicOffset(len(eventTypeNIDArray), len(stateBlockNIDs)), 1) - sqlStatement = strings.Replace(sqlStatement, "($3)", sqlutil.QueryVariadicOffset(len(eventStateKeyNIDArray), len(stateBlockNIDs)+len(eventTypeNIDArray)), 1) - - var params []interface{} - for _, val := range stateBlockNIDs { - params = append(params, int64(val)) - } - for _, val := range eventTypeNIDArray { - params = append(params, val) - } - for _, val := range eventStateKeyNIDArray { - params = append(params, val) - } - - rows, err := s.db.QueryContext( - ctx, - sqlStatement, - params..., - ) - if err != nil { + if err = rows.Err(); err != nil { return nil, err } - defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectFilteredStateBlockEntries: rows.close() failed") - - var results []types.StateEntryList - var current types.StateEntryList - for rows.Next() { - var ( - stateBlockNID int64 - eventTypeNID int64 - eventStateKeyNID int64 - eventNID int64 - entry types.StateEntry - ) - if err := rows.Scan( - &stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID, - ); err != nil { - return nil, err - } - entry.EventTypeNID = types.EventTypeNID(eventTypeNID) - entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID) - entry.EventNID = types.EventNID(eventNID) - - // We can use binary search here because we sorted the tuples earlier - if !tuples.contains(entry.StateKeyTuple) { - // The select will return the cross product of types and state keys. - // So we need to check if type of the entry is in the list. - continue - } - - if types.StateBlockNID(stateBlockNID) != current.StateBlockNID { - // The state entry row is for a different state data block to the current one. - // So we append the current entry to the results and start adding to a new one. - // The first time through the loop current will be empty. - if current.StateEntries != nil { - results = append(results, current) - } - current = types.StateEntryList{StateBlockNID: types.StateBlockNID(stateBlockNID)} - } - current.StateEntries = append(current.StateEntries, entry) + if i != len(stateBlockNIDs) { + return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", len(results), len(stateBlockNIDs)) } - // Add the last entry to the list if it is not empty. - if current.StateEntries != nil { - results = append(results, current) - } - return results, nil + return results, err } type stateKeyTupleSorter []types.StateKeyTuple diff --git a/roomserver/storage/sqlite3/state_snapshot_table.go b/roomserver/storage/sqlite3/state_snapshot_table.go index 99e69113..041d2504 100644 --- a/roomserver/storage/sqlite3/state_snapshot_table.go +++ b/roomserver/storage/sqlite3/state_snapshot_table.go @@ -27,6 +27,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/storage/shared" "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" + "github.com/matrix-org/util" ) const stateSnapshotSchema = ` @@ -74,8 +75,9 @@ func NewSqliteStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { } func (s *stateSnapshotStatements) InsertState( - ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, + ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs types.StateBlockNIDs, ) (stateNID types.StateSnapshotNID, err error) { + stateBlockNIDs = stateBlockNIDs[:util.SortAndUnique(stateBlockNIDs)] stateBlockNIDsJSON, err := json.Marshal(stateBlockNIDs) if err != nil { return diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 26bf5cf0..62d481e1 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -43,6 +43,7 @@ type Events interface { // bulkSelectStateEventByID lookups a list of state events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError BulkSelectStateEventByID(ctx context.Context, eventIDs []string) ([]types.StateEntry, error) + BulkSelectStateEventByNID(ctx context.Context, eventNIDs []types.EventNID) ([]types.StateEntry, error) // BulkSelectStateAtEventByID lookups the state at a list of events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError. // If we do not have the state for any of the requested events it returns a types.MissingEventError. @@ -81,14 +82,14 @@ type Transactions interface { } type StateSnapshot interface { - InsertState(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID) (stateNID types.StateSnapshotNID, err error) + InsertState(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs types.StateBlockNIDs) (stateNID types.StateSnapshotNID, err error) BulkSelectStateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) } type StateBlock interface { - BulkInsertStateData(ctx context.Context, txn *sql.Tx, entries []types.StateEntry) (types.StateBlockNID, error) - BulkSelectStateBlockEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) - BulkSelectFilteredStateBlockEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error) + BulkInsertStateData(ctx context.Context, txn *sql.Tx, entries types.StateEntries) (types.StateBlockNID, error) + BulkSelectStateBlockEntries(ctx context.Context, stateBlockNIDs types.StateBlockNIDs) ([][]types.EventNID, error) + //BulkSelectFilteredStateBlockEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error) } type RoomAliases interface { diff --git a/roomserver/types/types.go b/roomserver/types/types.go index e866f6cb..7290935f 100644 --- a/roomserver/types/types.go +++ b/roomserver/types/types.go @@ -40,6 +40,13 @@ type StateSnapshotNID int64 // These blocks of state data are combined to form the actual state. type StateBlockNID int64 +// StateBlockNIDs is used to sort and dedupe state block NIDs. +type StateBlockNIDs []StateBlockNID + +func (a StateBlockNIDs) Len() int { return len(a) } +func (a StateBlockNIDs) Swap(i, j int) { a[i], a[j] = a[j], a[i] } +func (a StateBlockNIDs) Less(i, j int) bool { return a[i] < a[j] } + // A StateKeyTuple is a pair of a numeric event type and a numeric state key. // It is used to lookup state entries. type StateKeyTuple struct { @@ -65,6 +72,12 @@ type StateEntry struct { EventNID EventNID } +type StateEntries []StateEntry + +func (a StateEntries) Len() int { return len(a) } +func (a StateEntries) Swap(i, j int) { a[i], a[j] = a[j], a[i] } +func (a StateEntries) Less(i, j int) bool { return a[i].EventNID < a[j].EventNID } + // LessThan returns true if this state entry is less than the other state entry. // The ordering is arbitrary and is used to implement binary search and to efficiently deduplicate entries. func (a StateEntry) LessThan(b StateEntry) bool {