diff --git a/go.mod b/go.mod index a3d80f1b..16c25ead 100644 --- a/go.mod +++ b/go.mod @@ -25,7 +25,7 @@ require ( github.com/matrix-org/gomatrixserverlib v0.0.0-20210302161955-6142fe3f8c2c github.com/matrix-org/naffka v0.0.0-20201009174903-d26a3b9cb161 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 - github.com/mattn/go-sqlite3 v1.14.6 + github.com/mattn/go-sqlite3 v1.14.7-0.20210414154423-1157a4212dcb github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 github.com/ngrok/sqlmw v0.0.0-20200129213757-d5c93a81bec6 github.com/opentracing/opentracing-go v1.2.0 diff --git a/go.sum b/go.sum index 90b5527c..f245a270 100644 --- a/go.sum +++ b/go.sum @@ -684,8 +684,8 @@ github.com/mattn/go-isatty v0.0.10/go.mod h1:qgIWMr58cqv1PHHyhnkY9lrL7etaEgOFcME github.com/mattn/go-runewidth v0.0.2/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU= github.com/mattn/go-runewidth v0.0.7/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= github.com/mattn/go-sqlite3 v1.14.2/go.mod h1:JIl7NbARA7phWnGvh0LKTyg7S9BA+6gx71ShQilpsus= -github.com/mattn/go-sqlite3 v1.14.6 h1:dNPt6NO46WmLVt2DLNpwczCmdV5boIZ6g/tlDrlRUbg= -github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= +github.com/mattn/go-sqlite3 v1.14.7-0.20210414154423-1157a4212dcb h1:ax2vG2unlxsjwS7PMRo4FECIfAdQLowd6ejWYwPQhBo= +github.com/mattn/go-sqlite3 v1.14.7-0.20210414154423-1157a4212dcb/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/mattn/goveralls v0.0.2/go.mod h1:8d1ZMHsd7fW6IRPKQh46F2WRpyib5/X4FOpevwGNQEw= github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= diff --git a/roomserver/internal/input/input_events.go b/roomserver/internal/input/input_events.go index 2a558c48..29c76e3f 100644 --- a/roomserver/internal/input/input_events.go +++ b/roomserver/internal/input/input_events.go @@ -271,7 +271,7 @@ func (r *Inputer) calculateAndSetState( } entries = types.DeduplicateStateEntries(entries) - if stateAtEvent.BeforeStateSnapshotNID, err = r.DB.AddState(ctx, roomInfo.RoomNID, nil, entries); err != nil { + if stateAtEvent.BeforeStateSnapshotNID, err = r.DB.AddState(ctx, roomInfo.RoomNID /*nil,*/, entries); err != nil { return fmt.Errorf("r.DB.AddState: %w", err) } } else { diff --git a/roomserver/internal/perform/perform_backfill.go b/roomserver/internal/perform/perform_backfill.go index d9d720f2..cc432cba 100644 --- a/roomserver/internal/perform/perform_backfill.go +++ b/roomserver/internal/perform/perform_backfill.go @@ -146,7 +146,7 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform } var beforeStateSnapshotNID types.StateSnapshotNID - if beforeStateSnapshotNID, err = r.DB.AddState(ctx, roomNID, nil, entries); err != nil { + if beforeStateSnapshotNID, err = r.DB.AddState(ctx, roomNID /*nil,*/, entries); err != nil { logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to persist state entries to get snapshot nid") return err } diff --git a/roomserver/state/state.go b/roomserver/state/state.go index 0d9511ac..c1b21c48 100644 --- a/roomserver/state/state.go +++ b/roomserver/state/state.go @@ -25,6 +25,7 @@ import ( "github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/util" "github.com/prometheus/client_golang/prometheus" + "github.com/sirupsen/logrus" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" @@ -50,32 +51,10 @@ func NewStateResolution(db storage.Database, roomInfo types.RoomInfo) StateResol func (v *StateResolution) LoadStateAtSnapshot( ctx context.Context, stateNID types.StateSnapshotNID, ) ([]types.StateEntry, error) { - stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID}) + fullState, err := v.db.StateEntries(ctx, stateNID) if err != nil { - return nil, err + return nil, fmt.Errorf("v.db.StateEntries: %w", err) } - // We've asked for exactly one snapshot from the db so we should have exactly one entry in the result. - stateBlockNIDList := stateBlockNIDLists[0] - - stateEntryLists, err := v.db.StateEntries(ctx, stateBlockNIDList.StateBlockNIDs) - if err != nil { - return nil, err - } - stateEntriesMap := stateEntryListMap(stateEntryLists) - - // Combine all the state entries for this snapshot. - // The order of state block NIDs in the list tells us the order to combine them in. - var fullState []types.StateEntry - for _, stateBlockNID := range stateBlockNIDList.StateBlockNIDs { - entries, ok := stateEntriesMap.lookup(stateBlockNID) - if !ok { - // This should only get hit if the database is corrupt. - // It should be impossible for an event to reference a NID that doesn't exist - panic(fmt.Errorf("Corrupt DB: Missing state block numeric ID %d", stateBlockNID)) - } - fullState = append(fullState, entries...) - } - // Stable sort so that the most recent entry for each state key stays // remains later in the list than the older entries for the same state key. sort.Stable(stateEntryByStateKeySorter(fullState)) @@ -95,12 +74,10 @@ func (v *StateResolution) LoadStateAtEvent( if snapshotNID == 0 { return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID(%s) returned 0 NID, was this event stored?", eventID) } - stateEntries, err := v.LoadStateAtSnapshot(ctx, snapshotNID) if err != nil { - return nil, err + return nil, fmt.Errorf("v.LoadStateAtSnapshot: %w", err) } - return stateEntries, nil } @@ -114,53 +91,33 @@ func (v *StateResolution) LoadCombinedStateAfterEvents( for i, state := range prevStates { stateNIDs[i] = state.BeforeStateSnapshotNID } - // Fetch the state snapshots for the state before the each prev event from the database. - // Deduplicate the IDs before passing them to the database. - // There could be duplicates because the events could be state events where - // the snapshot of the room state before them was the same. - stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, UniqueStateSnapshotNIDs(stateNIDs)) - if err != nil { - return nil, fmt.Errorf("v.db.StateBlockNIDs: %w", err) - } - var stateBlockNIDs []types.StateBlockNID - for _, list := range stateBlockNIDLists { - stateBlockNIDs = append(stateBlockNIDs, list.StateBlockNIDs...) - } // Fetch the state entries that will be combined to create the snapshots. // Deduplicate the IDs before passing them to the database. // There could be duplicates because a block of state entries could be reused by // multiple snapshots. - stateEntryLists, err := v.db.StateEntries(ctx, uniqueStateBlockNIDs(stateBlockNIDs)) - if err != nil { - return nil, fmt.Errorf("v.db.StateEntries: %w", err) + stateEntriesMap := map[types.StateSnapshotNID][]types.StateEntry{} + for _, nid := range stateNIDs { + entries, err := v.db.StateEntries(ctx, nid) + if err != nil { + return nil, fmt.Errorf("v.db.StateEntries: %w", err) + } + stateEntriesMap[nid] = entries } - stateBlockNIDsMap := stateBlockNIDListMap(stateBlockNIDLists) - stateEntriesMap := stateEntryListMap(stateEntryLists) // Combine the entries from all the snapshots of state after each prev event into a single list. var combined []types.StateEntry for _, prevState := range prevStates { - // Grab the list of state data NIDs for this snapshot. - stateBlockNIDs, ok := stateBlockNIDsMap.lookup(prevState.BeforeStateSnapshotNID) - if !ok { - // This should only get hit if the database is corrupt. - // It should be impossible for an event to reference a NID that doesn't exist - panic(fmt.Errorf("Corrupt DB: Missing state snapshot numeric ID %d", prevState.BeforeStateSnapshotNID)) - } - // Combine all the state entries for this snapshot. // The order of state block NIDs in the list tells us the order to combine them in. var fullState []types.StateEntry - for _, stateBlockNID := range stateBlockNIDs { - entries, ok := stateEntriesMap.lookup(stateBlockNID) - if !ok { - // This should only get hit if the database is corrupt. - // It should be impossible for an event to reference a NID that doesn't exist - panic(fmt.Errorf("Corrupt DB: Missing state block numeric ID %d", stateBlockNID)) - } - fullState = append(fullState, entries...) + entries, ok := stateEntriesMap[prevState.BeforeStateSnapshotNID] + if !ok { + // This should only get hit if the database is corrupt. + // It should be impossible for an event to reference a NID that doesn't exist + panic(fmt.Errorf("Corrupt DB: Missing state snapshot %d", prevState.BeforeStateSnapshotNID)) } + fullState = append(fullState, entries...) if prevState.IsStateEvent() && !prevState.IsRejected { // If the prev event was a state event then add an entry for the event itself // so that we get the state after the event rather than the state before. @@ -192,13 +149,13 @@ func (v *StateResolution) DifferenceBetweeenStateSnapshots( if oldStateNID != 0 { oldEntries, err = v.LoadStateAtSnapshot(ctx, oldStateNID) if err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("v.LoadStateAtSnapshot: %w", err) } } if newStateNID != 0 { newEntries, err = v.LoadStateAtSnapshot(ctx, newStateNID) if err != nil { - return nil, nil, err + return nil, nil, fmt.Errorf("v.LoadStateAtSnapshot: %w", err) } } @@ -299,33 +256,9 @@ func (v *StateResolution) loadStateAtSnapshotForNumericTuples( stateNID types.StateSnapshotNID, stateKeyTuples []types.StateKeyTuple, ) ([]types.StateEntry, error) { - stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID}) + fullState, err := v.db.StateEntries(ctx, stateNID) if err != nil { - return nil, err - } - // We've asked for exactly one snapshot from the db so we should have exactly one entry in the result. - stateBlockNIDList := stateBlockNIDLists[0] - - stateEntryLists, err := v.db.StateEntriesForTuples( - ctx, stateBlockNIDList.StateBlockNIDs, stateKeyTuples, - ) - if err != nil { - return nil, err - } - stateEntriesMap := stateEntryListMap(stateEntryLists) - - // Combine all the state entries for this snapshot. - // The order of state block NIDs in the list tells us the order to combine them in. - var fullState []types.StateEntry - for _, stateBlockNID := range stateBlockNIDList.StateBlockNIDs { - entries, ok := stateEntriesMap.lookup(stateBlockNID) - if !ok { - // If the block is missing from the map it means that none of its entries matched a requested tuple. - // This can happen if the block doesn't contain an update for one of the requested tuples. - // If none of the requested tuples are in the block then it can be safely skipped. - continue - } - fullState = append(fullState, entries...) + return nil, fmt.Errorf("v.db.StateEntries: %w", err) } // Stable sort so that the most recent entry for each state key stays @@ -549,7 +482,8 @@ func (v *StateResolution) CalculateAndStoreStateAfterEvents( // 2) There weren't any prev_events for this event so the state is // empty. metrics.algorithm = "empty_state" - stateNID, err := v.db.AddState(ctx, v.roomInfo.RoomNID, nil, nil) + stateNID, err := v.db.AddState(ctx, v.roomInfo.RoomNID /*nil,*/, nil) + logrus.Warnf("Empty prev state added state snapshot %d", stateNID) if err != nil { err = fmt.Errorf("v.db.AddState: %w", err) } @@ -566,28 +500,56 @@ func (v *StateResolution) CalculateAndStoreStateAfterEvents( metrics.algorithm = "no_change" return metrics.stop(prevState.BeforeStateSnapshotNID, nil) } - // The previous event was a state event so we need to store a copy - // of the previous state updated with that event. - stateBlockNIDLists, err := v.db.StateBlockNIDs( - ctx, []types.StateSnapshotNID{prevState.BeforeStateSnapshotNID}, + + oldState, err := v.db.StateEntries( + ctx, prevState.BeforeStateSnapshotNID, ) if err != nil { - metrics.algorithm = "_load_state_blocks" - return metrics.stop(0, fmt.Errorf("v.db.StateBlockNIDs: %w", err)) + return 0, fmt.Errorf("v.db.StateEntries: %w", err) } - stateBlockNIDs := stateBlockNIDLists[0].StateBlockNIDs - if len(stateBlockNIDs) < maxStateBlockNIDs { - // 4) The number of state data blocks is small enough that we can just - // add the state event as a block of size one to the end of the blocks. - metrics.algorithm = "single_delta" - stateNID, err := v.db.AddState( - ctx, v.roomInfo.RoomNID, stateBlockNIDs, []types.StateEntry{prevState.StateEntry}, + found := false + for _, s := range oldState { + if s.EventNID == prevState.StateEntry.EventNID { + found = true + break + } + } + if !found { + oldState = append(oldState, prevState.StateEntry) + } + + stateNID, err := v.db.AddState( + ctx, v.roomInfo.RoomNID, oldState, + ) + logrus.Warnf("Single prev state added state snapshot %d", stateNID) + if err != nil { + return 0, fmt.Errorf("v.db.AddState: %w", err) + } + return stateNID, nil + /* + // The previous event was a state event so we need to store a copy + // of the previous state updated with that event. + stateBlockNIDLists, err := v.db.StateBlockNIDs( + ctx, []types.StateSnapshotNID{prevState.BeforeStateSnapshotNID}, ) if err != nil { - err = fmt.Errorf("v.db.AddState: %w", err) + metrics.algorithm = "_load_state_blocks" + return metrics.stop(0, fmt.Errorf("v.db.StateBlockNIDs: %w", err)) } - return metrics.stop(stateNID, err) - } + stateBlockNIDs := stateBlockNIDLists[0].StateBlockNIDs + if len(stateBlockNIDs) < maxStateBlockNIDs { + // 4) The number of state data blocks is small enough that we can just + // add the state event as a block of size one to the end of the blocks. + metrics.algorithm = "single_delta" + stateNID, err := v.db.AddState( + ctx, v.roomInfo.RoomNID, stateBlockNIDs, []types.StateEntry{prevState.StateEntry}, + ) + if err != nil { + err = fmt.Errorf("v.db.AddState: %w", err) + } + return metrics.stop(stateNID, err) + } + */ // If there are too many deltas then we need to calculate the full state // So fall through to calculateAndStoreStateAfterManyEvents } @@ -596,6 +558,7 @@ func (v *StateResolution) CalculateAndStoreStateAfterEvents( if err != nil { return 0, fmt.Errorf("v.calculateAndStoreStateAfterManyEvents: %w", err) } + logrus.Warnf("Multiple prev states added state snapshot %d", stateNID) return stateNID, nil } @@ -626,7 +589,7 @@ func (v *StateResolution) calculateAndStoreStateAfterManyEvents( // previous state. metrics.conflictLength = conflictLength metrics.fullStateLength = len(state) - return metrics.stop(v.db.AddState(ctx, roomNID, nil, state)) + return metrics.stop(v.db.AddState(ctx, roomNID /*nil,*/, state)) } func (v *StateResolution) calculateStateAfterManyEvents( @@ -996,34 +959,6 @@ func (s stateEntrySorter) Len() int { return len(s) } func (s stateEntrySorter) Less(i, j int) bool { return s[i].LessThan(s[j]) } func (s stateEntrySorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } -type stateBlockNIDListMap []types.StateBlockNIDList - -func (m stateBlockNIDListMap) lookup(stateNID types.StateSnapshotNID) (stateBlockNIDs []types.StateBlockNID, ok bool) { - list := []types.StateBlockNIDList(m) - i := sort.Search(len(list), func(i int) bool { - return list[i].StateSnapshotNID >= stateNID - }) - if i < len(list) && list[i].StateSnapshotNID == stateNID { - ok = true - stateBlockNIDs = list[i].StateBlockNIDs - } - return -} - -type stateEntryListMap []types.StateEntryList - -func (m stateEntryListMap) lookup(stateBlockNID types.StateBlockNID) (stateEntries []types.StateEntry, ok bool) { - list := []types.StateEntryList(m) - i := sort.Search(len(list), func(i int) bool { - return list[i].StateBlockNID >= stateBlockNID - }) - if i < len(list) && list[i].StateBlockNID == stateBlockNID { - ok = true - stateEntries = list[i].StateEntries - } - return -} - type stateEntryByStateKeySorter []types.StateEntry func (s stateEntryByStateKeySorter) Len() int { return len(s) } diff --git a/roomserver/storage/interface.go b/roomserver/storage/interface.go index d2b0e75c..fd1bbb4f 100644 --- a/roomserver/storage/interface.go +++ b/roomserver/storage/interface.go @@ -33,7 +33,6 @@ type Database interface { AddState( ctx context.Context, roomNID types.RoomNID, - stateBlockNIDs []types.StateBlockNID, state []types.StateEntry, ) (types.StateSnapshotNID, error) // Look up the state of a room at each event for a list of string event IDs. @@ -47,21 +46,9 @@ type Database interface { // Look up the numeric IDs for a list of string event state keys. // Returns a map from string state key to numeric ID for the state key. EventStateKeyNIDs(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error) - // Look up the numeric state data IDs for each numeric state snapshot ID - // The returned slice is sorted by numeric state snapshot ID. - StateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) // Look up the state data for each numeric state data ID // The returned slice is sorted by numeric state data ID. - StateEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) - // Look up the state data for the state key tuples for each numeric state block ID - // This is used to fetch a subset of the room state at a snapshot. - // If a block doesn't contain any of the requested tuples then it can be discarded from the result. - // The returned slice is sorted by numeric state block ID. - StateEntriesForTuples( - ctx context.Context, - stateBlockNIDs []types.StateBlockNID, - stateKeyTuples []types.StateKeyTuple, - ) ([]types.StateEntryList, error) + StateEntries(ctx context.Context, stateSnapshotNID types.StateSnapshotNID) ([]types.StateEntry, error) // Look up the Events for a list of numeric event IDs. // Returns a sorted list of events. Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error) diff --git a/roomserver/storage/postgres/events_table.go b/roomserver/storage/postgres/events_table.go index 0cf0bd22..c8bd0813 100644 --- a/roomserver/storage/postgres/events_table.go +++ b/roomserver/storage/postgres/events_table.go @@ -88,6 +88,12 @@ const bulkSelectStateEventByIDSQL = "" + " WHERE event_id = ANY($1)" + " ORDER BY event_type_nid, event_state_key_nid ASC" +// Bulk lookup of events by numeric ID. +const bulkSelectStateEventByNIDSQL = "" + + "SELECT event_type_nid, event_state_key_nid, event_nid FROM roomserver_events" + + " WHERE event_nid = ANY($1)" + + " ORDER BY event_type_nid, event_state_key_nid ASC" + const bulkSelectStateAtEventByIDSQL = "" + "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, is_rejected FROM roomserver_events" + " WHERE event_id = ANY($1)" @@ -127,6 +133,7 @@ type eventStatements struct { insertEventStmt *sql.Stmt selectEventStmt *sql.Stmt bulkSelectStateEventByIDStmt *sql.Stmt + bulkSelectStateEventByNIDStmt *sql.Stmt bulkSelectStateAtEventByIDStmt *sql.Stmt updateEventStateStmt *sql.Stmt selectEventSentToOutputStmt *sql.Stmt @@ -151,6 +158,7 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) { {&s.insertEventStmt, insertEventSQL}, {&s.selectEventStmt, selectEventSQL}, {&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL}, + {&s.bulkSelectStateEventByNIDStmt, bulkSelectStateEventByNIDSQL}, {&s.bulkSelectStateAtEventByIDStmt, bulkSelectStateAtEventByIDSQL}, {&s.updateEventStateStmt, updateEventStateSQL}, {&s.updateEventSentToOutputStmt, updateEventSentToOutputSQL}, @@ -238,6 +246,48 @@ func (s *eventStatements) BulkSelectStateEventByID( return results, nil } +// bulkSelectStateEventByNID lookups a list of state events by internal numeric ID. +// If any of the requested events are missing from the database it returns a types.MissingEventError +func (s *eventStatements) BulkSelectStateEventByNID( + ctx context.Context, eventNIDs []types.EventNID, +) ([]types.StateEntry, error) { + rows, err := s.bulkSelectStateEventByNIDStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs)) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateEventByID: rows.close() failed") + // We know that we will only get as many results as event NIDs + // because of the unique constraint on event NIDs. + // So we can allocate an array of the correct size now. + // We might get fewer results than NIDs so we adjust the length of the slice before returning it. + results := make([]types.StateEntry, len(eventNIDs)) + i := 0 + for ; rows.Next(); i++ { + result := &results[i] + if err = rows.Scan( + &result.EventTypeNID, + &result.EventStateKeyNID, + &result.EventNID, + ); err != nil { + return nil, err + } + } + if err = rows.Err(); err != nil { + return nil, err + } + if i != len(eventNIDs) { + // If there are fewer rows returned than IDs then we were asked to lookup event IDs we don't have. + // We don't know which ones were missing because we don't return the string IDs in the query. + // However it should be possible debug this by replaying queries or entries from the input kafka logs. + // If this turns out to be impossible and we do need the debug information here, it would be better + // to do it as a separate query rather than slowing down/complicating the internal case. + return nil, types.MissingEventError( + fmt.Sprintf("storage: state event IDs missing from the database (%d != %d)", i, len(eventNIDs)), + ) + } + return results, nil +} + // bulkSelectStateAtEventByID lookups the state at a list of events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError. // If we do not have the state for any of the requested events it returns a types.MissingEventError. diff --git a/roomserver/storage/postgres/state_block_table.go b/roomserver/storage/postgres/state_block_table.go deleted file mode 100644 index d618686f..00000000 --- a/roomserver/storage/postgres/state_block_table.go +++ /dev/null @@ -1,292 +0,0 @@ -// Copyright 2017-2018 New Vector Ltd -// Copyright 2019-2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package postgres - -import ( - "context" - "database/sql" - "fmt" - "sort" - - "github.com/lib/pq" - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/roomserver/storage/shared" - "github.com/matrix-org/dendrite/roomserver/storage/tables" - "github.com/matrix-org/dendrite/roomserver/types" - "github.com/matrix-org/util" -) - -const stateDataSchema = ` --- The state data map. --- Designed to give enough information to run the state resolution algorithm --- without hitting the database in the internal case. --- TODO: Is it worth replacing the unique btree index with a covering index so --- that postgres could lookup the state using an index-only scan? --- The type and state_key are included in the index to make it easier to --- lookup a specific (type, state_key) pair for an event. It also makes it easy --- to read the state for a given state_block_nid ordered by (type, state_key) --- which in turn makes it easier to merge state data blocks. -CREATE SEQUENCE IF NOT EXISTS roomserver_state_block_nid_seq; -CREATE TABLE IF NOT EXISTS roomserver_state_block ( - -- Local numeric ID for this state data. - state_block_nid bigint NOT NULL, - event_type_nid bigint NOT NULL, - event_state_key_nid bigint NOT NULL, - event_nid bigint NOT NULL, - UNIQUE (state_block_nid, event_type_nid, event_state_key_nid) -); -` - -const insertStateDataSQL = "" + - "INSERT INTO roomserver_state_block (state_block_nid, event_type_nid, event_state_key_nid, event_nid)" + - " VALUES ($1, $2, $3, $4)" - -const selectNextStateBlockNIDSQL = "" + - "SELECT nextval('roomserver_state_block_nid_seq')" - -// Bulk state lookup by numeric state block ID. -// Sort by the state_block_nid, event_type_nid, event_state_key_nid -// This means that all the entries for a given state_block_nid will appear -// together in the list and those entries will sorted by event_type_nid -// and event_state_key_nid. This property makes it easier to merge two -// state data blocks together. -const bulkSelectStateBlockEntriesSQL = "" + - "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" + - " FROM roomserver_state_block WHERE state_block_nid = ANY($1)" + - " ORDER BY state_block_nid, event_type_nid, event_state_key_nid" - -// Bulk state lookup by numeric state block ID. -// Filters the rows in each block to the requested types and state keys. -// We would like to restrict to particular type state key pairs but we are -// restricted by the query language to pull the cross product of a list -// of types and a list state_keys. So we have to filter the result in the -// application to restrict it to the list of event types and state keys we -// actually wanted. -const bulkSelectFilteredStateBlockEntriesSQL = "" + - "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" + - " FROM roomserver_state_block WHERE state_block_nid = ANY($1)" + - " AND event_type_nid = ANY($2) AND event_state_key_nid = ANY($3)" + - " ORDER BY state_block_nid, event_type_nid, event_state_key_nid" - -type stateBlockStatements struct { - insertStateDataStmt *sql.Stmt - selectNextStateBlockNIDStmt *sql.Stmt - bulkSelectStateBlockEntriesStmt *sql.Stmt - bulkSelectFilteredStateBlockEntriesStmt *sql.Stmt -} - -func NewPostgresStateBlockTable(db *sql.DB) (tables.StateBlock, error) { - s := &stateBlockStatements{} - _, err := db.Exec(stateDataSchema) - if err != nil { - return nil, err - } - - return s, shared.StatementList{ - {&s.insertStateDataStmt, insertStateDataSQL}, - {&s.selectNextStateBlockNIDStmt, selectNextStateBlockNIDSQL}, - {&s.bulkSelectStateBlockEntriesStmt, bulkSelectStateBlockEntriesSQL}, - {&s.bulkSelectFilteredStateBlockEntriesStmt, bulkSelectFilteredStateBlockEntriesSQL}, - }.Prepare(db) -} - -func (s *stateBlockStatements) BulkInsertStateData( - ctx context.Context, - txn *sql.Tx, - entries []types.StateEntry, -) (types.StateBlockNID, error) { - stateBlockNID, err := s.selectNextStateBlockNID(ctx) - if err != nil { - return 0, err - } - for _, entry := range entries { - _, err := s.insertStateDataStmt.ExecContext( - ctx, - int64(stateBlockNID), - int64(entry.EventTypeNID), - int64(entry.EventStateKeyNID), - int64(entry.EventNID), - ) - if err != nil { - return 0, err - } - } - return stateBlockNID, nil -} - -func (s *stateBlockStatements) selectNextStateBlockNID( - ctx context.Context, -) (types.StateBlockNID, error) { - var stateBlockNID int64 - err := s.selectNextStateBlockNIDStmt.QueryRowContext(ctx).Scan(&stateBlockNID) - return types.StateBlockNID(stateBlockNID), err -} - -func (s *stateBlockStatements) BulkSelectStateBlockEntries( - ctx context.Context, stateBlockNIDs []types.StateBlockNID, -) ([]types.StateEntryList, error) { - nids := make([]int64, len(stateBlockNIDs)) - for i := range stateBlockNIDs { - nids[i] = int64(stateBlockNIDs[i]) - } - rows, err := s.bulkSelectStateBlockEntriesStmt.QueryContext(ctx, pq.Int64Array(nids)) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockEntries: rows.close() failed") - - results := make([]types.StateEntryList, len(stateBlockNIDs)) - // current is a pointer to the StateEntryList to append the state entries to. - var current *types.StateEntryList - i := 0 - for rows.Next() { - var ( - stateBlockNID int64 - eventTypeNID int64 - eventStateKeyNID int64 - eventNID int64 - entry types.StateEntry - ) - if err = rows.Scan( - &stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID, - ); err != nil { - return nil, err - } - entry.EventTypeNID = types.EventTypeNID(eventTypeNID) - entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID) - entry.EventNID = types.EventNID(eventNID) - if current == nil || types.StateBlockNID(stateBlockNID) != current.StateBlockNID { - // The state entry row is for a different state data block to the current one. - // So we start appending to the next entry in the list. - current = &results[i] - current.StateBlockNID = types.StateBlockNID(stateBlockNID) - i++ - } - current.StateEntries = append(current.StateEntries, entry) - } - if err = rows.Err(); err != nil { - return nil, err - } - if i != len(stateBlockNIDs) { - return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", i, len(stateBlockNIDs)) - } - return results, err -} - -func (s *stateBlockStatements) BulkSelectFilteredStateBlockEntries( - ctx context.Context, - stateBlockNIDs []types.StateBlockNID, - stateKeyTuples []types.StateKeyTuple, -) ([]types.StateEntryList, error) { - tuples := stateKeyTupleSorter(stateKeyTuples) - // Sort the tuples so that we can run binary search against them as we filter the rows returned by the db. - sort.Sort(tuples) - - eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays() - rows, err := s.bulkSelectFilteredStateBlockEntriesStmt.QueryContext( - ctx, - stateBlockNIDsAsArray(stateBlockNIDs), - eventTypeNIDArray, - eventStateKeyNIDArray, - ) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectFilteredStateBlockEntries: rows.close() failed") - - var results []types.StateEntryList - var current types.StateEntryList - for rows.Next() { - var ( - stateBlockNID int64 - eventTypeNID int64 - eventStateKeyNID int64 - eventNID int64 - entry types.StateEntry - ) - if err := rows.Scan( - &stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID, - ); err != nil { - return nil, err - } - entry.EventTypeNID = types.EventTypeNID(eventTypeNID) - entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID) - entry.EventNID = types.EventNID(eventNID) - - // We can use binary search here because we sorted the tuples earlier - if !tuples.contains(entry.StateKeyTuple) { - // The select will return the cross product of types and state keys. - // So we need to check if type of the entry is in the list. - continue - } - - if types.StateBlockNID(stateBlockNID) != current.StateBlockNID { - // The state entry row is for a different state data block to the current one. - // So we append the current entry to the results and start adding to a new one. - // The first time through the loop current will be empty. - if current.StateEntries != nil { - results = append(results, current) - } - current = types.StateEntryList{StateBlockNID: types.StateBlockNID(stateBlockNID)} - } - current.StateEntries = append(current.StateEntries, entry) - } - // Add the last entry to the list if it is not empty. - if current.StateEntries != nil { - results = append(results, current) - } - return results, rows.Err() -} - -func stateBlockNIDsAsArray(stateBlockNIDs []types.StateBlockNID) pq.Int64Array { - nids := make([]int64, len(stateBlockNIDs)) - for i := range stateBlockNIDs { - nids[i] = int64(stateBlockNIDs[i]) - } - return pq.Int64Array(nids) -} - -type stateKeyTupleSorter []types.StateKeyTuple - -func (s stateKeyTupleSorter) Len() int { return len(s) } -func (s stateKeyTupleSorter) Less(i, j int) bool { return s[i].LessThan(s[j]) } -func (s stateKeyTupleSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } - -// Check whether a tuple is in the list. Assumes that the list is sorted. -func (s stateKeyTupleSorter) contains(value types.StateKeyTuple) bool { - i := sort.Search(len(s), func(i int) bool { return !s[i].LessThan(value) }) - return i < len(s) && s[i] == value -} - -// List the unique eventTypeNIDs and eventStateKeyNIDs. -// Assumes that the list is sorted. -func (s stateKeyTupleSorter) typesAndStateKeysAsArrays() (eventTypeNIDs pq.Int64Array, eventStateKeyNIDs pq.Int64Array) { - eventTypeNIDs = make(pq.Int64Array, len(s)) - eventStateKeyNIDs = make(pq.Int64Array, len(s)) - for i := range s { - eventTypeNIDs[i] = int64(s[i].EventTypeNID) - eventStateKeyNIDs[i] = int64(s[i].EventStateKeyNID) - } - eventTypeNIDs = eventTypeNIDs[:util.SortAndUnique(int64Sorter(eventTypeNIDs))] - eventStateKeyNIDs = eventStateKeyNIDs[:util.SortAndUnique(int64Sorter(eventStateKeyNIDs))] - return -} - -type int64Sorter []int64 - -func (s int64Sorter) Len() int { return len(s) } -func (s int64Sorter) Less(i, j int) bool { return s[i] < s[j] } -func (s int64Sorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } diff --git a/roomserver/storage/postgres/state_block_table_test.go b/roomserver/storage/postgres/state_block_table_test.go deleted file mode 100644 index a0e2ec95..00000000 --- a/roomserver/storage/postgres/state_block_table_test.go +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright 2017-2018 New Vector Ltd -// Copyright 2019-2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package postgres - -import ( - "sort" - "testing" - - "github.com/matrix-org/dendrite/roomserver/types" -) - -func TestStateKeyTupleSorter(t *testing.T) { - input := stateKeyTupleSorter{ - {EventTypeNID: 1, EventStateKeyNID: 2}, - {EventTypeNID: 1, EventStateKeyNID: 4}, - {EventTypeNID: 2, EventStateKeyNID: 2}, - {EventTypeNID: 1, EventStateKeyNID: 1}, - } - want := []types.StateKeyTuple{ - {EventTypeNID: 1, EventStateKeyNID: 1}, - {EventTypeNID: 1, EventStateKeyNID: 2}, - {EventTypeNID: 1, EventStateKeyNID: 4}, - {EventTypeNID: 2, EventStateKeyNID: 2}, - } - doNotWant := []types.StateKeyTuple{ - {EventTypeNID: 0, EventStateKeyNID: 0}, - {EventTypeNID: 1, EventStateKeyNID: 3}, - {EventTypeNID: 2, EventStateKeyNID: 1}, - {EventTypeNID: 3, EventStateKeyNID: 1}, - } - wantTypeNIDs := []int64{1, 2} - wantStateKeyNIDs := []int64{1, 2, 4} - - // Sort the input and check it's in the right order. - sort.Sort(input) - gotTypeNIDs, gotStateKeyNIDs := input.typesAndStateKeysAsArrays() - - for i := range want { - if input[i] != want[i] { - t.Errorf("Wanted %#v at index %d got %#v", want[i], i, input[i]) - } - - if !input.contains(want[i]) { - t.Errorf("Wanted %#v.contains(%#v) to be true but got false", input, want[i]) - } - } - - for i := range doNotWant { - if input.contains(doNotWant[i]) { - t.Errorf("Wanted %#v.contains(%#v) to be false but got true", input, doNotWant[i]) - } - } - - if len(wantTypeNIDs) != len(gotTypeNIDs) { - t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs) - } - - for i := range wantTypeNIDs { - if wantTypeNIDs[i] != gotTypeNIDs[i] { - t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs) - } - } - - if len(wantStateKeyNIDs) != len(gotStateKeyNIDs) { - t.Fatalf("Wanted state key NIDs %#v got %#v", wantStateKeyNIDs, gotStateKeyNIDs) - } - - for i := range wantStateKeyNIDs { - if wantStateKeyNIDs[i] != gotStateKeyNIDs[i] { - t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs) - } - } -} diff --git a/roomserver/storage/postgres/state_snapshot_table.go b/roomserver/storage/postgres/state_snapshot_table.go deleted file mode 100644 index 63175955..00000000 --- a/roomserver/storage/postgres/state_snapshot_table.go +++ /dev/null @@ -1,126 +0,0 @@ -// Copyright 2017-2018 New Vector Ltd -// Copyright 2019-2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package postgres - -import ( - "context" - "database/sql" - "fmt" - - "github.com/lib/pq" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" - "github.com/matrix-org/dendrite/roomserver/storage/tables" - "github.com/matrix-org/dendrite/roomserver/types" -) - -const stateSnapshotSchema = ` --- The state of a room before an event. --- Stored as a list of state_block entries stored in a separate table. --- The actual state is constructed by combining all the state_block entries --- referenced by state_block_nids together. If the same state key tuple appears --- multiple times then the entry from the later state_block clobbers the earlier --- entries. --- This encoding format allows us to implement a delta encoding which is useful --- because room state tends to accumulate small changes over time. Although if --- the list of deltas becomes too long it becomes more efficient to encode --- the full state under single state_block_nid. -CREATE SEQUENCE IF NOT EXISTS roomserver_state_snapshot_nid_seq; -CREATE TABLE IF NOT EXISTS roomserver_state_snapshots ( - -- Local numeric ID for the state. - state_snapshot_nid bigint PRIMARY KEY DEFAULT nextval('roomserver_state_snapshot_nid_seq'), - -- Local numeric ID of the room this state is for. - -- Unused in normal operation, but useful for background work or ad-hoc debugging. - room_nid bigint NOT NULL, - -- List of state_block_nids, stored sorted by state_block_nid. - state_block_nids bigint[] NOT NULL -); -` - -const insertStateSQL = "" + - "INSERT INTO roomserver_state_snapshots (room_nid, state_block_nids)" + - " VALUES ($1, $2)" + - " RETURNING state_snapshot_nid" - -// Bulk state data NID lookup. -// Sorting by state_snapshot_nid means we can use binary search over the result -// to lookup the state data NIDs for a state snapshot NID. -const bulkSelectStateBlockNIDsSQL = "" + - "SELECT state_snapshot_nid, state_block_nids FROM roomserver_state_snapshots" + - " WHERE state_snapshot_nid = ANY($1) ORDER BY state_snapshot_nid ASC" - -type stateSnapshotStatements struct { - insertStateStmt *sql.Stmt - bulkSelectStateBlockNIDsStmt *sql.Stmt -} - -func NewPostgresStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { - s := &stateSnapshotStatements{} - _, err := db.Exec(stateSnapshotSchema) - if err != nil { - return nil, err - } - - return s, shared.StatementList{ - {&s.insertStateStmt, insertStateSQL}, - {&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL}, - }.Prepare(db) -} - -func (s *stateSnapshotStatements) InsertState( - ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, -) (stateNID types.StateSnapshotNID, err error) { - nids := make([]int64, len(stateBlockNIDs)) - for i := range stateBlockNIDs { - nids[i] = int64(stateBlockNIDs[i]) - } - err = sqlutil.TxStmt(txn, s.insertStateStmt).QueryRowContext(ctx, int64(roomNID), pq.Int64Array(nids)).Scan(&stateNID) - return -} - -func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs( - ctx context.Context, stateNIDs []types.StateSnapshotNID, -) ([]types.StateBlockNIDList, error) { - nids := make([]int64, len(stateNIDs)) - for i := range stateNIDs { - nids[i] = int64(stateNIDs[i]) - } - rows, err := s.bulkSelectStateBlockNIDsStmt.QueryContext(ctx, pq.Int64Array(nids)) - if err != nil { - return nil, err - } - defer rows.Close() // nolint: errcheck - results := make([]types.StateBlockNIDList, len(stateNIDs)) - i := 0 - for ; rows.Next(); i++ { - result := &results[i] - var stateBlockNIDs pq.Int64Array - if err = rows.Scan(&result.StateSnapshotNID, &stateBlockNIDs); err != nil { - return nil, err - } - result.StateBlockNIDs = make([]types.StateBlockNID, len(stateBlockNIDs)) - for k := range stateBlockNIDs { - result.StateBlockNIDs[k] = types.StateBlockNID(stateBlockNIDs[k]) - } - } - if err = rows.Err(); err != nil { - return nil, err - } - if i != len(stateNIDs) { - return nil, fmt.Errorf("storage: state NIDs missing from the database (%d != %d)", i, len(stateNIDs)) - } - return results, nil -} diff --git a/roomserver/storage/postgres/state_table.go b/roomserver/storage/postgres/state_table.go new file mode 100644 index 00000000..b1573f63 --- /dev/null +++ b/roomserver/storage/postgres/state_table.go @@ -0,0 +1,124 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + "fmt" + + "github.com/lib/pq" + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" + "github.com/matrix-org/dendrite/roomserver/types" +) + +const stateSchema = ` +CREATE TABLE IF NOT EXISTS roomserver_state ( + state_nid BIGSERIAL PRIMARY KEY, + room_nid bigint NOT NULL, + event_nids bigint[] NOT NULL, + UNIQUE (room_nid, event_nids), + CONSTRAINT fk_room_id FOREIGN KEY(room_nid) REFERENCES roomserver_rooms(room_nid) +); +` + +const insertNewStateSnapshotSQL = "" + + "INSERT INTO roomserver_state (room_nid, event_nids)" + + " VALUES ($1, $2)" + + " ON CONFLICT (room_nid, event_nids) DO UPDATE SET room_nid = $1" + + " RETURNING state_nid" + +const bulkSelectNewStateSnapshotSQL = "" + + "SELECT state_nid, event_nids" + + " FROM roomserver_state WHERE state_nid = ANY($1)" + + " ORDER BY state_nid" + +type stateStatements struct { + insertStateStmt *sql.Stmt + bulkSelectStateStmt *sql.Stmt +} + +func NewPostgresStateTable(db *sql.DB) (tables.State, error) { + s := &stateStatements{} + _, err := db.Exec(stateSchema) + if err != nil { + return nil, err + } + + return s, shared.StatementList{ + {&s.insertStateStmt, insertNewStateSnapshotSQL}, + {&s.bulkSelectStateStmt, bulkSelectNewStateSnapshotSQL}, + }.Prepare(db) +} + +func (s *stateStatements) InsertState( + ctx context.Context, + txn *sql.Tx, + roomNID types.RoomNID, + eventNIDs []types.EventNID, +) (types.StateSnapshotNID, error) { + stmt := sqlutil.TxStmt(txn, s.insertStateStmt) + var id int64 + var err error + eventNIDs = types.DeduplicateEventNIDs(eventNIDs) + if err = stmt.QueryRowContext(ctx, int64(roomNID), eventNIDsAsArray(eventNIDs)).Scan(&id); err != nil { + return 0, fmt.Errorf("stmt.ExecContext: %w", err) + } + return types.StateSnapshotNID(id), err +} + +func (s *stateStatements) BulkSelectState( + ctx context.Context, stateNIDs []types.StateSnapshotNID, +) (map[types.StateSnapshotNID][]types.EventNID, error) { + rows, err := s.bulkSelectStateStmt.QueryContext(ctx, stateSnapshotNIDsAsArray(stateNIDs)) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockEntries: rows.close() failed") + + results := map[types.StateSnapshotNID][]types.EventNID{} + for rows.Next() { + var stateNID int64 + var eventNIDs pq.Int64Array + if err = rows.Scan(&stateNID, &eventNIDs); err != nil { + return nil, err + } + for _, id := range eventNIDs { + results[types.StateSnapshotNID(stateNID)] = append( + results[types.StateSnapshotNID(stateNID)], + types.EventNID(id), + ) + } + } + if err = rows.Err(); err != nil { + return nil, fmt.Errorf("rows.Err: %w", err) + } + if len(results) != len(stateNIDs) { + return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", len(results), len(stateNIDs)) + } + return results, err +} + +func stateSnapshotNIDsAsArray(stateSnapshotNIDs []types.StateSnapshotNID) pq.Int64Array { + nids := make([]int64, len(stateSnapshotNIDs)) + for i := range stateSnapshotNIDs { + nids[i] = int64(stateSnapshotNIDs[i]) + } + return nids +} diff --git a/roomserver/storage/postgres/storage.go b/roomserver/storage/postgres/storage.go index bb3f841d..14677c17 100644 --- a/roomserver/storage/postgres/storage.go +++ b/roomserver/storage/postgres/storage.go @@ -86,11 +86,17 @@ func (d *Database) prepare(db *sql.DB, cache caching.RoomServerCaches) (err erro if err != nil { return err } - stateBlock, err := NewPostgresStateBlockTable(db) - if err != nil { - return err - } - stateSnapshot, err := NewPostgresStateSnapshotTable(db) + /* + stateBlock, err := NewPostgresStateBlockTable(db) + if err != nil { + return err + } + stateSnapshot, err := NewPostgresStateSnapshotTable(db) + if err != nil { + return err + } + */ + state, err := NewPostgresStateTable(db) if err != nil { return err } @@ -128,14 +134,15 @@ func (d *Database) prepare(db *sql.DB, cache caching.RoomServerCaches) (err erro EventsTable: events, RoomsTable: rooms, TransactionsTable: transactions, - StateBlockTable: stateBlock, - StateSnapshotTable: stateSnapshot, - PrevEventsTable: prevEvents, - RoomAliasesTable: roomAliases, - InvitesTable: invites, - MembershipTable: membership, - PublishedTable: published, - RedactionsTable: redactions, + StateTable: state, + // StateBlockTable: stateBlock, + // StateSnapshotTable: stateSnapshot, + PrevEventsTable: prevEvents, + RoomAliasesTable: roomAliases, + InvitesTable: invites, + MembershipTable: membership, + PublishedTable: published, + RedactionsTable: redactions, } return nil } diff --git a/roomserver/storage/shared/storage.go b/roomserver/storage/shared/storage.go index 24b48772..1da63846 100644 --- a/roomserver/storage/shared/storage.go +++ b/roomserver/storage/shared/storage.go @@ -5,7 +5,6 @@ import ( "database/sql" "encoding/json" "fmt" - "sort" "github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/sqlutil" @@ -13,7 +12,6 @@ import ( "github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/gomatrixserverlib" - "github.com/matrix-org/util" "github.com/tidwall/gjson" ) @@ -36,8 +34,7 @@ type Database struct { EventStateKeysTable tables.EventStateKeys RoomsTable tables.Rooms TransactionsTable tables.Transactions - StateSnapshotTable tables.StateSnapshot - StateBlockTable tables.StateBlock + StateTable tables.State RoomAliasesTable tables.RoomAliases PrevEventsTable tables.PreviousEvents InvitesTable tables.Invites @@ -113,16 +110,6 @@ func (d *Database) StateEntriesForEventIDs( return d.EventsTable.BulkSelectStateEventByID(ctx, eventIDs) } -func (d *Database) StateEntriesForTuples( - ctx context.Context, - stateBlockNIDs []types.StateBlockNID, - stateKeyTuples []types.StateKeyTuple, -) ([]types.StateEntryList, error) { - return d.StateBlockTable.BulkSelectFilteredStateBlockEntries( - ctx, stateBlockNIDs, stateKeyTuples, - ) -} - func (d *Database) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) { if roomInfo, ok := d.Cache.GetRoomInfo(roomID); ok { return &roomInfo, nil @@ -138,21 +125,16 @@ func (d *Database) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo func (d *Database) AddState( ctx context.Context, roomNID types.RoomNID, - stateBlockNIDs []types.StateBlockNID, state []types.StateEntry, ) (stateNID types.StateSnapshotNID, err error) { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { - if len(state) > 0 { - var stateBlockNID types.StateBlockNID - stateBlockNID, err = d.StateBlockTable.BulkInsertStateData(ctx, txn, state) - if err != nil { - return fmt.Errorf("d.StateBlockTable.BulkInsertStateData: %w", err) - } - stateBlockNIDs = append(stateBlockNIDs[:len(stateBlockNIDs):len(stateBlockNIDs)], stateBlockNID) + eventNIDs := make([]types.EventNID, 0, len(state)) + for _, s := range state { + eventNIDs = append(eventNIDs, s.EventNID) } - stateNID, err = d.StateSnapshotTable.InsertState(ctx, txn, roomNID, stateBlockNIDs) + stateNID, err = d.StateTable.InsertState(ctx, txn, roomNID, eventNIDs) if err != nil { - return fmt.Errorf("d.StateSnapshotTable.InsertState: %w", err) + return fmt.Errorf("d.StateTable.InsertState: %w", err) } return nil }) @@ -228,16 +210,22 @@ func (d *Database) LatestEventIDs( return } -func (d *Database) StateBlockNIDs( - ctx context.Context, stateNIDs []types.StateSnapshotNID, -) ([]types.StateBlockNIDList, error) { - return d.StateSnapshotTable.BulkSelectStateBlockNIDs(ctx, stateNIDs) -} - func (d *Database) StateEntries( - ctx context.Context, stateBlockNIDs []types.StateBlockNID, -) ([]types.StateEntryList, error) { - return d.StateBlockTable.BulkSelectStateBlockEntries(ctx, stateBlockNIDs) + ctx context.Context, stateSnapshotNID types.StateSnapshotNID, +) ([]types.StateEntry, error) { + nids, err := d.StateTable.BulkSelectState(ctx, []types.StateSnapshotNID{stateSnapshotNID}) + if err != nil { + return nil, fmt.Errorf("d.StateTable.BulkSelectState: %w", err) + } + state, ok := nids[stateSnapshotNID] + if !ok { + return nil, fmt.Errorf("state snapshot %d not found", stateSnapshotNID) + } + entries, err := d.EventsTable.BulkSelectStateEventByNID(ctx, state) + if err != nil { + return nil, fmt.Errorf("d.EventsTable.BulkSelectStateEventByNID: %w", err) + } + return entries, nil } func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error { @@ -817,9 +805,17 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s if err != nil { return nil, err } - entries, err := d.loadStateAtSnapshot(ctx, roomInfo.StateSnapshotNID) + snapshots, err := d.StateTable.BulkSelectState(ctx, []types.StateSnapshotNID{roomInfo.StateSnapshotNID}) if err != nil { - return nil, err + return nil, fmt.Errorf("d.StateTable.BulkSelectState: %w", err) + } + nids, ok := snapshots[roomInfo.StateSnapshotNID] + if !ok { + return nil, fmt.Errorf("state snapshot %d not found", roomInfo.StateSnapshotNID) + } + entries, err := d.EventsTable.BulkSelectStateEventByNID(ctx, nids) + if err != nil { + return nil, fmt.Errorf("d.EventsTable.BulkSelectStateEventByNID: %w", err) } var eventNIDs []types.EventNID for _, e := range entries { @@ -938,7 +934,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu if roomInfo == nil || roomInfo.IsStub { continue } - entries, err2 := d.loadStateAtSnapshot(ctx, roomInfo.StateSnapshotNID) + entries, err2 := d.StateEntries(ctx, roomInfo.StateSnapshotNID) if err2 != nil { return nil, fmt.Errorf("GetBulkStateContent: failed to load state for room %s : %w", roomID, err2) } @@ -1039,65 +1035,3 @@ func (d *Database) ForgetRoom(ctx context.Context, userID, roomID string, forget return d.MembershipTable.UpdateForgetMembership(ctx, nil, roomNIDs[0], stateKeyNID, forget) }) } - -// FIXME TODO: Remove all this - horrible dupe with roomserver/state. Can't use the original impl because of circular loops -// it should live in this package! - -func (d *Database) loadStateAtSnapshot( - ctx context.Context, stateNID types.StateSnapshotNID, -) ([]types.StateEntry, error) { - stateBlockNIDLists, err := d.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID}) - if err != nil { - return nil, err - } - // We've asked for exactly one snapshot from the db so we should have exactly one entry in the result. - stateBlockNIDList := stateBlockNIDLists[0] - - stateEntryLists, err := d.StateEntries(ctx, stateBlockNIDList.StateBlockNIDs) - if err != nil { - return nil, err - } - stateEntriesMap := stateEntryListMap(stateEntryLists) - - // Combine all the state entries for this snapshot. - // The order of state block NIDs in the list tells us the order to combine them in. - var fullState []types.StateEntry - for _, stateBlockNID := range stateBlockNIDList.StateBlockNIDs { - entries, ok := stateEntriesMap.lookup(stateBlockNID) - if !ok { - // This should only get hit if the database is corrupt. - // It should be impossible for an event to reference a NID that doesn't exist - panic(fmt.Errorf("Corrupt DB: Missing state block numeric ID %d", stateBlockNID)) - } - fullState = append(fullState, entries...) - } - - // Stable sort so that the most recent entry for each state key stays - // remains later in the list than the older entries for the same state key. - sort.Stable(stateEntryByStateKeySorter(fullState)) - // Unique returns the last entry and hence the most recent entry for each state key. - fullState = fullState[:util.Unique(stateEntryByStateKeySorter(fullState))] - return fullState, nil -} - -type stateEntryListMap []types.StateEntryList - -func (m stateEntryListMap) lookup(stateBlockNID types.StateBlockNID) (stateEntries []types.StateEntry, ok bool) { - list := []types.StateEntryList(m) - i := sort.Search(len(list), func(i int) bool { - return list[i].StateBlockNID >= stateBlockNID - }) - if i < len(list) && list[i].StateBlockNID == stateBlockNID { - ok = true - stateEntries = list[i].StateEntries - } - return -} - -type stateEntryByStateKeySorter []types.StateEntry - -func (s stateEntryByStateKeySorter) Len() int { return len(s) } -func (s stateEntryByStateKeySorter) Less(i, j int) bool { - return s[i].StateKeyTuple.LessThan(s[j].StateKeyTuple) -} -func (s stateEntryByStateKeySorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } diff --git a/roomserver/storage/sqlite3/events_table.go b/roomserver/storage/sqlite3/events_table.go index 53269657..10611067 100644 --- a/roomserver/storage/sqlite3/events_table.go +++ b/roomserver/storage/sqlite3/events_table.go @@ -63,6 +63,11 @@ const bulkSelectStateEventByIDSQL = "" + " WHERE event_id IN ($1)" + " ORDER BY event_type_nid, event_state_key_nid ASC" +const bulkSelectStateEventByNIDSQL = "" + + "SELECT event_type_nid, event_state_key_nid, event_nid FROM roomserver_events" + + " WHERE event_nid IN ($1)" + + " ORDER BY event_type_nid, event_state_key_nid ASC" + const bulkSelectStateAtEventByIDSQL = "" + "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, is_rejected FROM roomserver_events" + " WHERE event_id IN ($1)" @@ -103,6 +108,7 @@ type eventStatements struct { insertEventStmt *sql.Stmt selectEventStmt *sql.Stmt bulkSelectStateEventByIDStmt *sql.Stmt + bulkSelectStateEventByNIDStmt *sql.Stmt bulkSelectStateAtEventByIDStmt *sql.Stmt updateEventStateStmt *sql.Stmt selectEventSentToOutputStmt *sql.Stmt @@ -128,6 +134,7 @@ func NewSqliteEventsTable(db *sql.DB) (tables.Events, error) { {&s.insertEventStmt, insertEventSQL}, {&s.selectEventStmt, selectEventSQL}, {&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL}, + {&s.bulkSelectStateEventByNIDStmt, bulkSelectStateEventByNIDSQL}, {&s.bulkSelectStateAtEventByIDStmt, bulkSelectStateAtEventByIDSQL}, {&s.updateEventStateStmt, updateEventStateSQL}, {&s.updateEventSentToOutputStmt, updateEventSentToOutputSQL}, @@ -232,6 +239,57 @@ func (s *eventStatements) BulkSelectStateEventByID( return results, err } +// bulkSelectStateEventByID lookups a list of state events by event ID. +// If any of the requested events are missing from the database it returns a types.MissingEventError +func (s *eventStatements) BulkSelectStateEventByNID( + ctx context.Context, eventNIDs []types.EventNID, +) ([]types.StateEntry, error) { + /////////////// + iEventNIDs := make([]interface{}, len(eventNIDs)) + for k, v := range eventNIDs { + iEventNIDs[k] = v + } + selectOrig := strings.Replace(bulkSelectStateEventByNIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1) + selectStmt, err := s.db.Prepare(selectOrig) + if err != nil { + return nil, err + } + /////////////// + + rows, err := selectStmt.QueryContext(ctx, iEventNIDs...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateEventByID: rows.close() failed") + // We know that we will only get as many results as event IDs + // because of the unique constraint on event IDs. + // So we can allocate an array of the correct size now. + // We might get fewer results than IDs so we adjust the length of the slice before returning it. + results := make([]types.StateEntry, len(eventNIDs)) + i := 0 + for ; rows.Next(); i++ { + result := &results[i] + if err = rows.Scan( + &result.EventTypeNID, + &result.EventStateKeyNID, + &result.EventNID, + ); err != nil { + return nil, err + } + } + if i != len(eventNIDs) { + // If there are fewer rows returned than IDs then we were asked to lookup event IDs we don't have. + // We don't know which ones were missing because we don't return the string IDs in the query. + // However it should be possible debug this by replaying queries or entries from the input kafka logs. + // If this turns out to be impossible and we do need the debug information here, it would be better + // to do it as a separate query rather than slowing down/complicating the internal case. + return nil, types.MissingEventError( + fmt.Sprintf("storage: state event IDs missing from the database (%d != %d)", i, len(eventNIDs)), + ) + } + return results, err +} + // bulkSelectStateAtEventByID lookups the state at a list of events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError. // If we do not have the state for any of the requested events it returns a types.MissingEventError. diff --git a/roomserver/storage/sqlite3/state_block_table.go b/roomserver/storage/sqlite3/state_block_table.go deleted file mode 100644 index 2c544f2b..00000000 --- a/roomserver/storage/sqlite3/state_block_table.go +++ /dev/null @@ -1,289 +0,0 @@ -// Copyright 2017-2018 New Vector Ltd -// Copyright 2019-2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sqlite3 - -import ( - "context" - "database/sql" - "fmt" - "sort" - "strings" - - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" - "github.com/matrix-org/dendrite/roomserver/storage/tables" - "github.com/matrix-org/dendrite/roomserver/types" - "github.com/matrix-org/util" -) - -const stateDataSchema = ` - CREATE TABLE IF NOT EXISTS roomserver_state_block ( - state_block_nid INTEGER NOT NULL, - event_type_nid INTEGER NOT NULL, - event_state_key_nid INTEGER NOT NULL, - event_nid INTEGER NOT NULL, - UNIQUE (state_block_nid, event_type_nid, event_state_key_nid) - ); -` - -const insertStateDataSQL = "" + - "INSERT INTO roomserver_state_block (state_block_nid, event_type_nid, event_state_key_nid, event_nid)" + - " VALUES ($1, $2, $3, $4)" - -const selectNextStateBlockNIDSQL = ` -SELECT IFNULL(MAX(state_block_nid), 0) + 1 FROM roomserver_state_block -` - -// Bulk state lookup by numeric state block ID. -// Sort by the state_block_nid, event_type_nid, event_state_key_nid -// This means that all the entries for a given state_block_nid will appear -// together in the list and those entries will sorted by event_type_nid -// and event_state_key_nid. This property makes it easier to merge two -// state data blocks together. -const bulkSelectStateBlockEntriesSQL = "" + - "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" + - " FROM roomserver_state_block WHERE state_block_nid IN ($1)" + - " ORDER BY state_block_nid, event_type_nid, event_state_key_nid" - -// Bulk state lookup by numeric state block ID. -// Filters the rows in each block to the requested types and state keys. -// We would like to restrict to particular type state key pairs but we are -// restricted by the query language to pull the cross product of a list -// of types and a list state_keys. So we have to filter the result in the -// application to restrict it to the list of event types and state keys we -// actually wanted. -const bulkSelectFilteredStateBlockEntriesSQL = "" + - "SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" + - " FROM roomserver_state_block WHERE state_block_nid IN ($1)" + - " AND event_type_nid IN ($2) AND event_state_key_nid IN ($3)" + - " ORDER BY state_block_nid, event_type_nid, event_state_key_nid" - -type stateBlockStatements struct { - db *sql.DB - insertStateDataStmt *sql.Stmt - selectNextStateBlockNIDStmt *sql.Stmt - bulkSelectStateBlockEntriesStmt *sql.Stmt - bulkSelectFilteredStateBlockEntriesStmt *sql.Stmt -} - -func NewSqliteStateBlockTable(db *sql.DB) (tables.StateBlock, error) { - s := &stateBlockStatements{ - db: db, - } - _, err := db.Exec(stateDataSchema) - if err != nil { - return nil, err - } - - return s, shared.StatementList{ - {&s.insertStateDataStmt, insertStateDataSQL}, - {&s.selectNextStateBlockNIDStmt, selectNextStateBlockNIDSQL}, - {&s.bulkSelectStateBlockEntriesStmt, bulkSelectStateBlockEntriesSQL}, - {&s.bulkSelectFilteredStateBlockEntriesStmt, bulkSelectFilteredStateBlockEntriesSQL}, - }.Prepare(db) -} - -func (s *stateBlockStatements) BulkInsertStateData( - ctx context.Context, txn *sql.Tx, - entries []types.StateEntry, -) (types.StateBlockNID, error) { - if len(entries) == 0 { - return 0, nil - } - var stateBlockNID types.StateBlockNID - err := sqlutil.TxStmt(txn, s.selectNextStateBlockNIDStmt).QueryRowContext(ctx).Scan(&stateBlockNID) - if err != nil { - return 0, err - } - for _, entry := range entries { - _, err = sqlutil.TxStmt(txn, s.insertStateDataStmt).ExecContext( - ctx, - int64(stateBlockNID), - int64(entry.EventTypeNID), - int64(entry.EventStateKeyNID), - int64(entry.EventNID), - ) - if err != nil { - return 0, err - } - } - return stateBlockNID, err -} - -func (s *stateBlockStatements) BulkSelectStateBlockEntries( - ctx context.Context, stateBlockNIDs []types.StateBlockNID, -) ([]types.StateEntryList, error) { - nids := make([]interface{}, len(stateBlockNIDs)) - for k, v := range stateBlockNIDs { - nids[k] = v - } - selectOrig := strings.Replace(bulkSelectStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1) - selectStmt, err := s.db.Prepare(selectOrig) - if err != nil { - return nil, err - } - rows, err := selectStmt.QueryContext(ctx, nids...) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockEntries: rows.close() failed") - - results := make([]types.StateEntryList, len(stateBlockNIDs)) - // current is a pointer to the StateEntryList to append the state entries to. - var current *types.StateEntryList - i := 0 - for rows.Next() { - var ( - stateBlockNID int64 - eventTypeNID int64 - eventStateKeyNID int64 - eventNID int64 - entry types.StateEntry - ) - if err := rows.Scan( - &stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID, - ); err != nil { - return nil, err - } - entry.EventTypeNID = types.EventTypeNID(eventTypeNID) - entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID) - entry.EventNID = types.EventNID(eventNID) - if current == nil || types.StateBlockNID(stateBlockNID) != current.StateBlockNID { - // The state entry row is for a different state data block to the current one. - // So we start appending to the next entry in the list. - current = &results[i] - current.StateBlockNID = types.StateBlockNID(stateBlockNID) - i++ - } - current.StateEntries = append(current.StateEntries, entry) - } - if i != len(nids) { - return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", i, len(nids)) - } - return results, nil -} - -func (s *stateBlockStatements) BulkSelectFilteredStateBlockEntries( - ctx context.Context, - stateBlockNIDs []types.StateBlockNID, - stateKeyTuples []types.StateKeyTuple, -) ([]types.StateEntryList, error) { - tuples := stateKeyTupleSorter(stateKeyTuples) - // Sort the tuples so that we can run binary search against them as we filter the rows returned by the db. - sort.Sort(tuples) - - eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays() - sqlStatement := strings.Replace(bulkSelectFilteredStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(stateBlockNIDs)), 1) - sqlStatement = strings.Replace(sqlStatement, "($2)", sqlutil.QueryVariadicOffset(len(eventTypeNIDArray), len(stateBlockNIDs)), 1) - sqlStatement = strings.Replace(sqlStatement, "($3)", sqlutil.QueryVariadicOffset(len(eventStateKeyNIDArray), len(stateBlockNIDs)+len(eventTypeNIDArray)), 1) - - var params []interface{} - for _, val := range stateBlockNIDs { - params = append(params, int64(val)) - } - for _, val := range eventTypeNIDArray { - params = append(params, val) - } - for _, val := range eventStateKeyNIDArray { - params = append(params, val) - } - - rows, err := s.db.QueryContext( - ctx, - sqlStatement, - params..., - ) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectFilteredStateBlockEntries: rows.close() failed") - - var results []types.StateEntryList - var current types.StateEntryList - for rows.Next() { - var ( - stateBlockNID int64 - eventTypeNID int64 - eventStateKeyNID int64 - eventNID int64 - entry types.StateEntry - ) - if err := rows.Scan( - &stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID, - ); err != nil { - return nil, err - } - entry.EventTypeNID = types.EventTypeNID(eventTypeNID) - entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID) - entry.EventNID = types.EventNID(eventNID) - - // We can use binary search here because we sorted the tuples earlier - if !tuples.contains(entry.StateKeyTuple) { - // The select will return the cross product of types and state keys. - // So we need to check if type of the entry is in the list. - continue - } - - if types.StateBlockNID(stateBlockNID) != current.StateBlockNID { - // The state entry row is for a different state data block to the current one. - // So we append the current entry to the results and start adding to a new one. - // The first time through the loop current will be empty. - if current.StateEntries != nil { - results = append(results, current) - } - current = types.StateEntryList{StateBlockNID: types.StateBlockNID(stateBlockNID)} - } - current.StateEntries = append(current.StateEntries, entry) - } - // Add the last entry to the list if it is not empty. - if current.StateEntries != nil { - results = append(results, current) - } - return results, nil -} - -type stateKeyTupleSorter []types.StateKeyTuple - -func (s stateKeyTupleSorter) Len() int { return len(s) } -func (s stateKeyTupleSorter) Less(i, j int) bool { return s[i].LessThan(s[j]) } -func (s stateKeyTupleSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } - -// Check whether a tuple is in the list. Assumes that the list is sorted. -func (s stateKeyTupleSorter) contains(value types.StateKeyTuple) bool { - i := sort.Search(len(s), func(i int) bool { return !s[i].LessThan(value) }) - return i < len(s) && s[i] == value -} - -// List the unique eventTypeNIDs and eventStateKeyNIDs. -// Assumes that the list is sorted. -func (s stateKeyTupleSorter) typesAndStateKeysAsArrays() (eventTypeNIDs []int64, eventStateKeyNIDs []int64) { - eventTypeNIDs = make([]int64, len(s)) - eventStateKeyNIDs = make([]int64, len(s)) - for i := range s { - eventTypeNIDs[i] = int64(s[i].EventTypeNID) - eventStateKeyNIDs[i] = int64(s[i].EventStateKeyNID) - } - eventTypeNIDs = eventTypeNIDs[:util.SortAndUnique(int64Sorter(eventTypeNIDs))] - eventStateKeyNIDs = eventStateKeyNIDs[:util.SortAndUnique(int64Sorter(eventStateKeyNIDs))] - return -} - -type int64Sorter []int64 - -func (s int64Sorter) Len() int { return len(s) } -func (s int64Sorter) Less(i, j int) bool { return s[i] < s[j] } -func (s int64Sorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } diff --git a/roomserver/storage/sqlite3/state_block_table_test.go b/roomserver/storage/sqlite3/state_block_table_test.go deleted file mode 100644 index 98439f5c..00000000 --- a/roomserver/storage/sqlite3/state_block_table_test.go +++ /dev/null @@ -1,86 +0,0 @@ -// Copyright 2017-2018 New Vector Ltd -// Copyright 2019-2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sqlite3 - -import ( - "sort" - "testing" - - "github.com/matrix-org/dendrite/roomserver/types" -) - -func TestStateKeyTupleSorter(t *testing.T) { - input := stateKeyTupleSorter{ - {EventTypeNID: 1, EventStateKeyNID: 2}, - {EventTypeNID: 1, EventStateKeyNID: 4}, - {EventTypeNID: 2, EventStateKeyNID: 2}, - {EventTypeNID: 1, EventStateKeyNID: 1}, - } - want := []types.StateKeyTuple{ - {EventTypeNID: 1, EventStateKeyNID: 1}, - {EventTypeNID: 1, EventStateKeyNID: 2}, - {EventTypeNID: 1, EventStateKeyNID: 4}, - {EventTypeNID: 2, EventStateKeyNID: 2}, - } - doNotWant := []types.StateKeyTuple{ - {EventTypeNID: 0, EventStateKeyNID: 0}, - {EventTypeNID: 1, EventStateKeyNID: 3}, - {EventTypeNID: 2, EventStateKeyNID: 1}, - {EventTypeNID: 3, EventStateKeyNID: 1}, - } - wantTypeNIDs := []int64{1, 2} - wantStateKeyNIDs := []int64{1, 2, 4} - - // Sort the input and check it's in the right order. - sort.Sort(input) - gotTypeNIDs, gotStateKeyNIDs := input.typesAndStateKeysAsArrays() - - for i := range want { - if input[i] != want[i] { - t.Errorf("Wanted %#v at index %d got %#v", want[i], i, input[i]) - } - - if !input.contains(want[i]) { - t.Errorf("Wanted %#v.contains(%#v) to be true but got false", input, want[i]) - } - } - - for i := range doNotWant { - if input.contains(doNotWant[i]) { - t.Errorf("Wanted %#v.contains(%#v) to be false but got true", input, doNotWant[i]) - } - } - - if len(wantTypeNIDs) != len(gotTypeNIDs) { - t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs) - } - - for i := range wantTypeNIDs { - if wantTypeNIDs[i] != gotTypeNIDs[i] { - t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs) - } - } - - if len(wantStateKeyNIDs) != len(gotStateKeyNIDs) { - t.Fatalf("Wanted state key NIDs %#v got %#v", wantStateKeyNIDs, gotStateKeyNIDs) - } - - for i := range wantStateKeyNIDs { - if wantStateKeyNIDs[i] != gotStateKeyNIDs[i] { - t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs) - } - } -} diff --git a/roomserver/storage/sqlite3/state_snapshot_table.go b/roomserver/storage/sqlite3/state_snapshot_table.go deleted file mode 100644 index bf49f62c..00000000 --- a/roomserver/storage/sqlite3/state_snapshot_table.go +++ /dev/null @@ -1,126 +0,0 @@ -// Copyright 2017-2018 New Vector Ltd -// Copyright 2019-2020 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package sqlite3 - -import ( - "context" - "database/sql" - "encoding/json" - "fmt" - "strings" - - "github.com/matrix-org/dendrite/internal" - "github.com/matrix-org/dendrite/internal/sqlutil" - "github.com/matrix-org/dendrite/roomserver/storage/shared" - "github.com/matrix-org/dendrite/roomserver/storage/tables" - "github.com/matrix-org/dendrite/roomserver/types" -) - -const stateSnapshotSchema = ` - CREATE TABLE IF NOT EXISTS roomserver_state_snapshots ( - state_snapshot_nid INTEGER PRIMARY KEY AUTOINCREMENT, - room_nid INTEGER NOT NULL, - state_block_nids TEXT NOT NULL DEFAULT '[]' - ); -` - -const insertStateSQL = ` - INSERT INTO roomserver_state_snapshots (room_nid, state_block_nids) - VALUES ($1, $2);` - -// Bulk state data NID lookup. -// Sorting by state_snapshot_nid means we can use binary search over the result -// to lookup the state data NIDs for a state snapshot NID. -const bulkSelectStateBlockNIDsSQL = "" + - "SELECT state_snapshot_nid, state_block_nids FROM roomserver_state_snapshots" + - " WHERE state_snapshot_nid IN ($1) ORDER BY state_snapshot_nid ASC" - -type stateSnapshotStatements struct { - db *sql.DB - insertStateStmt *sql.Stmt - bulkSelectStateBlockNIDsStmt *sql.Stmt -} - -func NewSqliteStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) { - s := &stateSnapshotStatements{ - db: db, - } - _, err := db.Exec(stateSnapshotSchema) - if err != nil { - return nil, err - } - - return s, shared.StatementList{ - {&s.insertStateStmt, insertStateSQL}, - {&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL}, - }.Prepare(db) -} - -func (s *stateSnapshotStatements) InsertState( - ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, -) (stateNID types.StateSnapshotNID, err error) { - stateBlockNIDsJSON, err := json.Marshal(stateBlockNIDs) - if err != nil { - return - } - insertStmt := sqlutil.TxStmt(txn, s.insertStateStmt) - res, err := insertStmt.ExecContext(ctx, int64(roomNID), string(stateBlockNIDsJSON)) - if err != nil { - return 0, err - } - lastRowID, err := res.LastInsertId() - if err != nil { - return 0, err - } - stateNID = types.StateSnapshotNID(lastRowID) - return -} - -func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs( - ctx context.Context, stateNIDs []types.StateSnapshotNID, -) ([]types.StateBlockNIDList, error) { - nids := make([]interface{}, len(stateNIDs)) - for k, v := range stateNIDs { - nids[k] = v - } - selectOrig := strings.Replace(bulkSelectStateBlockNIDsSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1) - selectStmt, err := s.db.Prepare(selectOrig) - if err != nil { - return nil, err - } - - rows, err := selectStmt.QueryContext(ctx, nids...) - if err != nil { - return nil, err - } - defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockNIDs: rows.close() failed") - results := make([]types.StateBlockNIDList, len(stateNIDs)) - i := 0 - for ; rows.Next(); i++ { - result := &results[i] - var stateBlockNIDsJSON string - if err := rows.Scan(&result.StateSnapshotNID, &stateBlockNIDsJSON); err != nil { - return nil, err - } - if err := json.Unmarshal([]byte(stateBlockNIDsJSON), &result.StateBlockNIDs); err != nil { - return nil, err - } - } - if i != len(stateNIDs) { - return nil, fmt.Errorf("storage: state NIDs missing from the database (%d != %d)", i, len(stateNIDs)) - } - return results, nil -} diff --git a/roomserver/storage/sqlite3/state_table.go b/roomserver/storage/sqlite3/state_table.go new file mode 100644 index 00000000..a7c33205 --- /dev/null +++ b/roomserver/storage/sqlite3/state_table.go @@ -0,0 +1,132 @@ +// Copyright 2017-2018 New Vector Ltd +// Copyright 2019-2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "strings" + + "github.com/matrix-org/dendrite/internal" + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/roomserver/storage/shared" + "github.com/matrix-org/dendrite/roomserver/storage/tables" + "github.com/matrix-org/dendrite/roomserver/types" +) + +const stateSchema = ` +CREATE TABLE IF NOT EXISTS roomserver_state ( + state_nid INTEGER PRIMARY KEY AUTOINCREMENT, + room_nid INTEGER NOT NULL, + event_nids TEXT NOT NULL, + UNIQUE (room_nid, event_nids), + CONSTRAINT fk_room_id FOREIGN KEY(room_nid) REFERENCES roomserver_rooms(room_nid) +); +` + +const insertNewStateSnapshotSQL = "" + + "INSERT INTO roomserver_state (room_nid, event_nids)" + + " VALUES ($1, $2)" + + " ON CONFLICT (room_nid, event_nids) DO UPDATE SET room_nid = $1" + + " RETURNING state_nid" + +const bulkSelectNewStateSnapshotSQL = "" + + "SELECT state_nid, event_nids" + + " FROM roomserver_state WHERE state_nid IN ($1)" + + " ORDER BY state_nid" + +type stateStatements struct { + db *sql.DB + insertStateStmt *sql.Stmt + bulkSelectStateStmt *sql.Stmt +} + +func NewPostgresStateTable(db *sql.DB) (tables.State, error) { + s := &stateStatements{db: db} + _, err := db.Exec(stateSchema) + if err != nil { + return nil, err + } + + return s, shared.StatementList{ + {&s.insertStateStmt, insertNewStateSnapshotSQL}, + {&s.bulkSelectStateStmt, bulkSelectNewStateSnapshotSQL}, + }.Prepare(db) +} + +func (s *stateStatements) InsertState( + ctx context.Context, + txn *sql.Tx, + roomNID types.RoomNID, + eventNIDs []types.EventNID, +) (types.StateSnapshotNID, error) { + stmt := sqlutil.TxStmt(txn, s.insertStateStmt) + value, err := json.Marshal(types.DeduplicateEventNIDs(eventNIDs)) + if err != nil { + return 0, fmt.Errorf("json.Marshal: %w", err) + } + var id int64 + if err = stmt.QueryRowContext(ctx, int64(roomNID), value).Scan(&id); err != nil { + return 0, fmt.Errorf("stmt.ExecContext: %w", err) + } + return types.StateSnapshotNID(id), nil +} + +func (s *stateStatements) BulkSelectState( + ctx context.Context, stateNIDs []types.StateSnapshotNID, +) (map[types.StateSnapshotNID][]types.EventNID, error) { + selectOrig := strings.Replace(bulkSelectNewStateSnapshotSQL, "($1)", sqlutil.QueryVariadic(len(stateNIDs)), 1) + selectStmt, err := s.db.Prepare(selectOrig) + if err != nil { + return nil, err + } + nids := make([]interface{}, len(stateNIDs)) + for i := range stateNIDs { + nids[i] = int64(stateNIDs[i]) + } + rows, err := selectStmt.QueryContext(ctx, nids...) + if err != nil { + return nil, err + } + defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockEntries: rows.close() failed") + results := map[types.StateSnapshotNID][]types.EventNID{} + for rows.Next() { + var stateNID int64 + var eventNIDJSON json.RawMessage + if err = rows.Scan(&stateNID, &eventNIDJSON); err != nil { + return nil, fmt.Errorf("rows.Scan: %w", err) + } + var eventNIDs []int64 + if err = json.Unmarshal(eventNIDJSON, &eventNIDs); err != nil { + return nil, fmt.Errorf("json.Unmarshal: %w", err) + } + for _, id := range eventNIDs { + results[types.StateSnapshotNID(stateNID)] = append( + results[types.StateSnapshotNID(stateNID)], + types.EventNID(id), + ) + } + } + if err = rows.Err(); err != nil { + return nil, fmt.Errorf("rows.Err: %w", err) + } + if len(results) != len(stateNIDs) { + return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", len(results), len(stateNIDs)) + } + return results, err +} diff --git a/roomserver/storage/sqlite3/storage.go b/roomserver/storage/sqlite3/storage.go index 8e608a6d..ea1f4073 100644 --- a/roomserver/storage/sqlite3/storage.go +++ b/roomserver/storage/sqlite3/storage.go @@ -98,11 +98,17 @@ func (d *Database) prepare(db *sql.DB, cache caching.RoomServerCaches) error { if err != nil { return err } - stateBlock, err := NewSqliteStateBlockTable(db) - if err != nil { - return err - } - stateSnapshot, err := NewSqliteStateSnapshotTable(db) + /* + stateBlock, err := NewSqliteStateBlockTable(db) + if err != nil { + return err + } + stateSnapshot, err := NewSqliteStateSnapshotTable(db) + if err != nil { + return err + } + */ + state, err := NewPostgresStateTable(db) if err != nil { return err } @@ -131,17 +137,18 @@ func (d *Database) prepare(db *sql.DB, cache caching.RoomServerCaches) error { return err } d.Database = shared.Database{ - DB: db, - Cache: cache, - Writer: sqlutil.NewExclusiveWriter(), - EventsTable: events, - EventTypesTable: eventTypes, - EventStateKeysTable: eventStateKeys, - EventJSONTable: eventJSON, - RoomsTable: rooms, - TransactionsTable: transactions, - StateBlockTable: stateBlock, - StateSnapshotTable: stateSnapshot, + DB: db, + Cache: cache, + Writer: sqlutil.NewExclusiveWriter(), + EventsTable: events, + EventTypesTable: eventTypes, + EventStateKeysTable: eventStateKeys, + EventJSONTable: eventJSON, + RoomsTable: rooms, + TransactionsTable: transactions, + StateTable: state, + // StateBlockTable: stateBlock, + // StateSnapshotTable: stateSnapshot, PrevEventsTable: prevEvents, RoomAliasesTable: roomAliases, InvitesTable: invites, diff --git a/roomserver/storage/tables/interface.go b/roomserver/storage/tables/interface.go index 26bf5cf0..9d5ef631 100644 --- a/roomserver/storage/tables/interface.go +++ b/roomserver/storage/tables/interface.go @@ -43,6 +43,7 @@ type Events interface { // bulkSelectStateEventByID lookups a list of state events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError BulkSelectStateEventByID(ctx context.Context, eventIDs []string) ([]types.StateEntry, error) + BulkSelectStateEventByNID(ctx context.Context, eventNIDs []types.EventNID) ([]types.StateEntry, error) // BulkSelectStateAtEventByID lookups the state at a list of events by event ID. // If any of the requested events are missing from the database it returns a types.MissingEventError. // If we do not have the state for any of the requested events it returns a types.MissingEventError. @@ -80,6 +81,12 @@ type Transactions interface { SelectTransactionEventID(ctx context.Context, transactionID string, sessionID int64, userID string) (eventID string, err error) } +type State interface { + InsertState(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID) (types.StateSnapshotNID, error) + BulkSelectState(ctx context.Context, stateNIDs []types.StateSnapshotNID) (map[types.StateSnapshotNID][]types.EventNID, error) +} + +/* type StateSnapshot interface { InsertState(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID) (stateNID types.StateSnapshotNID, err error) BulkSelectStateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) @@ -90,6 +97,7 @@ type StateBlock interface { BulkSelectStateBlockEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) BulkSelectFilteredStateBlockEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error) } +*/ type RoomAliases interface { InsertRoomAlias(ctx context.Context, txn *sql.Tx, alias string, roomID string, creatorUserID string) (err error) diff --git a/roomserver/types/types.go b/roomserver/types/types.go index e866f6cb..e8787c29 100644 --- a/roomserver/types/types.go +++ b/roomserver/types/types.go @@ -74,6 +74,24 @@ func (a StateEntry) LessThan(b StateEntry) bool { return a.EventNID < b.EventNID } +// Deduplicate takes a set of event NIDs and ensures that there are no +// duplicates. If there are then we dedupe them. +func DeduplicateEventNIDs(a []EventNID) []EventNID { + if len(a) < 2 { + return a + } + sort.SliceStable(a, func(i, j int) bool { + return a[i] < a[j] + }) + for i := 0; i < len(a)-1; i++ { + if a[i] == a[i+1] { + a = append(a[:i], a[i+1:]...) + i-- + } + } + return a +} + // Deduplicate takes a set of state entries and ensures that there are no // duplicate (event type, state key) tuples. If there are then we dedupe // them, making sure that the latest/highest NIDs are always chosen. @@ -151,18 +169,6 @@ const ( EmptyStateKeyNID = 1 ) -// StateBlockNIDList is used to return the result of bulk StateBlockNID lookups from the database. -type StateBlockNIDList struct { - StateSnapshotNID StateSnapshotNID - StateBlockNIDs []StateBlockNID -} - -// StateEntryList is used to return the result of bulk state entry lookups from the database. -type StateEntryList struct { - StateBlockNID StateBlockNID - StateEntries []StateEntry -} - // A MissingEventError is an error that happened because the roomserver was // missing requested events from its database. type MissingEventError string