mirror of
https://github.com/hoernschen/dendrite.git
synced 2024-12-27 23:48:27 +00:00
Initial work on simplified state storage
This commit is contained in:
parent
e08942fb00
commit
a799847070
21 changed files with 540 additions and 1297 deletions
2
go.mod
2
go.mod
|
@ -25,7 +25,7 @@ require (
|
||||||
github.com/matrix-org/gomatrixserverlib v0.0.0-20210302161955-6142fe3f8c2c
|
github.com/matrix-org/gomatrixserverlib v0.0.0-20210302161955-6142fe3f8c2c
|
||||||
github.com/matrix-org/naffka v0.0.0-20201009174903-d26a3b9cb161
|
github.com/matrix-org/naffka v0.0.0-20201009174903-d26a3b9cb161
|
||||||
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
|
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
|
||||||
github.com/mattn/go-sqlite3 v1.14.6
|
github.com/mattn/go-sqlite3 v1.14.7-0.20210414154423-1157a4212dcb
|
||||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
|
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
|
||||||
github.com/ngrok/sqlmw v0.0.0-20200129213757-d5c93a81bec6
|
github.com/ngrok/sqlmw v0.0.0-20200129213757-d5c93a81bec6
|
||||||
github.com/opentracing/opentracing-go v1.2.0
|
github.com/opentracing/opentracing-go v1.2.0
|
||||||
|
|
4
go.sum
4
go.sum
|
@ -684,8 +684,8 @@ github.com/mattn/go-isatty v0.0.10/go.mod h1:qgIWMr58cqv1PHHyhnkY9lrL7etaEgOFcME
|
||||||
github.com/mattn/go-runewidth v0.0.2/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU=
|
github.com/mattn/go-runewidth v0.0.2/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU=
|
||||||
github.com/mattn/go-runewidth v0.0.7/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
|
github.com/mattn/go-runewidth v0.0.7/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
|
||||||
github.com/mattn/go-sqlite3 v1.14.2/go.mod h1:JIl7NbARA7phWnGvh0LKTyg7S9BA+6gx71ShQilpsus=
|
github.com/mattn/go-sqlite3 v1.14.2/go.mod h1:JIl7NbARA7phWnGvh0LKTyg7S9BA+6gx71ShQilpsus=
|
||||||
github.com/mattn/go-sqlite3 v1.14.6 h1:dNPt6NO46WmLVt2DLNpwczCmdV5boIZ6g/tlDrlRUbg=
|
github.com/mattn/go-sqlite3 v1.14.7-0.20210414154423-1157a4212dcb h1:ax2vG2unlxsjwS7PMRo4FECIfAdQLowd6ejWYwPQhBo=
|
||||||
github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
|
github.com/mattn/go-sqlite3 v1.14.7-0.20210414154423-1157a4212dcb/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
|
||||||
github.com/mattn/goveralls v0.0.2/go.mod h1:8d1ZMHsd7fW6IRPKQh46F2WRpyib5/X4FOpevwGNQEw=
|
github.com/mattn/goveralls v0.0.2/go.mod h1:8d1ZMHsd7fW6IRPKQh46F2WRpyib5/X4FOpevwGNQEw=
|
||||||
github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU=
|
github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU=
|
||||||
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
|
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
|
||||||
|
|
|
@ -271,7 +271,7 @@ func (r *Inputer) calculateAndSetState(
|
||||||
}
|
}
|
||||||
entries = types.DeduplicateStateEntries(entries)
|
entries = types.DeduplicateStateEntries(entries)
|
||||||
|
|
||||||
if stateAtEvent.BeforeStateSnapshotNID, err = r.DB.AddState(ctx, roomInfo.RoomNID, nil, entries); err != nil {
|
if stateAtEvent.BeforeStateSnapshotNID, err = r.DB.AddState(ctx, roomInfo.RoomNID /*nil,*/, entries); err != nil {
|
||||||
return fmt.Errorf("r.DB.AddState: %w", err)
|
return fmt.Errorf("r.DB.AddState: %w", err)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -146,7 +146,7 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform
|
||||||
}
|
}
|
||||||
|
|
||||||
var beforeStateSnapshotNID types.StateSnapshotNID
|
var beforeStateSnapshotNID types.StateSnapshotNID
|
||||||
if beforeStateSnapshotNID, err = r.DB.AddState(ctx, roomNID, nil, entries); err != nil {
|
if beforeStateSnapshotNID, err = r.DB.AddState(ctx, roomNID /*nil,*/, entries); err != nil {
|
||||||
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to persist state entries to get snapshot nid")
|
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to persist state entries to get snapshot nid")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,6 +25,7 @@ import (
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage"
|
"github.com/matrix-org/dendrite/roomserver/storage"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
@ -50,32 +51,10 @@ func NewStateResolution(db storage.Database, roomInfo types.RoomInfo) StateResol
|
||||||
func (v *StateResolution) LoadStateAtSnapshot(
|
func (v *StateResolution) LoadStateAtSnapshot(
|
||||||
ctx context.Context, stateNID types.StateSnapshotNID,
|
ctx context.Context, stateNID types.StateSnapshotNID,
|
||||||
) ([]types.StateEntry, error) {
|
) ([]types.StateEntry, error) {
|
||||||
stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID})
|
fullState, err := v.db.StateEntries(ctx, stateNID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("v.db.StateEntries: %w", err)
|
||||||
}
|
}
|
||||||
// We've asked for exactly one snapshot from the db so we should have exactly one entry in the result.
|
|
||||||
stateBlockNIDList := stateBlockNIDLists[0]
|
|
||||||
|
|
||||||
stateEntryLists, err := v.db.StateEntries(ctx, stateBlockNIDList.StateBlockNIDs)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
stateEntriesMap := stateEntryListMap(stateEntryLists)
|
|
||||||
|
|
||||||
// Combine all the state entries for this snapshot.
|
|
||||||
// The order of state block NIDs in the list tells us the order to combine them in.
|
|
||||||
var fullState []types.StateEntry
|
|
||||||
for _, stateBlockNID := range stateBlockNIDList.StateBlockNIDs {
|
|
||||||
entries, ok := stateEntriesMap.lookup(stateBlockNID)
|
|
||||||
if !ok {
|
|
||||||
// This should only get hit if the database is corrupt.
|
|
||||||
// It should be impossible for an event to reference a NID that doesn't exist
|
|
||||||
panic(fmt.Errorf("Corrupt DB: Missing state block numeric ID %d", stateBlockNID))
|
|
||||||
}
|
|
||||||
fullState = append(fullState, entries...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stable sort so that the most recent entry for each state key stays
|
// Stable sort so that the most recent entry for each state key stays
|
||||||
// remains later in the list than the older entries for the same state key.
|
// remains later in the list than the older entries for the same state key.
|
||||||
sort.Stable(stateEntryByStateKeySorter(fullState))
|
sort.Stable(stateEntryByStateKeySorter(fullState))
|
||||||
|
@ -95,12 +74,10 @@ func (v *StateResolution) LoadStateAtEvent(
|
||||||
if snapshotNID == 0 {
|
if snapshotNID == 0 {
|
||||||
return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID(%s) returned 0 NID, was this event stored?", eventID)
|
return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID(%s) returned 0 NID, was this event stored?", eventID)
|
||||||
}
|
}
|
||||||
|
|
||||||
stateEntries, err := v.LoadStateAtSnapshot(ctx, snapshotNID)
|
stateEntries, err := v.LoadStateAtSnapshot(ctx, snapshotNID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("v.LoadStateAtSnapshot: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return stateEntries, nil
|
return stateEntries, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -114,53 +91,33 @@ func (v *StateResolution) LoadCombinedStateAfterEvents(
|
||||||
for i, state := range prevStates {
|
for i, state := range prevStates {
|
||||||
stateNIDs[i] = state.BeforeStateSnapshotNID
|
stateNIDs[i] = state.BeforeStateSnapshotNID
|
||||||
}
|
}
|
||||||
// Fetch the state snapshots for the state before the each prev event from the database.
|
|
||||||
// Deduplicate the IDs before passing them to the database.
|
|
||||||
// There could be duplicates because the events could be state events where
|
|
||||||
// the snapshot of the room state before them was the same.
|
|
||||||
stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, UniqueStateSnapshotNIDs(stateNIDs))
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("v.db.StateBlockNIDs: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var stateBlockNIDs []types.StateBlockNID
|
|
||||||
for _, list := range stateBlockNIDLists {
|
|
||||||
stateBlockNIDs = append(stateBlockNIDs, list.StateBlockNIDs...)
|
|
||||||
}
|
|
||||||
// Fetch the state entries that will be combined to create the snapshots.
|
// Fetch the state entries that will be combined to create the snapshots.
|
||||||
// Deduplicate the IDs before passing them to the database.
|
// Deduplicate the IDs before passing them to the database.
|
||||||
// There could be duplicates because a block of state entries could be reused by
|
// There could be duplicates because a block of state entries could be reused by
|
||||||
// multiple snapshots.
|
// multiple snapshots.
|
||||||
stateEntryLists, err := v.db.StateEntries(ctx, uniqueStateBlockNIDs(stateBlockNIDs))
|
stateEntriesMap := map[types.StateSnapshotNID][]types.StateEntry{}
|
||||||
|
for _, nid := range stateNIDs {
|
||||||
|
entries, err := v.db.StateEntries(ctx, nid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("v.db.StateEntries: %w", err)
|
return nil, fmt.Errorf("v.db.StateEntries: %w", err)
|
||||||
}
|
}
|
||||||
stateBlockNIDsMap := stateBlockNIDListMap(stateBlockNIDLists)
|
stateEntriesMap[nid] = entries
|
||||||
stateEntriesMap := stateEntryListMap(stateEntryLists)
|
}
|
||||||
|
|
||||||
// Combine the entries from all the snapshots of state after each prev event into a single list.
|
// Combine the entries from all the snapshots of state after each prev event into a single list.
|
||||||
var combined []types.StateEntry
|
var combined []types.StateEntry
|
||||||
for _, prevState := range prevStates {
|
for _, prevState := range prevStates {
|
||||||
// Grab the list of state data NIDs for this snapshot.
|
|
||||||
stateBlockNIDs, ok := stateBlockNIDsMap.lookup(prevState.BeforeStateSnapshotNID)
|
|
||||||
if !ok {
|
|
||||||
// This should only get hit if the database is corrupt.
|
|
||||||
// It should be impossible for an event to reference a NID that doesn't exist
|
|
||||||
panic(fmt.Errorf("Corrupt DB: Missing state snapshot numeric ID %d", prevState.BeforeStateSnapshotNID))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Combine all the state entries for this snapshot.
|
// Combine all the state entries for this snapshot.
|
||||||
// The order of state block NIDs in the list tells us the order to combine them in.
|
// The order of state block NIDs in the list tells us the order to combine them in.
|
||||||
var fullState []types.StateEntry
|
var fullState []types.StateEntry
|
||||||
for _, stateBlockNID := range stateBlockNIDs {
|
entries, ok := stateEntriesMap[prevState.BeforeStateSnapshotNID]
|
||||||
entries, ok := stateEntriesMap.lookup(stateBlockNID)
|
|
||||||
if !ok {
|
if !ok {
|
||||||
// This should only get hit if the database is corrupt.
|
// This should only get hit if the database is corrupt.
|
||||||
// It should be impossible for an event to reference a NID that doesn't exist
|
// It should be impossible for an event to reference a NID that doesn't exist
|
||||||
panic(fmt.Errorf("Corrupt DB: Missing state block numeric ID %d", stateBlockNID))
|
panic(fmt.Errorf("Corrupt DB: Missing state snapshot %d", prevState.BeforeStateSnapshotNID))
|
||||||
}
|
}
|
||||||
fullState = append(fullState, entries...)
|
fullState = append(fullState, entries...)
|
||||||
}
|
|
||||||
if prevState.IsStateEvent() && !prevState.IsRejected {
|
if prevState.IsStateEvent() && !prevState.IsRejected {
|
||||||
// If the prev event was a state event then add an entry for the event itself
|
// If the prev event was a state event then add an entry for the event itself
|
||||||
// so that we get the state after the event rather than the state before.
|
// so that we get the state after the event rather than the state before.
|
||||||
|
@ -192,13 +149,13 @@ func (v *StateResolution) DifferenceBetweeenStateSnapshots(
|
||||||
if oldStateNID != 0 {
|
if oldStateNID != 0 {
|
||||||
oldEntries, err = v.LoadStateAtSnapshot(ctx, oldStateNID)
|
oldEntries, err = v.LoadStateAtSnapshot(ctx, oldStateNID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, fmt.Errorf("v.LoadStateAtSnapshot: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if newStateNID != 0 {
|
if newStateNID != 0 {
|
||||||
newEntries, err = v.LoadStateAtSnapshot(ctx, newStateNID)
|
newEntries, err = v.LoadStateAtSnapshot(ctx, newStateNID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, fmt.Errorf("v.LoadStateAtSnapshot: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -299,33 +256,9 @@ func (v *StateResolution) loadStateAtSnapshotForNumericTuples(
|
||||||
stateNID types.StateSnapshotNID,
|
stateNID types.StateSnapshotNID,
|
||||||
stateKeyTuples []types.StateKeyTuple,
|
stateKeyTuples []types.StateKeyTuple,
|
||||||
) ([]types.StateEntry, error) {
|
) ([]types.StateEntry, error) {
|
||||||
stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID})
|
fullState, err := v.db.StateEntries(ctx, stateNID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("v.db.StateEntries: %w", err)
|
||||||
}
|
|
||||||
// We've asked for exactly one snapshot from the db so we should have exactly one entry in the result.
|
|
||||||
stateBlockNIDList := stateBlockNIDLists[0]
|
|
||||||
|
|
||||||
stateEntryLists, err := v.db.StateEntriesForTuples(
|
|
||||||
ctx, stateBlockNIDList.StateBlockNIDs, stateKeyTuples,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
stateEntriesMap := stateEntryListMap(stateEntryLists)
|
|
||||||
|
|
||||||
// Combine all the state entries for this snapshot.
|
|
||||||
// The order of state block NIDs in the list tells us the order to combine them in.
|
|
||||||
var fullState []types.StateEntry
|
|
||||||
for _, stateBlockNID := range stateBlockNIDList.StateBlockNIDs {
|
|
||||||
entries, ok := stateEntriesMap.lookup(stateBlockNID)
|
|
||||||
if !ok {
|
|
||||||
// If the block is missing from the map it means that none of its entries matched a requested tuple.
|
|
||||||
// This can happen if the block doesn't contain an update for one of the requested tuples.
|
|
||||||
// If none of the requested tuples are in the block then it can be safely skipped.
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
fullState = append(fullState, entries...)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stable sort so that the most recent entry for each state key stays
|
// Stable sort so that the most recent entry for each state key stays
|
||||||
|
@ -549,7 +482,8 @@ func (v *StateResolution) CalculateAndStoreStateAfterEvents(
|
||||||
// 2) There weren't any prev_events for this event so the state is
|
// 2) There weren't any prev_events for this event so the state is
|
||||||
// empty.
|
// empty.
|
||||||
metrics.algorithm = "empty_state"
|
metrics.algorithm = "empty_state"
|
||||||
stateNID, err := v.db.AddState(ctx, v.roomInfo.RoomNID, nil, nil)
|
stateNID, err := v.db.AddState(ctx, v.roomInfo.RoomNID /*nil,*/, nil)
|
||||||
|
logrus.Warnf("Empty prev state added state snapshot %d", stateNID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = fmt.Errorf("v.db.AddState: %w", err)
|
err = fmt.Errorf("v.db.AddState: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -566,6 +500,33 @@ func (v *StateResolution) CalculateAndStoreStateAfterEvents(
|
||||||
metrics.algorithm = "no_change"
|
metrics.algorithm = "no_change"
|
||||||
return metrics.stop(prevState.BeforeStateSnapshotNID, nil)
|
return metrics.stop(prevState.BeforeStateSnapshotNID, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
oldState, err := v.db.StateEntries(
|
||||||
|
ctx, prevState.BeforeStateSnapshotNID,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("v.db.StateEntries: %w", err)
|
||||||
|
}
|
||||||
|
found := false
|
||||||
|
for _, s := range oldState {
|
||||||
|
if s.EventNID == prevState.StateEntry.EventNID {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
oldState = append(oldState, prevState.StateEntry)
|
||||||
|
}
|
||||||
|
|
||||||
|
stateNID, err := v.db.AddState(
|
||||||
|
ctx, v.roomInfo.RoomNID, oldState,
|
||||||
|
)
|
||||||
|
logrus.Warnf("Single prev state added state snapshot %d", stateNID)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("v.db.AddState: %w", err)
|
||||||
|
}
|
||||||
|
return stateNID, nil
|
||||||
|
/*
|
||||||
// The previous event was a state event so we need to store a copy
|
// The previous event was a state event so we need to store a copy
|
||||||
// of the previous state updated with that event.
|
// of the previous state updated with that event.
|
||||||
stateBlockNIDLists, err := v.db.StateBlockNIDs(
|
stateBlockNIDLists, err := v.db.StateBlockNIDs(
|
||||||
|
@ -588,6 +549,7 @@ func (v *StateResolution) CalculateAndStoreStateAfterEvents(
|
||||||
}
|
}
|
||||||
return metrics.stop(stateNID, err)
|
return metrics.stop(stateNID, err)
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
// If there are too many deltas then we need to calculate the full state
|
// If there are too many deltas then we need to calculate the full state
|
||||||
// So fall through to calculateAndStoreStateAfterManyEvents
|
// So fall through to calculateAndStoreStateAfterManyEvents
|
||||||
}
|
}
|
||||||
|
@ -596,6 +558,7 @@ func (v *StateResolution) CalculateAndStoreStateAfterEvents(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("v.calculateAndStoreStateAfterManyEvents: %w", err)
|
return 0, fmt.Errorf("v.calculateAndStoreStateAfterManyEvents: %w", err)
|
||||||
}
|
}
|
||||||
|
logrus.Warnf("Multiple prev states added state snapshot %d", stateNID)
|
||||||
return stateNID, nil
|
return stateNID, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -626,7 +589,7 @@ func (v *StateResolution) calculateAndStoreStateAfterManyEvents(
|
||||||
// previous state.
|
// previous state.
|
||||||
metrics.conflictLength = conflictLength
|
metrics.conflictLength = conflictLength
|
||||||
metrics.fullStateLength = len(state)
|
metrics.fullStateLength = len(state)
|
||||||
return metrics.stop(v.db.AddState(ctx, roomNID, nil, state))
|
return metrics.stop(v.db.AddState(ctx, roomNID /*nil,*/, state))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (v *StateResolution) calculateStateAfterManyEvents(
|
func (v *StateResolution) calculateStateAfterManyEvents(
|
||||||
|
@ -996,34 +959,6 @@ func (s stateEntrySorter) Len() int { return len(s) }
|
||||||
func (s stateEntrySorter) Less(i, j int) bool { return s[i].LessThan(s[j]) }
|
func (s stateEntrySorter) Less(i, j int) bool { return s[i].LessThan(s[j]) }
|
||||||
func (s stateEntrySorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
func (s stateEntrySorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
||||||
|
|
||||||
type stateBlockNIDListMap []types.StateBlockNIDList
|
|
||||||
|
|
||||||
func (m stateBlockNIDListMap) lookup(stateNID types.StateSnapshotNID) (stateBlockNIDs []types.StateBlockNID, ok bool) {
|
|
||||||
list := []types.StateBlockNIDList(m)
|
|
||||||
i := sort.Search(len(list), func(i int) bool {
|
|
||||||
return list[i].StateSnapshotNID >= stateNID
|
|
||||||
})
|
|
||||||
if i < len(list) && list[i].StateSnapshotNID == stateNID {
|
|
||||||
ok = true
|
|
||||||
stateBlockNIDs = list[i].StateBlockNIDs
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
type stateEntryListMap []types.StateEntryList
|
|
||||||
|
|
||||||
func (m stateEntryListMap) lookup(stateBlockNID types.StateBlockNID) (stateEntries []types.StateEntry, ok bool) {
|
|
||||||
list := []types.StateEntryList(m)
|
|
||||||
i := sort.Search(len(list), func(i int) bool {
|
|
||||||
return list[i].StateBlockNID >= stateBlockNID
|
|
||||||
})
|
|
||||||
if i < len(list) && list[i].StateBlockNID == stateBlockNID {
|
|
||||||
ok = true
|
|
||||||
stateEntries = list[i].StateEntries
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
type stateEntryByStateKeySorter []types.StateEntry
|
type stateEntryByStateKeySorter []types.StateEntry
|
||||||
|
|
||||||
func (s stateEntryByStateKeySorter) Len() int { return len(s) }
|
func (s stateEntryByStateKeySorter) Len() int { return len(s) }
|
||||||
|
|
|
@ -33,7 +33,6 @@ type Database interface {
|
||||||
AddState(
|
AddState(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
roomNID types.RoomNID,
|
roomNID types.RoomNID,
|
||||||
stateBlockNIDs []types.StateBlockNID,
|
|
||||||
state []types.StateEntry,
|
state []types.StateEntry,
|
||||||
) (types.StateSnapshotNID, error)
|
) (types.StateSnapshotNID, error)
|
||||||
// Look up the state of a room at each event for a list of string event IDs.
|
// Look up the state of a room at each event for a list of string event IDs.
|
||||||
|
@ -47,21 +46,9 @@ type Database interface {
|
||||||
// Look up the numeric IDs for a list of string event state keys.
|
// Look up the numeric IDs for a list of string event state keys.
|
||||||
// Returns a map from string state key to numeric ID for the state key.
|
// Returns a map from string state key to numeric ID for the state key.
|
||||||
EventStateKeyNIDs(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error)
|
EventStateKeyNIDs(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error)
|
||||||
// Look up the numeric state data IDs for each numeric state snapshot ID
|
|
||||||
// The returned slice is sorted by numeric state snapshot ID.
|
|
||||||
StateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error)
|
|
||||||
// Look up the state data for each numeric state data ID
|
// Look up the state data for each numeric state data ID
|
||||||
// The returned slice is sorted by numeric state data ID.
|
// The returned slice is sorted by numeric state data ID.
|
||||||
StateEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error)
|
StateEntries(ctx context.Context, stateSnapshotNID types.StateSnapshotNID) ([]types.StateEntry, error)
|
||||||
// Look up the state data for the state key tuples for each numeric state block ID
|
|
||||||
// This is used to fetch a subset of the room state at a snapshot.
|
|
||||||
// If a block doesn't contain any of the requested tuples then it can be discarded from the result.
|
|
||||||
// The returned slice is sorted by numeric state block ID.
|
|
||||||
StateEntriesForTuples(
|
|
||||||
ctx context.Context,
|
|
||||||
stateBlockNIDs []types.StateBlockNID,
|
|
||||||
stateKeyTuples []types.StateKeyTuple,
|
|
||||||
) ([]types.StateEntryList, error)
|
|
||||||
// Look up the Events for a list of numeric event IDs.
|
// Look up the Events for a list of numeric event IDs.
|
||||||
// Returns a sorted list of events.
|
// Returns a sorted list of events.
|
||||||
Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error)
|
Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error)
|
||||||
|
|
|
@ -88,6 +88,12 @@ const bulkSelectStateEventByIDSQL = "" +
|
||||||
" WHERE event_id = ANY($1)" +
|
" WHERE event_id = ANY($1)" +
|
||||||
" ORDER BY event_type_nid, event_state_key_nid ASC"
|
" ORDER BY event_type_nid, event_state_key_nid ASC"
|
||||||
|
|
||||||
|
// Bulk lookup of events by numeric ID.
|
||||||
|
const bulkSelectStateEventByNIDSQL = "" +
|
||||||
|
"SELECT event_type_nid, event_state_key_nid, event_nid FROM roomserver_events" +
|
||||||
|
" WHERE event_nid = ANY($1)" +
|
||||||
|
" ORDER BY event_type_nid, event_state_key_nid ASC"
|
||||||
|
|
||||||
const bulkSelectStateAtEventByIDSQL = "" +
|
const bulkSelectStateAtEventByIDSQL = "" +
|
||||||
"SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, is_rejected FROM roomserver_events" +
|
"SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, is_rejected FROM roomserver_events" +
|
||||||
" WHERE event_id = ANY($1)"
|
" WHERE event_id = ANY($1)"
|
||||||
|
@ -127,6 +133,7 @@ type eventStatements struct {
|
||||||
insertEventStmt *sql.Stmt
|
insertEventStmt *sql.Stmt
|
||||||
selectEventStmt *sql.Stmt
|
selectEventStmt *sql.Stmt
|
||||||
bulkSelectStateEventByIDStmt *sql.Stmt
|
bulkSelectStateEventByIDStmt *sql.Stmt
|
||||||
|
bulkSelectStateEventByNIDStmt *sql.Stmt
|
||||||
bulkSelectStateAtEventByIDStmt *sql.Stmt
|
bulkSelectStateAtEventByIDStmt *sql.Stmt
|
||||||
updateEventStateStmt *sql.Stmt
|
updateEventStateStmt *sql.Stmt
|
||||||
selectEventSentToOutputStmt *sql.Stmt
|
selectEventSentToOutputStmt *sql.Stmt
|
||||||
|
@ -151,6 +158,7 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) {
|
||||||
{&s.insertEventStmt, insertEventSQL},
|
{&s.insertEventStmt, insertEventSQL},
|
||||||
{&s.selectEventStmt, selectEventSQL},
|
{&s.selectEventStmt, selectEventSQL},
|
||||||
{&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL},
|
{&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL},
|
||||||
|
{&s.bulkSelectStateEventByNIDStmt, bulkSelectStateEventByNIDSQL},
|
||||||
{&s.bulkSelectStateAtEventByIDStmt, bulkSelectStateAtEventByIDSQL},
|
{&s.bulkSelectStateAtEventByIDStmt, bulkSelectStateAtEventByIDSQL},
|
||||||
{&s.updateEventStateStmt, updateEventStateSQL},
|
{&s.updateEventStateStmt, updateEventStateSQL},
|
||||||
{&s.updateEventSentToOutputStmt, updateEventSentToOutputSQL},
|
{&s.updateEventSentToOutputStmt, updateEventSentToOutputSQL},
|
||||||
|
@ -238,6 +246,48 @@ func (s *eventStatements) BulkSelectStateEventByID(
|
||||||
return results, nil
|
return results, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// bulkSelectStateEventByNID lookups a list of state events by internal numeric ID.
|
||||||
|
// If any of the requested events are missing from the database it returns a types.MissingEventError
|
||||||
|
func (s *eventStatements) BulkSelectStateEventByNID(
|
||||||
|
ctx context.Context, eventNIDs []types.EventNID,
|
||||||
|
) ([]types.StateEntry, error) {
|
||||||
|
rows, err := s.bulkSelectStateEventByNIDStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateEventByID: rows.close() failed")
|
||||||
|
// We know that we will only get as many results as event NIDs
|
||||||
|
// because of the unique constraint on event NIDs.
|
||||||
|
// So we can allocate an array of the correct size now.
|
||||||
|
// We might get fewer results than NIDs so we adjust the length of the slice before returning it.
|
||||||
|
results := make([]types.StateEntry, len(eventNIDs))
|
||||||
|
i := 0
|
||||||
|
for ; rows.Next(); i++ {
|
||||||
|
result := &results[i]
|
||||||
|
if err = rows.Scan(
|
||||||
|
&result.EventTypeNID,
|
||||||
|
&result.EventStateKeyNID,
|
||||||
|
&result.EventNID,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err = rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if i != len(eventNIDs) {
|
||||||
|
// If there are fewer rows returned than IDs then we were asked to lookup event IDs we don't have.
|
||||||
|
// We don't know which ones were missing because we don't return the string IDs in the query.
|
||||||
|
// However it should be possible debug this by replaying queries or entries from the input kafka logs.
|
||||||
|
// If this turns out to be impossible and we do need the debug information here, it would be better
|
||||||
|
// to do it as a separate query rather than slowing down/complicating the internal case.
|
||||||
|
return nil, types.MissingEventError(
|
||||||
|
fmt.Sprintf("storage: state event IDs missing from the database (%d != %d)", i, len(eventNIDs)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return results, nil
|
||||||
|
}
|
||||||
|
|
||||||
// bulkSelectStateAtEventByID lookups the state at a list of events by event ID.
|
// bulkSelectStateAtEventByID lookups the state at a list of events by event ID.
|
||||||
// If any of the requested events are missing from the database it returns a types.MissingEventError.
|
// If any of the requested events are missing from the database it returns a types.MissingEventError.
|
||||||
// If we do not have the state for any of the requested events it returns a types.MissingEventError.
|
// If we do not have the state for any of the requested events it returns a types.MissingEventError.
|
||||||
|
|
|
@ -1,292 +0,0 @@
|
||||||
// Copyright 2017-2018 New Vector Ltd
|
|
||||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package postgres
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
|
||||||
"sort"
|
|
||||||
|
|
||||||
"github.com/lib/pq"
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
|
||||||
"github.com/matrix-org/util"
|
|
||||||
)
|
|
||||||
|
|
||||||
const stateDataSchema = `
|
|
||||||
-- The state data map.
|
|
||||||
-- Designed to give enough information to run the state resolution algorithm
|
|
||||||
-- without hitting the database in the internal case.
|
|
||||||
-- TODO: Is it worth replacing the unique btree index with a covering index so
|
|
||||||
-- that postgres could lookup the state using an index-only scan?
|
|
||||||
-- The type and state_key are included in the index to make it easier to
|
|
||||||
-- lookup a specific (type, state_key) pair for an event. It also makes it easy
|
|
||||||
-- to read the state for a given state_block_nid ordered by (type, state_key)
|
|
||||||
-- which in turn makes it easier to merge state data blocks.
|
|
||||||
CREATE SEQUENCE IF NOT EXISTS roomserver_state_block_nid_seq;
|
|
||||||
CREATE TABLE IF NOT EXISTS roomserver_state_block (
|
|
||||||
-- Local numeric ID for this state data.
|
|
||||||
state_block_nid bigint NOT NULL,
|
|
||||||
event_type_nid bigint NOT NULL,
|
|
||||||
event_state_key_nid bigint NOT NULL,
|
|
||||||
event_nid bigint NOT NULL,
|
|
||||||
UNIQUE (state_block_nid, event_type_nid, event_state_key_nid)
|
|
||||||
);
|
|
||||||
`
|
|
||||||
|
|
||||||
const insertStateDataSQL = "" +
|
|
||||||
"INSERT INTO roomserver_state_block (state_block_nid, event_type_nid, event_state_key_nid, event_nid)" +
|
|
||||||
" VALUES ($1, $2, $3, $4)"
|
|
||||||
|
|
||||||
const selectNextStateBlockNIDSQL = "" +
|
|
||||||
"SELECT nextval('roomserver_state_block_nid_seq')"
|
|
||||||
|
|
||||||
// Bulk state lookup by numeric state block ID.
|
|
||||||
// Sort by the state_block_nid, event_type_nid, event_state_key_nid
|
|
||||||
// This means that all the entries for a given state_block_nid will appear
|
|
||||||
// together in the list and those entries will sorted by event_type_nid
|
|
||||||
// and event_state_key_nid. This property makes it easier to merge two
|
|
||||||
// state data blocks together.
|
|
||||||
const bulkSelectStateBlockEntriesSQL = "" +
|
|
||||||
"SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" +
|
|
||||||
" FROM roomserver_state_block WHERE state_block_nid = ANY($1)" +
|
|
||||||
" ORDER BY state_block_nid, event_type_nid, event_state_key_nid"
|
|
||||||
|
|
||||||
// Bulk state lookup by numeric state block ID.
|
|
||||||
// Filters the rows in each block to the requested types and state keys.
|
|
||||||
// We would like to restrict to particular type state key pairs but we are
|
|
||||||
// restricted by the query language to pull the cross product of a list
|
|
||||||
// of types and a list state_keys. So we have to filter the result in the
|
|
||||||
// application to restrict it to the list of event types and state keys we
|
|
||||||
// actually wanted.
|
|
||||||
const bulkSelectFilteredStateBlockEntriesSQL = "" +
|
|
||||||
"SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" +
|
|
||||||
" FROM roomserver_state_block WHERE state_block_nid = ANY($1)" +
|
|
||||||
" AND event_type_nid = ANY($2) AND event_state_key_nid = ANY($3)" +
|
|
||||||
" ORDER BY state_block_nid, event_type_nid, event_state_key_nid"
|
|
||||||
|
|
||||||
type stateBlockStatements struct {
|
|
||||||
insertStateDataStmt *sql.Stmt
|
|
||||||
selectNextStateBlockNIDStmt *sql.Stmt
|
|
||||||
bulkSelectStateBlockEntriesStmt *sql.Stmt
|
|
||||||
bulkSelectFilteredStateBlockEntriesStmt *sql.Stmt
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewPostgresStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
|
|
||||||
s := &stateBlockStatements{}
|
|
||||||
_, err := db.Exec(stateDataSchema)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return s, shared.StatementList{
|
|
||||||
{&s.insertStateDataStmt, insertStateDataSQL},
|
|
||||||
{&s.selectNextStateBlockNIDStmt, selectNextStateBlockNIDSQL},
|
|
||||||
{&s.bulkSelectStateBlockEntriesStmt, bulkSelectStateBlockEntriesSQL},
|
|
||||||
{&s.bulkSelectFilteredStateBlockEntriesStmt, bulkSelectFilteredStateBlockEntriesSQL},
|
|
||||||
}.Prepare(db)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *stateBlockStatements) BulkInsertStateData(
|
|
||||||
ctx context.Context,
|
|
||||||
txn *sql.Tx,
|
|
||||||
entries []types.StateEntry,
|
|
||||||
) (types.StateBlockNID, error) {
|
|
||||||
stateBlockNID, err := s.selectNextStateBlockNID(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
for _, entry := range entries {
|
|
||||||
_, err := s.insertStateDataStmt.ExecContext(
|
|
||||||
ctx,
|
|
||||||
int64(stateBlockNID),
|
|
||||||
int64(entry.EventTypeNID),
|
|
||||||
int64(entry.EventStateKeyNID),
|
|
||||||
int64(entry.EventNID),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return stateBlockNID, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *stateBlockStatements) selectNextStateBlockNID(
|
|
||||||
ctx context.Context,
|
|
||||||
) (types.StateBlockNID, error) {
|
|
||||||
var stateBlockNID int64
|
|
||||||
err := s.selectNextStateBlockNIDStmt.QueryRowContext(ctx).Scan(&stateBlockNID)
|
|
||||||
return types.StateBlockNID(stateBlockNID), err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *stateBlockStatements) BulkSelectStateBlockEntries(
|
|
||||||
ctx context.Context, stateBlockNIDs []types.StateBlockNID,
|
|
||||||
) ([]types.StateEntryList, error) {
|
|
||||||
nids := make([]int64, len(stateBlockNIDs))
|
|
||||||
for i := range stateBlockNIDs {
|
|
||||||
nids[i] = int64(stateBlockNIDs[i])
|
|
||||||
}
|
|
||||||
rows, err := s.bulkSelectStateBlockEntriesStmt.QueryContext(ctx, pq.Int64Array(nids))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockEntries: rows.close() failed")
|
|
||||||
|
|
||||||
results := make([]types.StateEntryList, len(stateBlockNIDs))
|
|
||||||
// current is a pointer to the StateEntryList to append the state entries to.
|
|
||||||
var current *types.StateEntryList
|
|
||||||
i := 0
|
|
||||||
for rows.Next() {
|
|
||||||
var (
|
|
||||||
stateBlockNID int64
|
|
||||||
eventTypeNID int64
|
|
||||||
eventStateKeyNID int64
|
|
||||||
eventNID int64
|
|
||||||
entry types.StateEntry
|
|
||||||
)
|
|
||||||
if err = rows.Scan(
|
|
||||||
&stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID,
|
|
||||||
); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
entry.EventTypeNID = types.EventTypeNID(eventTypeNID)
|
|
||||||
entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID)
|
|
||||||
entry.EventNID = types.EventNID(eventNID)
|
|
||||||
if current == nil || types.StateBlockNID(stateBlockNID) != current.StateBlockNID {
|
|
||||||
// The state entry row is for a different state data block to the current one.
|
|
||||||
// So we start appending to the next entry in the list.
|
|
||||||
current = &results[i]
|
|
||||||
current.StateBlockNID = types.StateBlockNID(stateBlockNID)
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
current.StateEntries = append(current.StateEntries, entry)
|
|
||||||
}
|
|
||||||
if err = rows.Err(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if i != len(stateBlockNIDs) {
|
|
||||||
return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", i, len(stateBlockNIDs))
|
|
||||||
}
|
|
||||||
return results, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *stateBlockStatements) BulkSelectFilteredStateBlockEntries(
|
|
||||||
ctx context.Context,
|
|
||||||
stateBlockNIDs []types.StateBlockNID,
|
|
||||||
stateKeyTuples []types.StateKeyTuple,
|
|
||||||
) ([]types.StateEntryList, error) {
|
|
||||||
tuples := stateKeyTupleSorter(stateKeyTuples)
|
|
||||||
// Sort the tuples so that we can run binary search against them as we filter the rows returned by the db.
|
|
||||||
sort.Sort(tuples)
|
|
||||||
|
|
||||||
eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays()
|
|
||||||
rows, err := s.bulkSelectFilteredStateBlockEntriesStmt.QueryContext(
|
|
||||||
ctx,
|
|
||||||
stateBlockNIDsAsArray(stateBlockNIDs),
|
|
||||||
eventTypeNIDArray,
|
|
||||||
eventStateKeyNIDArray,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectFilteredStateBlockEntries: rows.close() failed")
|
|
||||||
|
|
||||||
var results []types.StateEntryList
|
|
||||||
var current types.StateEntryList
|
|
||||||
for rows.Next() {
|
|
||||||
var (
|
|
||||||
stateBlockNID int64
|
|
||||||
eventTypeNID int64
|
|
||||||
eventStateKeyNID int64
|
|
||||||
eventNID int64
|
|
||||||
entry types.StateEntry
|
|
||||||
)
|
|
||||||
if err := rows.Scan(
|
|
||||||
&stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID,
|
|
||||||
); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
entry.EventTypeNID = types.EventTypeNID(eventTypeNID)
|
|
||||||
entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID)
|
|
||||||
entry.EventNID = types.EventNID(eventNID)
|
|
||||||
|
|
||||||
// We can use binary search here because we sorted the tuples earlier
|
|
||||||
if !tuples.contains(entry.StateKeyTuple) {
|
|
||||||
// The select will return the cross product of types and state keys.
|
|
||||||
// So we need to check if type of the entry is in the list.
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if types.StateBlockNID(stateBlockNID) != current.StateBlockNID {
|
|
||||||
// The state entry row is for a different state data block to the current one.
|
|
||||||
// So we append the current entry to the results and start adding to a new one.
|
|
||||||
// The first time through the loop current will be empty.
|
|
||||||
if current.StateEntries != nil {
|
|
||||||
results = append(results, current)
|
|
||||||
}
|
|
||||||
current = types.StateEntryList{StateBlockNID: types.StateBlockNID(stateBlockNID)}
|
|
||||||
}
|
|
||||||
current.StateEntries = append(current.StateEntries, entry)
|
|
||||||
}
|
|
||||||
// Add the last entry to the list if it is not empty.
|
|
||||||
if current.StateEntries != nil {
|
|
||||||
results = append(results, current)
|
|
||||||
}
|
|
||||||
return results, rows.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
func stateBlockNIDsAsArray(stateBlockNIDs []types.StateBlockNID) pq.Int64Array {
|
|
||||||
nids := make([]int64, len(stateBlockNIDs))
|
|
||||||
for i := range stateBlockNIDs {
|
|
||||||
nids[i] = int64(stateBlockNIDs[i])
|
|
||||||
}
|
|
||||||
return pq.Int64Array(nids)
|
|
||||||
}
|
|
||||||
|
|
||||||
type stateKeyTupleSorter []types.StateKeyTuple
|
|
||||||
|
|
||||||
func (s stateKeyTupleSorter) Len() int { return len(s) }
|
|
||||||
func (s stateKeyTupleSorter) Less(i, j int) bool { return s[i].LessThan(s[j]) }
|
|
||||||
func (s stateKeyTupleSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
|
||||||
|
|
||||||
// Check whether a tuple is in the list. Assumes that the list is sorted.
|
|
||||||
func (s stateKeyTupleSorter) contains(value types.StateKeyTuple) bool {
|
|
||||||
i := sort.Search(len(s), func(i int) bool { return !s[i].LessThan(value) })
|
|
||||||
return i < len(s) && s[i] == value
|
|
||||||
}
|
|
||||||
|
|
||||||
// List the unique eventTypeNIDs and eventStateKeyNIDs.
|
|
||||||
// Assumes that the list is sorted.
|
|
||||||
func (s stateKeyTupleSorter) typesAndStateKeysAsArrays() (eventTypeNIDs pq.Int64Array, eventStateKeyNIDs pq.Int64Array) {
|
|
||||||
eventTypeNIDs = make(pq.Int64Array, len(s))
|
|
||||||
eventStateKeyNIDs = make(pq.Int64Array, len(s))
|
|
||||||
for i := range s {
|
|
||||||
eventTypeNIDs[i] = int64(s[i].EventTypeNID)
|
|
||||||
eventStateKeyNIDs[i] = int64(s[i].EventStateKeyNID)
|
|
||||||
}
|
|
||||||
eventTypeNIDs = eventTypeNIDs[:util.SortAndUnique(int64Sorter(eventTypeNIDs))]
|
|
||||||
eventStateKeyNIDs = eventStateKeyNIDs[:util.SortAndUnique(int64Sorter(eventStateKeyNIDs))]
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
type int64Sorter []int64
|
|
||||||
|
|
||||||
func (s int64Sorter) Len() int { return len(s) }
|
|
||||||
func (s int64Sorter) Less(i, j int) bool { return s[i] < s[j] }
|
|
||||||
func (s int64Sorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
|
|
@ -1,86 +0,0 @@
|
||||||
// Copyright 2017-2018 New Vector Ltd
|
|
||||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package postgres
|
|
||||||
|
|
||||||
import (
|
|
||||||
"sort"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestStateKeyTupleSorter(t *testing.T) {
|
|
||||||
input := stateKeyTupleSorter{
|
|
||||||
{EventTypeNID: 1, EventStateKeyNID: 2},
|
|
||||||
{EventTypeNID: 1, EventStateKeyNID: 4},
|
|
||||||
{EventTypeNID: 2, EventStateKeyNID: 2},
|
|
||||||
{EventTypeNID: 1, EventStateKeyNID: 1},
|
|
||||||
}
|
|
||||||
want := []types.StateKeyTuple{
|
|
||||||
{EventTypeNID: 1, EventStateKeyNID: 1},
|
|
||||||
{EventTypeNID: 1, EventStateKeyNID: 2},
|
|
||||||
{EventTypeNID: 1, EventStateKeyNID: 4},
|
|
||||||
{EventTypeNID: 2, EventStateKeyNID: 2},
|
|
||||||
}
|
|
||||||
doNotWant := []types.StateKeyTuple{
|
|
||||||
{EventTypeNID: 0, EventStateKeyNID: 0},
|
|
||||||
{EventTypeNID: 1, EventStateKeyNID: 3},
|
|
||||||
{EventTypeNID: 2, EventStateKeyNID: 1},
|
|
||||||
{EventTypeNID: 3, EventStateKeyNID: 1},
|
|
||||||
}
|
|
||||||
wantTypeNIDs := []int64{1, 2}
|
|
||||||
wantStateKeyNIDs := []int64{1, 2, 4}
|
|
||||||
|
|
||||||
// Sort the input and check it's in the right order.
|
|
||||||
sort.Sort(input)
|
|
||||||
gotTypeNIDs, gotStateKeyNIDs := input.typesAndStateKeysAsArrays()
|
|
||||||
|
|
||||||
for i := range want {
|
|
||||||
if input[i] != want[i] {
|
|
||||||
t.Errorf("Wanted %#v at index %d got %#v", want[i], i, input[i])
|
|
||||||
}
|
|
||||||
|
|
||||||
if !input.contains(want[i]) {
|
|
||||||
t.Errorf("Wanted %#v.contains(%#v) to be true but got false", input, want[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range doNotWant {
|
|
||||||
if input.contains(doNotWant[i]) {
|
|
||||||
t.Errorf("Wanted %#v.contains(%#v) to be false but got true", input, doNotWant[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(wantTypeNIDs) != len(gotTypeNIDs) {
|
|
||||||
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range wantTypeNIDs {
|
|
||||||
if wantTypeNIDs[i] != gotTypeNIDs[i] {
|
|
||||||
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(wantStateKeyNIDs) != len(gotStateKeyNIDs) {
|
|
||||||
t.Fatalf("Wanted state key NIDs %#v got %#v", wantStateKeyNIDs, gotStateKeyNIDs)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range wantStateKeyNIDs {
|
|
||||||
if wantStateKeyNIDs[i] != gotStateKeyNIDs[i] {
|
|
||||||
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,126 +0,0 @@
|
||||||
// Copyright 2017-2018 New Vector Ltd
|
|
||||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package postgres
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"github.com/lib/pq"
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
|
||||||
)
|
|
||||||
|
|
||||||
const stateSnapshotSchema = `
|
|
||||||
-- The state of a room before an event.
|
|
||||||
-- Stored as a list of state_block entries stored in a separate table.
|
|
||||||
-- The actual state is constructed by combining all the state_block entries
|
|
||||||
-- referenced by state_block_nids together. If the same state key tuple appears
|
|
||||||
-- multiple times then the entry from the later state_block clobbers the earlier
|
|
||||||
-- entries.
|
|
||||||
-- This encoding format allows us to implement a delta encoding which is useful
|
|
||||||
-- because room state tends to accumulate small changes over time. Although if
|
|
||||||
-- the list of deltas becomes too long it becomes more efficient to encode
|
|
||||||
-- the full state under single state_block_nid.
|
|
||||||
CREATE SEQUENCE IF NOT EXISTS roomserver_state_snapshot_nid_seq;
|
|
||||||
CREATE TABLE IF NOT EXISTS roomserver_state_snapshots (
|
|
||||||
-- Local numeric ID for the state.
|
|
||||||
state_snapshot_nid bigint PRIMARY KEY DEFAULT nextval('roomserver_state_snapshot_nid_seq'),
|
|
||||||
-- Local numeric ID of the room this state is for.
|
|
||||||
-- Unused in normal operation, but useful for background work or ad-hoc debugging.
|
|
||||||
room_nid bigint NOT NULL,
|
|
||||||
-- List of state_block_nids, stored sorted by state_block_nid.
|
|
||||||
state_block_nids bigint[] NOT NULL
|
|
||||||
);
|
|
||||||
`
|
|
||||||
|
|
||||||
const insertStateSQL = "" +
|
|
||||||
"INSERT INTO roomserver_state_snapshots (room_nid, state_block_nids)" +
|
|
||||||
" VALUES ($1, $2)" +
|
|
||||||
" RETURNING state_snapshot_nid"
|
|
||||||
|
|
||||||
// Bulk state data NID lookup.
|
|
||||||
// Sorting by state_snapshot_nid means we can use binary search over the result
|
|
||||||
// to lookup the state data NIDs for a state snapshot NID.
|
|
||||||
const bulkSelectStateBlockNIDsSQL = "" +
|
|
||||||
"SELECT state_snapshot_nid, state_block_nids FROM roomserver_state_snapshots" +
|
|
||||||
" WHERE state_snapshot_nid = ANY($1) ORDER BY state_snapshot_nid ASC"
|
|
||||||
|
|
||||||
type stateSnapshotStatements struct {
|
|
||||||
insertStateStmt *sql.Stmt
|
|
||||||
bulkSelectStateBlockNIDsStmt *sql.Stmt
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewPostgresStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
|
|
||||||
s := &stateSnapshotStatements{}
|
|
||||||
_, err := db.Exec(stateSnapshotSchema)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return s, shared.StatementList{
|
|
||||||
{&s.insertStateStmt, insertStateSQL},
|
|
||||||
{&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL},
|
|
||||||
}.Prepare(db)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *stateSnapshotStatements) InsertState(
|
|
||||||
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID,
|
|
||||||
) (stateNID types.StateSnapshotNID, err error) {
|
|
||||||
nids := make([]int64, len(stateBlockNIDs))
|
|
||||||
for i := range stateBlockNIDs {
|
|
||||||
nids[i] = int64(stateBlockNIDs[i])
|
|
||||||
}
|
|
||||||
err = sqlutil.TxStmt(txn, s.insertStateStmt).QueryRowContext(ctx, int64(roomNID), pq.Int64Array(nids)).Scan(&stateNID)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
|
|
||||||
ctx context.Context, stateNIDs []types.StateSnapshotNID,
|
|
||||||
) ([]types.StateBlockNIDList, error) {
|
|
||||||
nids := make([]int64, len(stateNIDs))
|
|
||||||
for i := range stateNIDs {
|
|
||||||
nids[i] = int64(stateNIDs[i])
|
|
||||||
}
|
|
||||||
rows, err := s.bulkSelectStateBlockNIDsStmt.QueryContext(ctx, pq.Int64Array(nids))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer rows.Close() // nolint: errcheck
|
|
||||||
results := make([]types.StateBlockNIDList, len(stateNIDs))
|
|
||||||
i := 0
|
|
||||||
for ; rows.Next(); i++ {
|
|
||||||
result := &results[i]
|
|
||||||
var stateBlockNIDs pq.Int64Array
|
|
||||||
if err = rows.Scan(&result.StateSnapshotNID, &stateBlockNIDs); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
result.StateBlockNIDs = make([]types.StateBlockNID, len(stateBlockNIDs))
|
|
||||||
for k := range stateBlockNIDs {
|
|
||||||
result.StateBlockNIDs[k] = types.StateBlockNID(stateBlockNIDs[k])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if err = rows.Err(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if i != len(stateNIDs) {
|
|
||||||
return nil, fmt.Errorf("storage: state NIDs missing from the database (%d != %d)", i, len(stateNIDs))
|
|
||||||
}
|
|
||||||
return results, nil
|
|
||||||
}
|
|
124
roomserver/storage/postgres/state_table.go
Normal file
124
roomserver/storage/postgres/state_table.go
Normal file
|
@ -0,0 +1,124 @@
|
||||||
|
// Copyright 2017-2018 New Vector Ltd
|
||||||
|
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package postgres
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/lib/pq"
|
||||||
|
"github.com/matrix-org/dendrite/internal"
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
const stateSchema = `
|
||||||
|
CREATE TABLE IF NOT EXISTS roomserver_state (
|
||||||
|
state_nid BIGSERIAL PRIMARY KEY,
|
||||||
|
room_nid bigint NOT NULL,
|
||||||
|
event_nids bigint[] NOT NULL,
|
||||||
|
UNIQUE (room_nid, event_nids),
|
||||||
|
CONSTRAINT fk_room_id FOREIGN KEY(room_nid) REFERENCES roomserver_rooms(room_nid)
|
||||||
|
);
|
||||||
|
`
|
||||||
|
|
||||||
|
const insertNewStateSnapshotSQL = "" +
|
||||||
|
"INSERT INTO roomserver_state (room_nid, event_nids)" +
|
||||||
|
" VALUES ($1, $2)" +
|
||||||
|
" ON CONFLICT (room_nid, event_nids) DO UPDATE SET room_nid = $1" +
|
||||||
|
" RETURNING state_nid"
|
||||||
|
|
||||||
|
const bulkSelectNewStateSnapshotSQL = "" +
|
||||||
|
"SELECT state_nid, event_nids" +
|
||||||
|
" FROM roomserver_state WHERE state_nid = ANY($1)" +
|
||||||
|
" ORDER BY state_nid"
|
||||||
|
|
||||||
|
type stateStatements struct {
|
||||||
|
insertStateStmt *sql.Stmt
|
||||||
|
bulkSelectStateStmt *sql.Stmt
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPostgresStateTable(db *sql.DB) (tables.State, error) {
|
||||||
|
s := &stateStatements{}
|
||||||
|
_, err := db.Exec(stateSchema)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return s, shared.StatementList{
|
||||||
|
{&s.insertStateStmt, insertNewStateSnapshotSQL},
|
||||||
|
{&s.bulkSelectStateStmt, bulkSelectNewStateSnapshotSQL},
|
||||||
|
}.Prepare(db)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stateStatements) InsertState(
|
||||||
|
ctx context.Context,
|
||||||
|
txn *sql.Tx,
|
||||||
|
roomNID types.RoomNID,
|
||||||
|
eventNIDs []types.EventNID,
|
||||||
|
) (types.StateSnapshotNID, error) {
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.insertStateStmt)
|
||||||
|
var id int64
|
||||||
|
var err error
|
||||||
|
eventNIDs = types.DeduplicateEventNIDs(eventNIDs)
|
||||||
|
if err = stmt.QueryRowContext(ctx, int64(roomNID), eventNIDsAsArray(eventNIDs)).Scan(&id); err != nil {
|
||||||
|
return 0, fmt.Errorf("stmt.ExecContext: %w", err)
|
||||||
|
}
|
||||||
|
return types.StateSnapshotNID(id), err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stateStatements) BulkSelectState(
|
||||||
|
ctx context.Context, stateNIDs []types.StateSnapshotNID,
|
||||||
|
) (map[types.StateSnapshotNID][]types.EventNID, error) {
|
||||||
|
rows, err := s.bulkSelectStateStmt.QueryContext(ctx, stateSnapshotNIDsAsArray(stateNIDs))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockEntries: rows.close() failed")
|
||||||
|
|
||||||
|
results := map[types.StateSnapshotNID][]types.EventNID{}
|
||||||
|
for rows.Next() {
|
||||||
|
var stateNID int64
|
||||||
|
var eventNIDs pq.Int64Array
|
||||||
|
if err = rows.Scan(&stateNID, &eventNIDs); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
for _, id := range eventNIDs {
|
||||||
|
results[types.StateSnapshotNID(stateNID)] = append(
|
||||||
|
results[types.StateSnapshotNID(stateNID)],
|
||||||
|
types.EventNID(id),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err = rows.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("rows.Err: %w", err)
|
||||||
|
}
|
||||||
|
if len(results) != len(stateNIDs) {
|
||||||
|
return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", len(results), len(stateNIDs))
|
||||||
|
}
|
||||||
|
return results, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func stateSnapshotNIDsAsArray(stateSnapshotNIDs []types.StateSnapshotNID) pq.Int64Array {
|
||||||
|
nids := make([]int64, len(stateSnapshotNIDs))
|
||||||
|
for i := range stateSnapshotNIDs {
|
||||||
|
nids[i] = int64(stateSnapshotNIDs[i])
|
||||||
|
}
|
||||||
|
return nids
|
||||||
|
}
|
|
@ -86,6 +86,7 @@ func (d *Database) prepare(db *sql.DB, cache caching.RoomServerCaches) (err erro
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
/*
|
||||||
stateBlock, err := NewPostgresStateBlockTable(db)
|
stateBlock, err := NewPostgresStateBlockTable(db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -94,6 +95,11 @@ func (d *Database) prepare(db *sql.DB, cache caching.RoomServerCaches) (err erro
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
state, err := NewPostgresStateTable(db)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
roomAliases, err := NewPostgresRoomAliasesTable(db)
|
roomAliases, err := NewPostgresRoomAliasesTable(db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -128,8 +134,9 @@ func (d *Database) prepare(db *sql.DB, cache caching.RoomServerCaches) (err erro
|
||||||
EventsTable: events,
|
EventsTable: events,
|
||||||
RoomsTable: rooms,
|
RoomsTable: rooms,
|
||||||
TransactionsTable: transactions,
|
TransactionsTable: transactions,
|
||||||
StateBlockTable: stateBlock,
|
StateTable: state,
|
||||||
StateSnapshotTable: stateSnapshot,
|
// StateBlockTable: stateBlock,
|
||||||
|
// StateSnapshotTable: stateSnapshot,
|
||||||
PrevEventsTable: prevEvents,
|
PrevEventsTable: prevEvents,
|
||||||
RoomAliasesTable: roomAliases,
|
RoomAliasesTable: roomAliases,
|
||||||
InvitesTable: invites,
|
InvitesTable: invites,
|
||||||
|
|
|
@ -5,7 +5,6 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sort"
|
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal/caching"
|
"github.com/matrix-org/dendrite/internal/caching"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
@ -13,7 +12,6 @@ import (
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/util"
|
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -36,8 +34,7 @@ type Database struct {
|
||||||
EventStateKeysTable tables.EventStateKeys
|
EventStateKeysTable tables.EventStateKeys
|
||||||
RoomsTable tables.Rooms
|
RoomsTable tables.Rooms
|
||||||
TransactionsTable tables.Transactions
|
TransactionsTable tables.Transactions
|
||||||
StateSnapshotTable tables.StateSnapshot
|
StateTable tables.State
|
||||||
StateBlockTable tables.StateBlock
|
|
||||||
RoomAliasesTable tables.RoomAliases
|
RoomAliasesTable tables.RoomAliases
|
||||||
PrevEventsTable tables.PreviousEvents
|
PrevEventsTable tables.PreviousEvents
|
||||||
InvitesTable tables.Invites
|
InvitesTable tables.Invites
|
||||||
|
@ -113,16 +110,6 @@ func (d *Database) StateEntriesForEventIDs(
|
||||||
return d.EventsTable.BulkSelectStateEventByID(ctx, eventIDs)
|
return d.EventsTable.BulkSelectStateEventByID(ctx, eventIDs)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) StateEntriesForTuples(
|
|
||||||
ctx context.Context,
|
|
||||||
stateBlockNIDs []types.StateBlockNID,
|
|
||||||
stateKeyTuples []types.StateKeyTuple,
|
|
||||||
) ([]types.StateEntryList, error) {
|
|
||||||
return d.StateBlockTable.BulkSelectFilteredStateBlockEntries(
|
|
||||||
ctx, stateBlockNIDs, stateKeyTuples,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Database) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) {
|
func (d *Database) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) {
|
||||||
if roomInfo, ok := d.Cache.GetRoomInfo(roomID); ok {
|
if roomInfo, ok := d.Cache.GetRoomInfo(roomID); ok {
|
||||||
return &roomInfo, nil
|
return &roomInfo, nil
|
||||||
|
@ -138,21 +125,16 @@ func (d *Database) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo
|
||||||
func (d *Database) AddState(
|
func (d *Database) AddState(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
roomNID types.RoomNID,
|
roomNID types.RoomNID,
|
||||||
stateBlockNIDs []types.StateBlockNID,
|
|
||||||
state []types.StateEntry,
|
state []types.StateEntry,
|
||||||
) (stateNID types.StateSnapshotNID, err error) {
|
) (stateNID types.StateSnapshotNID, err error) {
|
||||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
if len(state) > 0 {
|
eventNIDs := make([]types.EventNID, 0, len(state))
|
||||||
var stateBlockNID types.StateBlockNID
|
for _, s := range state {
|
||||||
stateBlockNID, err = d.StateBlockTable.BulkInsertStateData(ctx, txn, state)
|
eventNIDs = append(eventNIDs, s.EventNID)
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("d.StateBlockTable.BulkInsertStateData: %w", err)
|
|
||||||
}
|
}
|
||||||
stateBlockNIDs = append(stateBlockNIDs[:len(stateBlockNIDs):len(stateBlockNIDs)], stateBlockNID)
|
stateNID, err = d.StateTable.InsertState(ctx, txn, roomNID, eventNIDs)
|
||||||
}
|
|
||||||
stateNID, err = d.StateSnapshotTable.InsertState(ctx, txn, roomNID, stateBlockNIDs)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("d.StateSnapshotTable.InsertState: %w", err)
|
return fmt.Errorf("d.StateTable.InsertState: %w", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
@ -228,16 +210,22 @@ func (d *Database) LatestEventIDs(
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) StateBlockNIDs(
|
|
||||||
ctx context.Context, stateNIDs []types.StateSnapshotNID,
|
|
||||||
) ([]types.StateBlockNIDList, error) {
|
|
||||||
return d.StateSnapshotTable.BulkSelectStateBlockNIDs(ctx, stateNIDs)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Database) StateEntries(
|
func (d *Database) StateEntries(
|
||||||
ctx context.Context, stateBlockNIDs []types.StateBlockNID,
|
ctx context.Context, stateSnapshotNID types.StateSnapshotNID,
|
||||||
) ([]types.StateEntryList, error) {
|
) ([]types.StateEntry, error) {
|
||||||
return d.StateBlockTable.BulkSelectStateBlockEntries(ctx, stateBlockNIDs)
|
nids, err := d.StateTable.BulkSelectState(ctx, []types.StateSnapshotNID{stateSnapshotNID})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("d.StateTable.BulkSelectState: %w", err)
|
||||||
|
}
|
||||||
|
state, ok := nids[stateSnapshotNID]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("state snapshot %d not found", stateSnapshotNID)
|
||||||
|
}
|
||||||
|
entries, err := d.EventsTable.BulkSelectStateEventByNID(ctx, state)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("d.EventsTable.BulkSelectStateEventByNID: %w", err)
|
||||||
|
}
|
||||||
|
return entries, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error {
|
func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error {
|
||||||
|
@ -817,9 +805,17 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
entries, err := d.loadStateAtSnapshot(ctx, roomInfo.StateSnapshotNID)
|
snapshots, err := d.StateTable.BulkSelectState(ctx, []types.StateSnapshotNID{roomInfo.StateSnapshotNID})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("d.StateTable.BulkSelectState: %w", err)
|
||||||
|
}
|
||||||
|
nids, ok := snapshots[roomInfo.StateSnapshotNID]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("state snapshot %d not found", roomInfo.StateSnapshotNID)
|
||||||
|
}
|
||||||
|
entries, err := d.EventsTable.BulkSelectStateEventByNID(ctx, nids)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("d.EventsTable.BulkSelectStateEventByNID: %w", err)
|
||||||
}
|
}
|
||||||
var eventNIDs []types.EventNID
|
var eventNIDs []types.EventNID
|
||||||
for _, e := range entries {
|
for _, e := range entries {
|
||||||
|
@ -938,7 +934,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
|
||||||
if roomInfo == nil || roomInfo.IsStub {
|
if roomInfo == nil || roomInfo.IsStub {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
entries, err2 := d.loadStateAtSnapshot(ctx, roomInfo.StateSnapshotNID)
|
entries, err2 := d.StateEntries(ctx, roomInfo.StateSnapshotNID)
|
||||||
if err2 != nil {
|
if err2 != nil {
|
||||||
return nil, fmt.Errorf("GetBulkStateContent: failed to load state for room %s : %w", roomID, err2)
|
return nil, fmt.Errorf("GetBulkStateContent: failed to load state for room %s : %w", roomID, err2)
|
||||||
}
|
}
|
||||||
|
@ -1039,65 +1035,3 @@ func (d *Database) ForgetRoom(ctx context.Context, userID, roomID string, forget
|
||||||
return d.MembershipTable.UpdateForgetMembership(ctx, nil, roomNIDs[0], stateKeyNID, forget)
|
return d.MembershipTable.UpdateForgetMembership(ctx, nil, roomNIDs[0], stateKeyNID, forget)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// FIXME TODO: Remove all this - horrible dupe with roomserver/state. Can't use the original impl because of circular loops
|
|
||||||
// it should live in this package!
|
|
||||||
|
|
||||||
func (d *Database) loadStateAtSnapshot(
|
|
||||||
ctx context.Context, stateNID types.StateSnapshotNID,
|
|
||||||
) ([]types.StateEntry, error) {
|
|
||||||
stateBlockNIDLists, err := d.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
// We've asked for exactly one snapshot from the db so we should have exactly one entry in the result.
|
|
||||||
stateBlockNIDList := stateBlockNIDLists[0]
|
|
||||||
|
|
||||||
stateEntryLists, err := d.StateEntries(ctx, stateBlockNIDList.StateBlockNIDs)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
stateEntriesMap := stateEntryListMap(stateEntryLists)
|
|
||||||
|
|
||||||
// Combine all the state entries for this snapshot.
|
|
||||||
// The order of state block NIDs in the list tells us the order to combine them in.
|
|
||||||
var fullState []types.StateEntry
|
|
||||||
for _, stateBlockNID := range stateBlockNIDList.StateBlockNIDs {
|
|
||||||
entries, ok := stateEntriesMap.lookup(stateBlockNID)
|
|
||||||
if !ok {
|
|
||||||
// This should only get hit if the database is corrupt.
|
|
||||||
// It should be impossible for an event to reference a NID that doesn't exist
|
|
||||||
panic(fmt.Errorf("Corrupt DB: Missing state block numeric ID %d", stateBlockNID))
|
|
||||||
}
|
|
||||||
fullState = append(fullState, entries...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stable sort so that the most recent entry for each state key stays
|
|
||||||
// remains later in the list than the older entries for the same state key.
|
|
||||||
sort.Stable(stateEntryByStateKeySorter(fullState))
|
|
||||||
// Unique returns the last entry and hence the most recent entry for each state key.
|
|
||||||
fullState = fullState[:util.Unique(stateEntryByStateKeySorter(fullState))]
|
|
||||||
return fullState, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type stateEntryListMap []types.StateEntryList
|
|
||||||
|
|
||||||
func (m stateEntryListMap) lookup(stateBlockNID types.StateBlockNID) (stateEntries []types.StateEntry, ok bool) {
|
|
||||||
list := []types.StateEntryList(m)
|
|
||||||
i := sort.Search(len(list), func(i int) bool {
|
|
||||||
return list[i].StateBlockNID >= stateBlockNID
|
|
||||||
})
|
|
||||||
if i < len(list) && list[i].StateBlockNID == stateBlockNID {
|
|
||||||
ok = true
|
|
||||||
stateEntries = list[i].StateEntries
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
type stateEntryByStateKeySorter []types.StateEntry
|
|
||||||
|
|
||||||
func (s stateEntryByStateKeySorter) Len() int { return len(s) }
|
|
||||||
func (s stateEntryByStateKeySorter) Less(i, j int) bool {
|
|
||||||
return s[i].StateKeyTuple.LessThan(s[j].StateKeyTuple)
|
|
||||||
}
|
|
||||||
func (s stateEntryByStateKeySorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
|
||||||
|
|
|
@ -63,6 +63,11 @@ const bulkSelectStateEventByIDSQL = "" +
|
||||||
" WHERE event_id IN ($1)" +
|
" WHERE event_id IN ($1)" +
|
||||||
" ORDER BY event_type_nid, event_state_key_nid ASC"
|
" ORDER BY event_type_nid, event_state_key_nid ASC"
|
||||||
|
|
||||||
|
const bulkSelectStateEventByNIDSQL = "" +
|
||||||
|
"SELECT event_type_nid, event_state_key_nid, event_nid FROM roomserver_events" +
|
||||||
|
" WHERE event_nid IN ($1)" +
|
||||||
|
" ORDER BY event_type_nid, event_state_key_nid ASC"
|
||||||
|
|
||||||
const bulkSelectStateAtEventByIDSQL = "" +
|
const bulkSelectStateAtEventByIDSQL = "" +
|
||||||
"SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, is_rejected FROM roomserver_events" +
|
"SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, is_rejected FROM roomserver_events" +
|
||||||
" WHERE event_id IN ($1)"
|
" WHERE event_id IN ($1)"
|
||||||
|
@ -103,6 +108,7 @@ type eventStatements struct {
|
||||||
insertEventStmt *sql.Stmt
|
insertEventStmt *sql.Stmt
|
||||||
selectEventStmt *sql.Stmt
|
selectEventStmt *sql.Stmt
|
||||||
bulkSelectStateEventByIDStmt *sql.Stmt
|
bulkSelectStateEventByIDStmt *sql.Stmt
|
||||||
|
bulkSelectStateEventByNIDStmt *sql.Stmt
|
||||||
bulkSelectStateAtEventByIDStmt *sql.Stmt
|
bulkSelectStateAtEventByIDStmt *sql.Stmt
|
||||||
updateEventStateStmt *sql.Stmt
|
updateEventStateStmt *sql.Stmt
|
||||||
selectEventSentToOutputStmt *sql.Stmt
|
selectEventSentToOutputStmt *sql.Stmt
|
||||||
|
@ -128,6 +134,7 @@ func NewSqliteEventsTable(db *sql.DB) (tables.Events, error) {
|
||||||
{&s.insertEventStmt, insertEventSQL},
|
{&s.insertEventStmt, insertEventSQL},
|
||||||
{&s.selectEventStmt, selectEventSQL},
|
{&s.selectEventStmt, selectEventSQL},
|
||||||
{&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL},
|
{&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL},
|
||||||
|
{&s.bulkSelectStateEventByNIDStmt, bulkSelectStateEventByNIDSQL},
|
||||||
{&s.bulkSelectStateAtEventByIDStmt, bulkSelectStateAtEventByIDSQL},
|
{&s.bulkSelectStateAtEventByIDStmt, bulkSelectStateAtEventByIDSQL},
|
||||||
{&s.updateEventStateStmt, updateEventStateSQL},
|
{&s.updateEventStateStmt, updateEventStateSQL},
|
||||||
{&s.updateEventSentToOutputStmt, updateEventSentToOutputSQL},
|
{&s.updateEventSentToOutputStmt, updateEventSentToOutputSQL},
|
||||||
|
@ -232,6 +239,57 @@ func (s *eventStatements) BulkSelectStateEventByID(
|
||||||
return results, err
|
return results, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// bulkSelectStateEventByID lookups a list of state events by event ID.
|
||||||
|
// If any of the requested events are missing from the database it returns a types.MissingEventError
|
||||||
|
func (s *eventStatements) BulkSelectStateEventByNID(
|
||||||
|
ctx context.Context, eventNIDs []types.EventNID,
|
||||||
|
) ([]types.StateEntry, error) {
|
||||||
|
///////////////
|
||||||
|
iEventNIDs := make([]interface{}, len(eventNIDs))
|
||||||
|
for k, v := range eventNIDs {
|
||||||
|
iEventNIDs[k] = v
|
||||||
|
}
|
||||||
|
selectOrig := strings.Replace(bulkSelectStateEventByNIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1)
|
||||||
|
selectStmt, err := s.db.Prepare(selectOrig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
///////////////
|
||||||
|
|
||||||
|
rows, err := selectStmt.QueryContext(ctx, iEventNIDs...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateEventByID: rows.close() failed")
|
||||||
|
// We know that we will only get as many results as event IDs
|
||||||
|
// because of the unique constraint on event IDs.
|
||||||
|
// So we can allocate an array of the correct size now.
|
||||||
|
// We might get fewer results than IDs so we adjust the length of the slice before returning it.
|
||||||
|
results := make([]types.StateEntry, len(eventNIDs))
|
||||||
|
i := 0
|
||||||
|
for ; rows.Next(); i++ {
|
||||||
|
result := &results[i]
|
||||||
|
if err = rows.Scan(
|
||||||
|
&result.EventTypeNID,
|
||||||
|
&result.EventStateKeyNID,
|
||||||
|
&result.EventNID,
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if i != len(eventNIDs) {
|
||||||
|
// If there are fewer rows returned than IDs then we were asked to lookup event IDs we don't have.
|
||||||
|
// We don't know which ones were missing because we don't return the string IDs in the query.
|
||||||
|
// However it should be possible debug this by replaying queries or entries from the input kafka logs.
|
||||||
|
// If this turns out to be impossible and we do need the debug information here, it would be better
|
||||||
|
// to do it as a separate query rather than slowing down/complicating the internal case.
|
||||||
|
return nil, types.MissingEventError(
|
||||||
|
fmt.Sprintf("storage: state event IDs missing from the database (%d != %d)", i, len(eventNIDs)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return results, err
|
||||||
|
}
|
||||||
|
|
||||||
// bulkSelectStateAtEventByID lookups the state at a list of events by event ID.
|
// bulkSelectStateAtEventByID lookups the state at a list of events by event ID.
|
||||||
// If any of the requested events are missing from the database it returns a types.MissingEventError.
|
// If any of the requested events are missing from the database it returns a types.MissingEventError.
|
||||||
// If we do not have the state for any of the requested events it returns a types.MissingEventError.
|
// If we do not have the state for any of the requested events it returns a types.MissingEventError.
|
||||||
|
|
|
@ -1,289 +0,0 @@
|
||||||
// Copyright 2017-2018 New Vector Ltd
|
|
||||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package sqlite3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
|
||||||
"sort"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
|
||||||
"github.com/matrix-org/util"
|
|
||||||
)
|
|
||||||
|
|
||||||
const stateDataSchema = `
|
|
||||||
CREATE TABLE IF NOT EXISTS roomserver_state_block (
|
|
||||||
state_block_nid INTEGER NOT NULL,
|
|
||||||
event_type_nid INTEGER NOT NULL,
|
|
||||||
event_state_key_nid INTEGER NOT NULL,
|
|
||||||
event_nid INTEGER NOT NULL,
|
|
||||||
UNIQUE (state_block_nid, event_type_nid, event_state_key_nid)
|
|
||||||
);
|
|
||||||
`
|
|
||||||
|
|
||||||
const insertStateDataSQL = "" +
|
|
||||||
"INSERT INTO roomserver_state_block (state_block_nid, event_type_nid, event_state_key_nid, event_nid)" +
|
|
||||||
" VALUES ($1, $2, $3, $4)"
|
|
||||||
|
|
||||||
const selectNextStateBlockNIDSQL = `
|
|
||||||
SELECT IFNULL(MAX(state_block_nid), 0) + 1 FROM roomserver_state_block
|
|
||||||
`
|
|
||||||
|
|
||||||
// Bulk state lookup by numeric state block ID.
|
|
||||||
// Sort by the state_block_nid, event_type_nid, event_state_key_nid
|
|
||||||
// This means that all the entries for a given state_block_nid will appear
|
|
||||||
// together in the list and those entries will sorted by event_type_nid
|
|
||||||
// and event_state_key_nid. This property makes it easier to merge two
|
|
||||||
// state data blocks together.
|
|
||||||
const bulkSelectStateBlockEntriesSQL = "" +
|
|
||||||
"SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" +
|
|
||||||
" FROM roomserver_state_block WHERE state_block_nid IN ($1)" +
|
|
||||||
" ORDER BY state_block_nid, event_type_nid, event_state_key_nid"
|
|
||||||
|
|
||||||
// Bulk state lookup by numeric state block ID.
|
|
||||||
// Filters the rows in each block to the requested types and state keys.
|
|
||||||
// We would like to restrict to particular type state key pairs but we are
|
|
||||||
// restricted by the query language to pull the cross product of a list
|
|
||||||
// of types and a list state_keys. So we have to filter the result in the
|
|
||||||
// application to restrict it to the list of event types and state keys we
|
|
||||||
// actually wanted.
|
|
||||||
const bulkSelectFilteredStateBlockEntriesSQL = "" +
|
|
||||||
"SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" +
|
|
||||||
" FROM roomserver_state_block WHERE state_block_nid IN ($1)" +
|
|
||||||
" AND event_type_nid IN ($2) AND event_state_key_nid IN ($3)" +
|
|
||||||
" ORDER BY state_block_nid, event_type_nid, event_state_key_nid"
|
|
||||||
|
|
||||||
type stateBlockStatements struct {
|
|
||||||
db *sql.DB
|
|
||||||
insertStateDataStmt *sql.Stmt
|
|
||||||
selectNextStateBlockNIDStmt *sql.Stmt
|
|
||||||
bulkSelectStateBlockEntriesStmt *sql.Stmt
|
|
||||||
bulkSelectFilteredStateBlockEntriesStmt *sql.Stmt
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewSqliteStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
|
|
||||||
s := &stateBlockStatements{
|
|
||||||
db: db,
|
|
||||||
}
|
|
||||||
_, err := db.Exec(stateDataSchema)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return s, shared.StatementList{
|
|
||||||
{&s.insertStateDataStmt, insertStateDataSQL},
|
|
||||||
{&s.selectNextStateBlockNIDStmt, selectNextStateBlockNIDSQL},
|
|
||||||
{&s.bulkSelectStateBlockEntriesStmt, bulkSelectStateBlockEntriesSQL},
|
|
||||||
{&s.bulkSelectFilteredStateBlockEntriesStmt, bulkSelectFilteredStateBlockEntriesSQL},
|
|
||||||
}.Prepare(db)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *stateBlockStatements) BulkInsertStateData(
|
|
||||||
ctx context.Context, txn *sql.Tx,
|
|
||||||
entries []types.StateEntry,
|
|
||||||
) (types.StateBlockNID, error) {
|
|
||||||
if len(entries) == 0 {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
var stateBlockNID types.StateBlockNID
|
|
||||||
err := sqlutil.TxStmt(txn, s.selectNextStateBlockNIDStmt).QueryRowContext(ctx).Scan(&stateBlockNID)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
for _, entry := range entries {
|
|
||||||
_, err = sqlutil.TxStmt(txn, s.insertStateDataStmt).ExecContext(
|
|
||||||
ctx,
|
|
||||||
int64(stateBlockNID),
|
|
||||||
int64(entry.EventTypeNID),
|
|
||||||
int64(entry.EventStateKeyNID),
|
|
||||||
int64(entry.EventNID),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return stateBlockNID, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *stateBlockStatements) BulkSelectStateBlockEntries(
|
|
||||||
ctx context.Context, stateBlockNIDs []types.StateBlockNID,
|
|
||||||
) ([]types.StateEntryList, error) {
|
|
||||||
nids := make([]interface{}, len(stateBlockNIDs))
|
|
||||||
for k, v := range stateBlockNIDs {
|
|
||||||
nids[k] = v
|
|
||||||
}
|
|
||||||
selectOrig := strings.Replace(bulkSelectStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1)
|
|
||||||
selectStmt, err := s.db.Prepare(selectOrig)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
rows, err := selectStmt.QueryContext(ctx, nids...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockEntries: rows.close() failed")
|
|
||||||
|
|
||||||
results := make([]types.StateEntryList, len(stateBlockNIDs))
|
|
||||||
// current is a pointer to the StateEntryList to append the state entries to.
|
|
||||||
var current *types.StateEntryList
|
|
||||||
i := 0
|
|
||||||
for rows.Next() {
|
|
||||||
var (
|
|
||||||
stateBlockNID int64
|
|
||||||
eventTypeNID int64
|
|
||||||
eventStateKeyNID int64
|
|
||||||
eventNID int64
|
|
||||||
entry types.StateEntry
|
|
||||||
)
|
|
||||||
if err := rows.Scan(
|
|
||||||
&stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID,
|
|
||||||
); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
entry.EventTypeNID = types.EventTypeNID(eventTypeNID)
|
|
||||||
entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID)
|
|
||||||
entry.EventNID = types.EventNID(eventNID)
|
|
||||||
if current == nil || types.StateBlockNID(stateBlockNID) != current.StateBlockNID {
|
|
||||||
// The state entry row is for a different state data block to the current one.
|
|
||||||
// So we start appending to the next entry in the list.
|
|
||||||
current = &results[i]
|
|
||||||
current.StateBlockNID = types.StateBlockNID(stateBlockNID)
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
current.StateEntries = append(current.StateEntries, entry)
|
|
||||||
}
|
|
||||||
if i != len(nids) {
|
|
||||||
return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", i, len(nids))
|
|
||||||
}
|
|
||||||
return results, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *stateBlockStatements) BulkSelectFilteredStateBlockEntries(
|
|
||||||
ctx context.Context,
|
|
||||||
stateBlockNIDs []types.StateBlockNID,
|
|
||||||
stateKeyTuples []types.StateKeyTuple,
|
|
||||||
) ([]types.StateEntryList, error) {
|
|
||||||
tuples := stateKeyTupleSorter(stateKeyTuples)
|
|
||||||
// Sort the tuples so that we can run binary search against them as we filter the rows returned by the db.
|
|
||||||
sort.Sort(tuples)
|
|
||||||
|
|
||||||
eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays()
|
|
||||||
sqlStatement := strings.Replace(bulkSelectFilteredStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(stateBlockNIDs)), 1)
|
|
||||||
sqlStatement = strings.Replace(sqlStatement, "($2)", sqlutil.QueryVariadicOffset(len(eventTypeNIDArray), len(stateBlockNIDs)), 1)
|
|
||||||
sqlStatement = strings.Replace(sqlStatement, "($3)", sqlutil.QueryVariadicOffset(len(eventStateKeyNIDArray), len(stateBlockNIDs)+len(eventTypeNIDArray)), 1)
|
|
||||||
|
|
||||||
var params []interface{}
|
|
||||||
for _, val := range stateBlockNIDs {
|
|
||||||
params = append(params, int64(val))
|
|
||||||
}
|
|
||||||
for _, val := range eventTypeNIDArray {
|
|
||||||
params = append(params, val)
|
|
||||||
}
|
|
||||||
for _, val := range eventStateKeyNIDArray {
|
|
||||||
params = append(params, val)
|
|
||||||
}
|
|
||||||
|
|
||||||
rows, err := s.db.QueryContext(
|
|
||||||
ctx,
|
|
||||||
sqlStatement,
|
|
||||||
params...,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectFilteredStateBlockEntries: rows.close() failed")
|
|
||||||
|
|
||||||
var results []types.StateEntryList
|
|
||||||
var current types.StateEntryList
|
|
||||||
for rows.Next() {
|
|
||||||
var (
|
|
||||||
stateBlockNID int64
|
|
||||||
eventTypeNID int64
|
|
||||||
eventStateKeyNID int64
|
|
||||||
eventNID int64
|
|
||||||
entry types.StateEntry
|
|
||||||
)
|
|
||||||
if err := rows.Scan(
|
|
||||||
&stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID,
|
|
||||||
); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
entry.EventTypeNID = types.EventTypeNID(eventTypeNID)
|
|
||||||
entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID)
|
|
||||||
entry.EventNID = types.EventNID(eventNID)
|
|
||||||
|
|
||||||
// We can use binary search here because we sorted the tuples earlier
|
|
||||||
if !tuples.contains(entry.StateKeyTuple) {
|
|
||||||
// The select will return the cross product of types and state keys.
|
|
||||||
// So we need to check if type of the entry is in the list.
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if types.StateBlockNID(stateBlockNID) != current.StateBlockNID {
|
|
||||||
// The state entry row is for a different state data block to the current one.
|
|
||||||
// So we append the current entry to the results and start adding to a new one.
|
|
||||||
// The first time through the loop current will be empty.
|
|
||||||
if current.StateEntries != nil {
|
|
||||||
results = append(results, current)
|
|
||||||
}
|
|
||||||
current = types.StateEntryList{StateBlockNID: types.StateBlockNID(stateBlockNID)}
|
|
||||||
}
|
|
||||||
current.StateEntries = append(current.StateEntries, entry)
|
|
||||||
}
|
|
||||||
// Add the last entry to the list if it is not empty.
|
|
||||||
if current.StateEntries != nil {
|
|
||||||
results = append(results, current)
|
|
||||||
}
|
|
||||||
return results, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type stateKeyTupleSorter []types.StateKeyTuple
|
|
||||||
|
|
||||||
func (s stateKeyTupleSorter) Len() int { return len(s) }
|
|
||||||
func (s stateKeyTupleSorter) Less(i, j int) bool { return s[i].LessThan(s[j]) }
|
|
||||||
func (s stateKeyTupleSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
|
||||||
|
|
||||||
// Check whether a tuple is in the list. Assumes that the list is sorted.
|
|
||||||
func (s stateKeyTupleSorter) contains(value types.StateKeyTuple) bool {
|
|
||||||
i := sort.Search(len(s), func(i int) bool { return !s[i].LessThan(value) })
|
|
||||||
return i < len(s) && s[i] == value
|
|
||||||
}
|
|
||||||
|
|
||||||
// List the unique eventTypeNIDs and eventStateKeyNIDs.
|
|
||||||
// Assumes that the list is sorted.
|
|
||||||
func (s stateKeyTupleSorter) typesAndStateKeysAsArrays() (eventTypeNIDs []int64, eventStateKeyNIDs []int64) {
|
|
||||||
eventTypeNIDs = make([]int64, len(s))
|
|
||||||
eventStateKeyNIDs = make([]int64, len(s))
|
|
||||||
for i := range s {
|
|
||||||
eventTypeNIDs[i] = int64(s[i].EventTypeNID)
|
|
||||||
eventStateKeyNIDs[i] = int64(s[i].EventStateKeyNID)
|
|
||||||
}
|
|
||||||
eventTypeNIDs = eventTypeNIDs[:util.SortAndUnique(int64Sorter(eventTypeNIDs))]
|
|
||||||
eventStateKeyNIDs = eventStateKeyNIDs[:util.SortAndUnique(int64Sorter(eventStateKeyNIDs))]
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
type int64Sorter []int64
|
|
||||||
|
|
||||||
func (s int64Sorter) Len() int { return len(s) }
|
|
||||||
func (s int64Sorter) Less(i, j int) bool { return s[i] < s[j] }
|
|
||||||
func (s int64Sorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
|
|
@ -1,86 +0,0 @@
|
||||||
// Copyright 2017-2018 New Vector Ltd
|
|
||||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package sqlite3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"sort"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestStateKeyTupleSorter(t *testing.T) {
|
|
||||||
input := stateKeyTupleSorter{
|
|
||||||
{EventTypeNID: 1, EventStateKeyNID: 2},
|
|
||||||
{EventTypeNID: 1, EventStateKeyNID: 4},
|
|
||||||
{EventTypeNID: 2, EventStateKeyNID: 2},
|
|
||||||
{EventTypeNID: 1, EventStateKeyNID: 1},
|
|
||||||
}
|
|
||||||
want := []types.StateKeyTuple{
|
|
||||||
{EventTypeNID: 1, EventStateKeyNID: 1},
|
|
||||||
{EventTypeNID: 1, EventStateKeyNID: 2},
|
|
||||||
{EventTypeNID: 1, EventStateKeyNID: 4},
|
|
||||||
{EventTypeNID: 2, EventStateKeyNID: 2},
|
|
||||||
}
|
|
||||||
doNotWant := []types.StateKeyTuple{
|
|
||||||
{EventTypeNID: 0, EventStateKeyNID: 0},
|
|
||||||
{EventTypeNID: 1, EventStateKeyNID: 3},
|
|
||||||
{EventTypeNID: 2, EventStateKeyNID: 1},
|
|
||||||
{EventTypeNID: 3, EventStateKeyNID: 1},
|
|
||||||
}
|
|
||||||
wantTypeNIDs := []int64{1, 2}
|
|
||||||
wantStateKeyNIDs := []int64{1, 2, 4}
|
|
||||||
|
|
||||||
// Sort the input and check it's in the right order.
|
|
||||||
sort.Sort(input)
|
|
||||||
gotTypeNIDs, gotStateKeyNIDs := input.typesAndStateKeysAsArrays()
|
|
||||||
|
|
||||||
for i := range want {
|
|
||||||
if input[i] != want[i] {
|
|
||||||
t.Errorf("Wanted %#v at index %d got %#v", want[i], i, input[i])
|
|
||||||
}
|
|
||||||
|
|
||||||
if !input.contains(want[i]) {
|
|
||||||
t.Errorf("Wanted %#v.contains(%#v) to be true but got false", input, want[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range doNotWant {
|
|
||||||
if input.contains(doNotWant[i]) {
|
|
||||||
t.Errorf("Wanted %#v.contains(%#v) to be false but got true", input, doNotWant[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(wantTypeNIDs) != len(gotTypeNIDs) {
|
|
||||||
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range wantTypeNIDs {
|
|
||||||
if wantTypeNIDs[i] != gotTypeNIDs[i] {
|
|
||||||
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(wantStateKeyNIDs) != len(gotStateKeyNIDs) {
|
|
||||||
t.Fatalf("Wanted state key NIDs %#v got %#v", wantStateKeyNIDs, gotStateKeyNIDs)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range wantStateKeyNIDs {
|
|
||||||
if wantStateKeyNIDs[i] != gotStateKeyNIDs[i] {
|
|
||||||
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,126 +0,0 @@
|
||||||
// Copyright 2017-2018 New Vector Ltd
|
|
||||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package sqlite3
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"database/sql"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
|
||||||
)
|
|
||||||
|
|
||||||
const stateSnapshotSchema = `
|
|
||||||
CREATE TABLE IF NOT EXISTS roomserver_state_snapshots (
|
|
||||||
state_snapshot_nid INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
room_nid INTEGER NOT NULL,
|
|
||||||
state_block_nids TEXT NOT NULL DEFAULT '[]'
|
|
||||||
);
|
|
||||||
`
|
|
||||||
|
|
||||||
const insertStateSQL = `
|
|
||||||
INSERT INTO roomserver_state_snapshots (room_nid, state_block_nids)
|
|
||||||
VALUES ($1, $2);`
|
|
||||||
|
|
||||||
// Bulk state data NID lookup.
|
|
||||||
// Sorting by state_snapshot_nid means we can use binary search over the result
|
|
||||||
// to lookup the state data NIDs for a state snapshot NID.
|
|
||||||
const bulkSelectStateBlockNIDsSQL = "" +
|
|
||||||
"SELECT state_snapshot_nid, state_block_nids FROM roomserver_state_snapshots" +
|
|
||||||
" WHERE state_snapshot_nid IN ($1) ORDER BY state_snapshot_nid ASC"
|
|
||||||
|
|
||||||
type stateSnapshotStatements struct {
|
|
||||||
db *sql.DB
|
|
||||||
insertStateStmt *sql.Stmt
|
|
||||||
bulkSelectStateBlockNIDsStmt *sql.Stmt
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewSqliteStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
|
|
||||||
s := &stateSnapshotStatements{
|
|
||||||
db: db,
|
|
||||||
}
|
|
||||||
_, err := db.Exec(stateSnapshotSchema)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return s, shared.StatementList{
|
|
||||||
{&s.insertStateStmt, insertStateSQL},
|
|
||||||
{&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL},
|
|
||||||
}.Prepare(db)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *stateSnapshotStatements) InsertState(
|
|
||||||
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID,
|
|
||||||
) (stateNID types.StateSnapshotNID, err error) {
|
|
||||||
stateBlockNIDsJSON, err := json.Marshal(stateBlockNIDs)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
insertStmt := sqlutil.TxStmt(txn, s.insertStateStmt)
|
|
||||||
res, err := insertStmt.ExecContext(ctx, int64(roomNID), string(stateBlockNIDsJSON))
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
lastRowID, err := res.LastInsertId()
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
stateNID = types.StateSnapshotNID(lastRowID)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
|
|
||||||
ctx context.Context, stateNIDs []types.StateSnapshotNID,
|
|
||||||
) ([]types.StateBlockNIDList, error) {
|
|
||||||
nids := make([]interface{}, len(stateNIDs))
|
|
||||||
for k, v := range stateNIDs {
|
|
||||||
nids[k] = v
|
|
||||||
}
|
|
||||||
selectOrig := strings.Replace(bulkSelectStateBlockNIDsSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1)
|
|
||||||
selectStmt, err := s.db.Prepare(selectOrig)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
rows, err := selectStmt.QueryContext(ctx, nids...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockNIDs: rows.close() failed")
|
|
||||||
results := make([]types.StateBlockNIDList, len(stateNIDs))
|
|
||||||
i := 0
|
|
||||||
for ; rows.Next(); i++ {
|
|
||||||
result := &results[i]
|
|
||||||
var stateBlockNIDsJSON string
|
|
||||||
if err := rows.Scan(&result.StateSnapshotNID, &stateBlockNIDsJSON); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err := json.Unmarshal([]byte(stateBlockNIDsJSON), &result.StateBlockNIDs); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if i != len(stateNIDs) {
|
|
||||||
return nil, fmt.Errorf("storage: state NIDs missing from the database (%d != %d)", i, len(stateNIDs))
|
|
||||||
}
|
|
||||||
return results, nil
|
|
||||||
}
|
|
132
roomserver/storage/sqlite3/state_table.go
Normal file
132
roomserver/storage/sqlite3/state_table.go
Normal file
|
@ -0,0 +1,132 @@
|
||||||
|
// Copyright 2017-2018 New Vector Ltd
|
||||||
|
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
package sqlite3
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal"
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
const stateSchema = `
|
||||||
|
CREATE TABLE IF NOT EXISTS roomserver_state (
|
||||||
|
state_nid INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
room_nid INTEGER NOT NULL,
|
||||||
|
event_nids TEXT NOT NULL,
|
||||||
|
UNIQUE (room_nid, event_nids),
|
||||||
|
CONSTRAINT fk_room_id FOREIGN KEY(room_nid) REFERENCES roomserver_rooms(room_nid)
|
||||||
|
);
|
||||||
|
`
|
||||||
|
|
||||||
|
const insertNewStateSnapshotSQL = "" +
|
||||||
|
"INSERT INTO roomserver_state (room_nid, event_nids)" +
|
||||||
|
" VALUES ($1, $2)" +
|
||||||
|
" ON CONFLICT (room_nid, event_nids) DO UPDATE SET room_nid = $1" +
|
||||||
|
" RETURNING state_nid"
|
||||||
|
|
||||||
|
const bulkSelectNewStateSnapshotSQL = "" +
|
||||||
|
"SELECT state_nid, event_nids" +
|
||||||
|
" FROM roomserver_state WHERE state_nid IN ($1)" +
|
||||||
|
" ORDER BY state_nid"
|
||||||
|
|
||||||
|
type stateStatements struct {
|
||||||
|
db *sql.DB
|
||||||
|
insertStateStmt *sql.Stmt
|
||||||
|
bulkSelectStateStmt *sql.Stmt
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPostgresStateTable(db *sql.DB) (tables.State, error) {
|
||||||
|
s := &stateStatements{db: db}
|
||||||
|
_, err := db.Exec(stateSchema)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return s, shared.StatementList{
|
||||||
|
{&s.insertStateStmt, insertNewStateSnapshotSQL},
|
||||||
|
{&s.bulkSelectStateStmt, bulkSelectNewStateSnapshotSQL},
|
||||||
|
}.Prepare(db)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stateStatements) InsertState(
|
||||||
|
ctx context.Context,
|
||||||
|
txn *sql.Tx,
|
||||||
|
roomNID types.RoomNID,
|
||||||
|
eventNIDs []types.EventNID,
|
||||||
|
) (types.StateSnapshotNID, error) {
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.insertStateStmt)
|
||||||
|
value, err := json.Marshal(types.DeduplicateEventNIDs(eventNIDs))
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("json.Marshal: %w", err)
|
||||||
|
}
|
||||||
|
var id int64
|
||||||
|
if err = stmt.QueryRowContext(ctx, int64(roomNID), value).Scan(&id); err != nil {
|
||||||
|
return 0, fmt.Errorf("stmt.ExecContext: %w", err)
|
||||||
|
}
|
||||||
|
return types.StateSnapshotNID(id), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *stateStatements) BulkSelectState(
|
||||||
|
ctx context.Context, stateNIDs []types.StateSnapshotNID,
|
||||||
|
) (map[types.StateSnapshotNID][]types.EventNID, error) {
|
||||||
|
selectOrig := strings.Replace(bulkSelectNewStateSnapshotSQL, "($1)", sqlutil.QueryVariadic(len(stateNIDs)), 1)
|
||||||
|
selectStmt, err := s.db.Prepare(selectOrig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
nids := make([]interface{}, len(stateNIDs))
|
||||||
|
for i := range stateNIDs {
|
||||||
|
nids[i] = int64(stateNIDs[i])
|
||||||
|
}
|
||||||
|
rows, err := selectStmt.QueryContext(ctx, nids...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockEntries: rows.close() failed")
|
||||||
|
results := map[types.StateSnapshotNID][]types.EventNID{}
|
||||||
|
for rows.Next() {
|
||||||
|
var stateNID int64
|
||||||
|
var eventNIDJSON json.RawMessage
|
||||||
|
if err = rows.Scan(&stateNID, &eventNIDJSON); err != nil {
|
||||||
|
return nil, fmt.Errorf("rows.Scan: %w", err)
|
||||||
|
}
|
||||||
|
var eventNIDs []int64
|
||||||
|
if err = json.Unmarshal(eventNIDJSON, &eventNIDs); err != nil {
|
||||||
|
return nil, fmt.Errorf("json.Unmarshal: %w", err)
|
||||||
|
}
|
||||||
|
for _, id := range eventNIDs {
|
||||||
|
results[types.StateSnapshotNID(stateNID)] = append(
|
||||||
|
results[types.StateSnapshotNID(stateNID)],
|
||||||
|
types.EventNID(id),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err = rows.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("rows.Err: %w", err)
|
||||||
|
}
|
||||||
|
if len(results) != len(stateNIDs) {
|
||||||
|
return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", len(results), len(stateNIDs))
|
||||||
|
}
|
||||||
|
return results, err
|
||||||
|
}
|
|
@ -98,6 +98,7 @@ func (d *Database) prepare(db *sql.DB, cache caching.RoomServerCaches) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
/*
|
||||||
stateBlock, err := NewSqliteStateBlockTable(db)
|
stateBlock, err := NewSqliteStateBlockTable(db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -106,6 +107,11 @@ func (d *Database) prepare(db *sql.DB, cache caching.RoomServerCaches) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
state, err := NewPostgresStateTable(db)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
prevEvents, err := NewSqlitePrevEventsTable(db)
|
prevEvents, err := NewSqlitePrevEventsTable(db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -140,8 +146,9 @@ func (d *Database) prepare(db *sql.DB, cache caching.RoomServerCaches) error {
|
||||||
EventJSONTable: eventJSON,
|
EventJSONTable: eventJSON,
|
||||||
RoomsTable: rooms,
|
RoomsTable: rooms,
|
||||||
TransactionsTable: transactions,
|
TransactionsTable: transactions,
|
||||||
StateBlockTable: stateBlock,
|
StateTable: state,
|
||||||
StateSnapshotTable: stateSnapshot,
|
// StateBlockTable: stateBlock,
|
||||||
|
// StateSnapshotTable: stateSnapshot,
|
||||||
PrevEventsTable: prevEvents,
|
PrevEventsTable: prevEvents,
|
||||||
RoomAliasesTable: roomAliases,
|
RoomAliasesTable: roomAliases,
|
||||||
InvitesTable: invites,
|
InvitesTable: invites,
|
||||||
|
|
|
@ -43,6 +43,7 @@ type Events interface {
|
||||||
// bulkSelectStateEventByID lookups a list of state events by event ID.
|
// bulkSelectStateEventByID lookups a list of state events by event ID.
|
||||||
// If any of the requested events are missing from the database it returns a types.MissingEventError
|
// If any of the requested events are missing from the database it returns a types.MissingEventError
|
||||||
BulkSelectStateEventByID(ctx context.Context, eventIDs []string) ([]types.StateEntry, error)
|
BulkSelectStateEventByID(ctx context.Context, eventIDs []string) ([]types.StateEntry, error)
|
||||||
|
BulkSelectStateEventByNID(ctx context.Context, eventNIDs []types.EventNID) ([]types.StateEntry, error)
|
||||||
// BulkSelectStateAtEventByID lookups the state at a list of events by event ID.
|
// BulkSelectStateAtEventByID lookups the state at a list of events by event ID.
|
||||||
// If any of the requested events are missing from the database it returns a types.MissingEventError.
|
// If any of the requested events are missing from the database it returns a types.MissingEventError.
|
||||||
// If we do not have the state for any of the requested events it returns a types.MissingEventError.
|
// If we do not have the state for any of the requested events it returns a types.MissingEventError.
|
||||||
|
@ -80,6 +81,12 @@ type Transactions interface {
|
||||||
SelectTransactionEventID(ctx context.Context, transactionID string, sessionID int64, userID string) (eventID string, err error)
|
SelectTransactionEventID(ctx context.Context, transactionID string, sessionID int64, userID string) (eventID string, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type State interface {
|
||||||
|
InsertState(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID) (types.StateSnapshotNID, error)
|
||||||
|
BulkSelectState(ctx context.Context, stateNIDs []types.StateSnapshotNID) (map[types.StateSnapshotNID][]types.EventNID, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
type StateSnapshot interface {
|
type StateSnapshot interface {
|
||||||
InsertState(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID) (stateNID types.StateSnapshotNID, err error)
|
InsertState(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID) (stateNID types.StateSnapshotNID, err error)
|
||||||
BulkSelectStateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error)
|
BulkSelectStateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error)
|
||||||
|
@ -90,6 +97,7 @@ type StateBlock interface {
|
||||||
BulkSelectStateBlockEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error)
|
BulkSelectStateBlockEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error)
|
||||||
BulkSelectFilteredStateBlockEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error)
|
BulkSelectFilteredStateBlockEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error)
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
type RoomAliases interface {
|
type RoomAliases interface {
|
||||||
InsertRoomAlias(ctx context.Context, txn *sql.Tx, alias string, roomID string, creatorUserID string) (err error)
|
InsertRoomAlias(ctx context.Context, txn *sql.Tx, alias string, roomID string, creatorUserID string) (err error)
|
||||||
|
|
|
@ -74,6 +74,24 @@ func (a StateEntry) LessThan(b StateEntry) bool {
|
||||||
return a.EventNID < b.EventNID
|
return a.EventNID < b.EventNID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Deduplicate takes a set of event NIDs and ensures that there are no
|
||||||
|
// duplicates. If there are then we dedupe them.
|
||||||
|
func DeduplicateEventNIDs(a []EventNID) []EventNID {
|
||||||
|
if len(a) < 2 {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
sort.SliceStable(a, func(i, j int) bool {
|
||||||
|
return a[i] < a[j]
|
||||||
|
})
|
||||||
|
for i := 0; i < len(a)-1; i++ {
|
||||||
|
if a[i] == a[i+1] {
|
||||||
|
a = append(a[:i], a[i+1:]...)
|
||||||
|
i--
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
|
||||||
// Deduplicate takes a set of state entries and ensures that there are no
|
// Deduplicate takes a set of state entries and ensures that there are no
|
||||||
// duplicate (event type, state key) tuples. If there are then we dedupe
|
// duplicate (event type, state key) tuples. If there are then we dedupe
|
||||||
// them, making sure that the latest/highest NIDs are always chosen.
|
// them, making sure that the latest/highest NIDs are always chosen.
|
||||||
|
@ -151,18 +169,6 @@ const (
|
||||||
EmptyStateKeyNID = 1
|
EmptyStateKeyNID = 1
|
||||||
)
|
)
|
||||||
|
|
||||||
// StateBlockNIDList is used to return the result of bulk StateBlockNID lookups from the database.
|
|
||||||
type StateBlockNIDList struct {
|
|
||||||
StateSnapshotNID StateSnapshotNID
|
|
||||||
StateBlockNIDs []StateBlockNID
|
|
||||||
}
|
|
||||||
|
|
||||||
// StateEntryList is used to return the result of bulk state entry lookups from the database.
|
|
||||||
type StateEntryList struct {
|
|
||||||
StateBlockNID StateBlockNID
|
|
||||||
StateEntries []StateEntry
|
|
||||||
}
|
|
||||||
|
|
||||||
// A MissingEventError is an error that happened because the roomserver was
|
// A MissingEventError is an error that happened because the roomserver was
|
||||||
// missing requested events from its database.
|
// missing requested events from its database.
|
||||||
type MissingEventError string
|
type MissingEventError string
|
||||||
|
|
Loading…
Reference in a new issue