Initial work on simplified state storage

This commit is contained in:
Neil Alexander 2021-04-14 17:30:37 +01:00
parent e08942fb00
commit a799847070
No known key found for this signature in database
GPG key ID: A02A2019A2BB0944
21 changed files with 540 additions and 1297 deletions

2
go.mod
View file

@ -25,7 +25,7 @@ require (
github.com/matrix-org/gomatrixserverlib v0.0.0-20210302161955-6142fe3f8c2c github.com/matrix-org/gomatrixserverlib v0.0.0-20210302161955-6142fe3f8c2c
github.com/matrix-org/naffka v0.0.0-20201009174903-d26a3b9cb161 github.com/matrix-org/naffka v0.0.0-20201009174903-d26a3b9cb161
github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4
github.com/mattn/go-sqlite3 v1.14.6 github.com/mattn/go-sqlite3 v1.14.7-0.20210414154423-1157a4212dcb
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646
github.com/ngrok/sqlmw v0.0.0-20200129213757-d5c93a81bec6 github.com/ngrok/sqlmw v0.0.0-20200129213757-d5c93a81bec6
github.com/opentracing/opentracing-go v1.2.0 github.com/opentracing/opentracing-go v1.2.0

4
go.sum
View file

@ -684,8 +684,8 @@ github.com/mattn/go-isatty v0.0.10/go.mod h1:qgIWMr58cqv1PHHyhnkY9lrL7etaEgOFcME
github.com/mattn/go-runewidth v0.0.2/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU= github.com/mattn/go-runewidth v0.0.2/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU=
github.com/mattn/go-runewidth v0.0.7/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= github.com/mattn/go-runewidth v0.0.7/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
github.com/mattn/go-sqlite3 v1.14.2/go.mod h1:JIl7NbARA7phWnGvh0LKTyg7S9BA+6gx71ShQilpsus= github.com/mattn/go-sqlite3 v1.14.2/go.mod h1:JIl7NbARA7phWnGvh0LKTyg7S9BA+6gx71ShQilpsus=
github.com/mattn/go-sqlite3 v1.14.6 h1:dNPt6NO46WmLVt2DLNpwczCmdV5boIZ6g/tlDrlRUbg= github.com/mattn/go-sqlite3 v1.14.7-0.20210414154423-1157a4212dcb h1:ax2vG2unlxsjwS7PMRo4FECIfAdQLowd6ejWYwPQhBo=
github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= github.com/mattn/go-sqlite3 v1.14.7-0.20210414154423-1157a4212dcb/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
github.com/mattn/goveralls v0.0.2/go.mod h1:8d1ZMHsd7fW6IRPKQh46F2WRpyib5/X4FOpevwGNQEw= github.com/mattn/goveralls v0.0.2/go.mod h1:8d1ZMHsd7fW6IRPKQh46F2WRpyib5/X4FOpevwGNQEw=
github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU= github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU=
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=

View file

@ -271,7 +271,7 @@ func (r *Inputer) calculateAndSetState(
} }
entries = types.DeduplicateStateEntries(entries) entries = types.DeduplicateStateEntries(entries)
if stateAtEvent.BeforeStateSnapshotNID, err = r.DB.AddState(ctx, roomInfo.RoomNID, nil, entries); err != nil { if stateAtEvent.BeforeStateSnapshotNID, err = r.DB.AddState(ctx, roomInfo.RoomNID /*nil,*/, entries); err != nil {
return fmt.Errorf("r.DB.AddState: %w", err) return fmt.Errorf("r.DB.AddState: %w", err)
} }
} else { } else {

View file

@ -146,7 +146,7 @@ func (r *Backfiller) backfillViaFederation(ctx context.Context, req *api.Perform
} }
var beforeStateSnapshotNID types.StateSnapshotNID var beforeStateSnapshotNID types.StateSnapshotNID
if beforeStateSnapshotNID, err = r.DB.AddState(ctx, roomNID, nil, entries); err != nil { if beforeStateSnapshotNID, err = r.DB.AddState(ctx, roomNID /*nil,*/, entries); err != nil {
logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to persist state entries to get snapshot nid") logrus.WithError(err).WithField("event_id", ev.EventID()).Error("backfillViaFederation: failed to persist state entries to get snapshot nid")
return err return err
} }

View file

@ -25,6 +25,7 @@ import (
"github.com/matrix-org/dendrite/roomserver/storage" "github.com/matrix-org/dendrite/roomserver/storage"
"github.com/matrix-org/util" "github.com/matrix-org/util"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/sirupsen/logrus"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -50,32 +51,10 @@ func NewStateResolution(db storage.Database, roomInfo types.RoomInfo) StateResol
func (v *StateResolution) LoadStateAtSnapshot( func (v *StateResolution) LoadStateAtSnapshot(
ctx context.Context, stateNID types.StateSnapshotNID, ctx context.Context, stateNID types.StateSnapshotNID,
) ([]types.StateEntry, error) { ) ([]types.StateEntry, error) {
stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID}) fullState, err := v.db.StateEntries(ctx, stateNID)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("v.db.StateEntries: %w", err)
} }
// We've asked for exactly one snapshot from the db so we should have exactly one entry in the result.
stateBlockNIDList := stateBlockNIDLists[0]
stateEntryLists, err := v.db.StateEntries(ctx, stateBlockNIDList.StateBlockNIDs)
if err != nil {
return nil, err
}
stateEntriesMap := stateEntryListMap(stateEntryLists)
// Combine all the state entries for this snapshot.
// The order of state block NIDs in the list tells us the order to combine them in.
var fullState []types.StateEntry
for _, stateBlockNID := range stateBlockNIDList.StateBlockNIDs {
entries, ok := stateEntriesMap.lookup(stateBlockNID)
if !ok {
// This should only get hit if the database is corrupt.
// It should be impossible for an event to reference a NID that doesn't exist
panic(fmt.Errorf("Corrupt DB: Missing state block numeric ID %d", stateBlockNID))
}
fullState = append(fullState, entries...)
}
// Stable sort so that the most recent entry for each state key stays // Stable sort so that the most recent entry for each state key stays
// remains later in the list than the older entries for the same state key. // remains later in the list than the older entries for the same state key.
sort.Stable(stateEntryByStateKeySorter(fullState)) sort.Stable(stateEntryByStateKeySorter(fullState))
@ -95,12 +74,10 @@ func (v *StateResolution) LoadStateAtEvent(
if snapshotNID == 0 { if snapshotNID == 0 {
return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID(%s) returned 0 NID, was this event stored?", eventID) return nil, fmt.Errorf("LoadStateAtEvent.SnapshotNIDFromEventID(%s) returned 0 NID, was this event stored?", eventID)
} }
stateEntries, err := v.LoadStateAtSnapshot(ctx, snapshotNID) stateEntries, err := v.LoadStateAtSnapshot(ctx, snapshotNID)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("v.LoadStateAtSnapshot: %w", err)
} }
return stateEntries, nil return stateEntries, nil
} }
@ -114,53 +91,33 @@ func (v *StateResolution) LoadCombinedStateAfterEvents(
for i, state := range prevStates { for i, state := range prevStates {
stateNIDs[i] = state.BeforeStateSnapshotNID stateNIDs[i] = state.BeforeStateSnapshotNID
} }
// Fetch the state snapshots for the state before the each prev event from the database.
// Deduplicate the IDs before passing them to the database.
// There could be duplicates because the events could be state events where
// the snapshot of the room state before them was the same.
stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, UniqueStateSnapshotNIDs(stateNIDs))
if err != nil {
return nil, fmt.Errorf("v.db.StateBlockNIDs: %w", err)
}
var stateBlockNIDs []types.StateBlockNID
for _, list := range stateBlockNIDLists {
stateBlockNIDs = append(stateBlockNIDs, list.StateBlockNIDs...)
}
// Fetch the state entries that will be combined to create the snapshots. // Fetch the state entries that will be combined to create the snapshots.
// Deduplicate the IDs before passing them to the database. // Deduplicate the IDs before passing them to the database.
// There could be duplicates because a block of state entries could be reused by // There could be duplicates because a block of state entries could be reused by
// multiple snapshots. // multiple snapshots.
stateEntryLists, err := v.db.StateEntries(ctx, uniqueStateBlockNIDs(stateBlockNIDs)) stateEntriesMap := map[types.StateSnapshotNID][]types.StateEntry{}
if err != nil { for _, nid := range stateNIDs {
return nil, fmt.Errorf("v.db.StateEntries: %w", err) entries, err := v.db.StateEntries(ctx, nid)
if err != nil {
return nil, fmt.Errorf("v.db.StateEntries: %w", err)
}
stateEntriesMap[nid] = entries
} }
stateBlockNIDsMap := stateBlockNIDListMap(stateBlockNIDLists)
stateEntriesMap := stateEntryListMap(stateEntryLists)
// Combine the entries from all the snapshots of state after each prev event into a single list. // Combine the entries from all the snapshots of state after each prev event into a single list.
var combined []types.StateEntry var combined []types.StateEntry
for _, prevState := range prevStates { for _, prevState := range prevStates {
// Grab the list of state data NIDs for this snapshot.
stateBlockNIDs, ok := stateBlockNIDsMap.lookup(prevState.BeforeStateSnapshotNID)
if !ok {
// This should only get hit if the database is corrupt.
// It should be impossible for an event to reference a NID that doesn't exist
panic(fmt.Errorf("Corrupt DB: Missing state snapshot numeric ID %d", prevState.BeforeStateSnapshotNID))
}
// Combine all the state entries for this snapshot. // Combine all the state entries for this snapshot.
// The order of state block NIDs in the list tells us the order to combine them in. // The order of state block NIDs in the list tells us the order to combine them in.
var fullState []types.StateEntry var fullState []types.StateEntry
for _, stateBlockNID := range stateBlockNIDs { entries, ok := stateEntriesMap[prevState.BeforeStateSnapshotNID]
entries, ok := stateEntriesMap.lookup(stateBlockNID) if !ok {
if !ok { // This should only get hit if the database is corrupt.
// This should only get hit if the database is corrupt. // It should be impossible for an event to reference a NID that doesn't exist
// It should be impossible for an event to reference a NID that doesn't exist panic(fmt.Errorf("Corrupt DB: Missing state snapshot %d", prevState.BeforeStateSnapshotNID))
panic(fmt.Errorf("Corrupt DB: Missing state block numeric ID %d", stateBlockNID))
}
fullState = append(fullState, entries...)
} }
fullState = append(fullState, entries...)
if prevState.IsStateEvent() && !prevState.IsRejected { if prevState.IsStateEvent() && !prevState.IsRejected {
// If the prev event was a state event then add an entry for the event itself // If the prev event was a state event then add an entry for the event itself
// so that we get the state after the event rather than the state before. // so that we get the state after the event rather than the state before.
@ -192,13 +149,13 @@ func (v *StateResolution) DifferenceBetweeenStateSnapshots(
if oldStateNID != 0 { if oldStateNID != 0 {
oldEntries, err = v.LoadStateAtSnapshot(ctx, oldStateNID) oldEntries, err = v.LoadStateAtSnapshot(ctx, oldStateNID)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, fmt.Errorf("v.LoadStateAtSnapshot: %w", err)
} }
} }
if newStateNID != 0 { if newStateNID != 0 {
newEntries, err = v.LoadStateAtSnapshot(ctx, newStateNID) newEntries, err = v.LoadStateAtSnapshot(ctx, newStateNID)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, fmt.Errorf("v.LoadStateAtSnapshot: %w", err)
} }
} }
@ -299,33 +256,9 @@ func (v *StateResolution) loadStateAtSnapshotForNumericTuples(
stateNID types.StateSnapshotNID, stateNID types.StateSnapshotNID,
stateKeyTuples []types.StateKeyTuple, stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntry, error) { ) ([]types.StateEntry, error) {
stateBlockNIDLists, err := v.db.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID}) fullState, err := v.db.StateEntries(ctx, stateNID)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("v.db.StateEntries: %w", err)
}
// We've asked for exactly one snapshot from the db so we should have exactly one entry in the result.
stateBlockNIDList := stateBlockNIDLists[0]
stateEntryLists, err := v.db.StateEntriesForTuples(
ctx, stateBlockNIDList.StateBlockNIDs, stateKeyTuples,
)
if err != nil {
return nil, err
}
stateEntriesMap := stateEntryListMap(stateEntryLists)
// Combine all the state entries for this snapshot.
// The order of state block NIDs in the list tells us the order to combine them in.
var fullState []types.StateEntry
for _, stateBlockNID := range stateBlockNIDList.StateBlockNIDs {
entries, ok := stateEntriesMap.lookup(stateBlockNID)
if !ok {
// If the block is missing from the map it means that none of its entries matched a requested tuple.
// This can happen if the block doesn't contain an update for one of the requested tuples.
// If none of the requested tuples are in the block then it can be safely skipped.
continue
}
fullState = append(fullState, entries...)
} }
// Stable sort so that the most recent entry for each state key stays // Stable sort so that the most recent entry for each state key stays
@ -549,7 +482,8 @@ func (v *StateResolution) CalculateAndStoreStateAfterEvents(
// 2) There weren't any prev_events for this event so the state is // 2) There weren't any prev_events for this event so the state is
// empty. // empty.
metrics.algorithm = "empty_state" metrics.algorithm = "empty_state"
stateNID, err := v.db.AddState(ctx, v.roomInfo.RoomNID, nil, nil) stateNID, err := v.db.AddState(ctx, v.roomInfo.RoomNID /*nil,*/, nil)
logrus.Warnf("Empty prev state added state snapshot %d", stateNID)
if err != nil { if err != nil {
err = fmt.Errorf("v.db.AddState: %w", err) err = fmt.Errorf("v.db.AddState: %w", err)
} }
@ -566,28 +500,56 @@ func (v *StateResolution) CalculateAndStoreStateAfterEvents(
metrics.algorithm = "no_change" metrics.algorithm = "no_change"
return metrics.stop(prevState.BeforeStateSnapshotNID, nil) return metrics.stop(prevState.BeforeStateSnapshotNID, nil)
} }
// The previous event was a state event so we need to store a copy
// of the previous state updated with that event. oldState, err := v.db.StateEntries(
stateBlockNIDLists, err := v.db.StateBlockNIDs( ctx, prevState.BeforeStateSnapshotNID,
ctx, []types.StateSnapshotNID{prevState.BeforeStateSnapshotNID},
) )
if err != nil { if err != nil {
metrics.algorithm = "_load_state_blocks" return 0, fmt.Errorf("v.db.StateEntries: %w", err)
return metrics.stop(0, fmt.Errorf("v.db.StateBlockNIDs: %w", err))
} }
stateBlockNIDs := stateBlockNIDLists[0].StateBlockNIDs found := false
if len(stateBlockNIDs) < maxStateBlockNIDs { for _, s := range oldState {
// 4) The number of state data blocks is small enough that we can just if s.EventNID == prevState.StateEntry.EventNID {
// add the state event as a block of size one to the end of the blocks. found = true
metrics.algorithm = "single_delta" break
stateNID, err := v.db.AddState( }
ctx, v.roomInfo.RoomNID, stateBlockNIDs, []types.StateEntry{prevState.StateEntry}, }
if !found {
oldState = append(oldState, prevState.StateEntry)
}
stateNID, err := v.db.AddState(
ctx, v.roomInfo.RoomNID, oldState,
)
logrus.Warnf("Single prev state added state snapshot %d", stateNID)
if err != nil {
return 0, fmt.Errorf("v.db.AddState: %w", err)
}
return stateNID, nil
/*
// The previous event was a state event so we need to store a copy
// of the previous state updated with that event.
stateBlockNIDLists, err := v.db.StateBlockNIDs(
ctx, []types.StateSnapshotNID{prevState.BeforeStateSnapshotNID},
) )
if err != nil { if err != nil {
err = fmt.Errorf("v.db.AddState: %w", err) metrics.algorithm = "_load_state_blocks"
return metrics.stop(0, fmt.Errorf("v.db.StateBlockNIDs: %w", err))
} }
return metrics.stop(stateNID, err) stateBlockNIDs := stateBlockNIDLists[0].StateBlockNIDs
} if len(stateBlockNIDs) < maxStateBlockNIDs {
// 4) The number of state data blocks is small enough that we can just
// add the state event as a block of size one to the end of the blocks.
metrics.algorithm = "single_delta"
stateNID, err := v.db.AddState(
ctx, v.roomInfo.RoomNID, stateBlockNIDs, []types.StateEntry{prevState.StateEntry},
)
if err != nil {
err = fmt.Errorf("v.db.AddState: %w", err)
}
return metrics.stop(stateNID, err)
}
*/
// If there are too many deltas then we need to calculate the full state // If there are too many deltas then we need to calculate the full state
// So fall through to calculateAndStoreStateAfterManyEvents // So fall through to calculateAndStoreStateAfterManyEvents
} }
@ -596,6 +558,7 @@ func (v *StateResolution) CalculateAndStoreStateAfterEvents(
if err != nil { if err != nil {
return 0, fmt.Errorf("v.calculateAndStoreStateAfterManyEvents: %w", err) return 0, fmt.Errorf("v.calculateAndStoreStateAfterManyEvents: %w", err)
} }
logrus.Warnf("Multiple prev states added state snapshot %d", stateNID)
return stateNID, nil return stateNID, nil
} }
@ -626,7 +589,7 @@ func (v *StateResolution) calculateAndStoreStateAfterManyEvents(
// previous state. // previous state.
metrics.conflictLength = conflictLength metrics.conflictLength = conflictLength
metrics.fullStateLength = len(state) metrics.fullStateLength = len(state)
return metrics.stop(v.db.AddState(ctx, roomNID, nil, state)) return metrics.stop(v.db.AddState(ctx, roomNID /*nil,*/, state))
} }
func (v *StateResolution) calculateStateAfterManyEvents( func (v *StateResolution) calculateStateAfterManyEvents(
@ -996,34 +959,6 @@ func (s stateEntrySorter) Len() int { return len(s) }
func (s stateEntrySorter) Less(i, j int) bool { return s[i].LessThan(s[j]) } func (s stateEntrySorter) Less(i, j int) bool { return s[i].LessThan(s[j]) }
func (s stateEntrySorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] } func (s stateEntrySorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
type stateBlockNIDListMap []types.StateBlockNIDList
func (m stateBlockNIDListMap) lookup(stateNID types.StateSnapshotNID) (stateBlockNIDs []types.StateBlockNID, ok bool) {
list := []types.StateBlockNIDList(m)
i := sort.Search(len(list), func(i int) bool {
return list[i].StateSnapshotNID >= stateNID
})
if i < len(list) && list[i].StateSnapshotNID == stateNID {
ok = true
stateBlockNIDs = list[i].StateBlockNIDs
}
return
}
type stateEntryListMap []types.StateEntryList
func (m stateEntryListMap) lookup(stateBlockNID types.StateBlockNID) (stateEntries []types.StateEntry, ok bool) {
list := []types.StateEntryList(m)
i := sort.Search(len(list), func(i int) bool {
return list[i].StateBlockNID >= stateBlockNID
})
if i < len(list) && list[i].StateBlockNID == stateBlockNID {
ok = true
stateEntries = list[i].StateEntries
}
return
}
type stateEntryByStateKeySorter []types.StateEntry type stateEntryByStateKeySorter []types.StateEntry
func (s stateEntryByStateKeySorter) Len() int { return len(s) } func (s stateEntryByStateKeySorter) Len() int { return len(s) }

View file

@ -33,7 +33,6 @@ type Database interface {
AddState( AddState(
ctx context.Context, ctx context.Context,
roomNID types.RoomNID, roomNID types.RoomNID,
stateBlockNIDs []types.StateBlockNID,
state []types.StateEntry, state []types.StateEntry,
) (types.StateSnapshotNID, error) ) (types.StateSnapshotNID, error)
// Look up the state of a room at each event for a list of string event IDs. // Look up the state of a room at each event for a list of string event IDs.
@ -47,21 +46,9 @@ type Database interface {
// Look up the numeric IDs for a list of string event state keys. // Look up the numeric IDs for a list of string event state keys.
// Returns a map from string state key to numeric ID for the state key. // Returns a map from string state key to numeric ID for the state key.
EventStateKeyNIDs(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error) EventStateKeyNIDs(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error)
// Look up the numeric state data IDs for each numeric state snapshot ID
// The returned slice is sorted by numeric state snapshot ID.
StateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error)
// Look up the state data for each numeric state data ID // Look up the state data for each numeric state data ID
// The returned slice is sorted by numeric state data ID. // The returned slice is sorted by numeric state data ID.
StateEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) StateEntries(ctx context.Context, stateSnapshotNID types.StateSnapshotNID) ([]types.StateEntry, error)
// Look up the state data for the state key tuples for each numeric state block ID
// This is used to fetch a subset of the room state at a snapshot.
// If a block doesn't contain any of the requested tuples then it can be discarded from the result.
// The returned slice is sorted by numeric state block ID.
StateEntriesForTuples(
ctx context.Context,
stateBlockNIDs []types.StateBlockNID,
stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntryList, error)
// Look up the Events for a list of numeric event IDs. // Look up the Events for a list of numeric event IDs.
// Returns a sorted list of events. // Returns a sorted list of events.
Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error) Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error)

View file

@ -88,6 +88,12 @@ const bulkSelectStateEventByIDSQL = "" +
" WHERE event_id = ANY($1)" + " WHERE event_id = ANY($1)" +
" ORDER BY event_type_nid, event_state_key_nid ASC" " ORDER BY event_type_nid, event_state_key_nid ASC"
// Bulk lookup of events by numeric ID.
const bulkSelectStateEventByNIDSQL = "" +
"SELECT event_type_nid, event_state_key_nid, event_nid FROM roomserver_events" +
" WHERE event_nid = ANY($1)" +
" ORDER BY event_type_nid, event_state_key_nid ASC"
const bulkSelectStateAtEventByIDSQL = "" + const bulkSelectStateAtEventByIDSQL = "" +
"SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, is_rejected FROM roomserver_events" + "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, is_rejected FROM roomserver_events" +
" WHERE event_id = ANY($1)" " WHERE event_id = ANY($1)"
@ -127,6 +133,7 @@ type eventStatements struct {
insertEventStmt *sql.Stmt insertEventStmt *sql.Stmt
selectEventStmt *sql.Stmt selectEventStmt *sql.Stmt
bulkSelectStateEventByIDStmt *sql.Stmt bulkSelectStateEventByIDStmt *sql.Stmt
bulkSelectStateEventByNIDStmt *sql.Stmt
bulkSelectStateAtEventByIDStmt *sql.Stmt bulkSelectStateAtEventByIDStmt *sql.Stmt
updateEventStateStmt *sql.Stmt updateEventStateStmt *sql.Stmt
selectEventSentToOutputStmt *sql.Stmt selectEventSentToOutputStmt *sql.Stmt
@ -151,6 +158,7 @@ func NewPostgresEventsTable(db *sql.DB) (tables.Events, error) {
{&s.insertEventStmt, insertEventSQL}, {&s.insertEventStmt, insertEventSQL},
{&s.selectEventStmt, selectEventSQL}, {&s.selectEventStmt, selectEventSQL},
{&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL}, {&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL},
{&s.bulkSelectStateEventByNIDStmt, bulkSelectStateEventByNIDSQL},
{&s.bulkSelectStateAtEventByIDStmt, bulkSelectStateAtEventByIDSQL}, {&s.bulkSelectStateAtEventByIDStmt, bulkSelectStateAtEventByIDSQL},
{&s.updateEventStateStmt, updateEventStateSQL}, {&s.updateEventStateStmt, updateEventStateSQL},
{&s.updateEventSentToOutputStmt, updateEventSentToOutputSQL}, {&s.updateEventSentToOutputStmt, updateEventSentToOutputSQL},
@ -238,6 +246,48 @@ func (s *eventStatements) BulkSelectStateEventByID(
return results, nil return results, nil
} }
// bulkSelectStateEventByNID lookups a list of state events by internal numeric ID.
// If any of the requested events are missing from the database it returns a types.MissingEventError
func (s *eventStatements) BulkSelectStateEventByNID(
ctx context.Context, eventNIDs []types.EventNID,
) ([]types.StateEntry, error) {
rows, err := s.bulkSelectStateEventByNIDStmt.QueryContext(ctx, eventNIDsAsArray(eventNIDs))
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateEventByID: rows.close() failed")
// We know that we will only get as many results as event NIDs
// because of the unique constraint on event NIDs.
// So we can allocate an array of the correct size now.
// We might get fewer results than NIDs so we adjust the length of the slice before returning it.
results := make([]types.StateEntry, len(eventNIDs))
i := 0
for ; rows.Next(); i++ {
result := &results[i]
if err = rows.Scan(
&result.EventTypeNID,
&result.EventStateKeyNID,
&result.EventNID,
); err != nil {
return nil, err
}
}
if err = rows.Err(); err != nil {
return nil, err
}
if i != len(eventNIDs) {
// If there are fewer rows returned than IDs then we were asked to lookup event IDs we don't have.
// We don't know which ones were missing because we don't return the string IDs in the query.
// However it should be possible debug this by replaying queries or entries from the input kafka logs.
// If this turns out to be impossible and we do need the debug information here, it would be better
// to do it as a separate query rather than slowing down/complicating the internal case.
return nil, types.MissingEventError(
fmt.Sprintf("storage: state event IDs missing from the database (%d != %d)", i, len(eventNIDs)),
)
}
return results, nil
}
// bulkSelectStateAtEventByID lookups the state at a list of events by event ID. // bulkSelectStateAtEventByID lookups the state at a list of events by event ID.
// If any of the requested events are missing from the database it returns a types.MissingEventError. // If any of the requested events are missing from the database it returns a types.MissingEventError.
// If we do not have the state for any of the requested events it returns a types.MissingEventError. // If we do not have the state for any of the requested events it returns a types.MissingEventError.

View file

@ -1,292 +0,0 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"fmt"
"sort"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/util"
)
const stateDataSchema = `
-- The state data map.
-- Designed to give enough information to run the state resolution algorithm
-- without hitting the database in the internal case.
-- TODO: Is it worth replacing the unique btree index with a covering index so
-- that postgres could lookup the state using an index-only scan?
-- The type and state_key are included in the index to make it easier to
-- lookup a specific (type, state_key) pair for an event. It also makes it easy
-- to read the state for a given state_block_nid ordered by (type, state_key)
-- which in turn makes it easier to merge state data blocks.
CREATE SEQUENCE IF NOT EXISTS roomserver_state_block_nid_seq;
CREATE TABLE IF NOT EXISTS roomserver_state_block (
-- Local numeric ID for this state data.
state_block_nid bigint NOT NULL,
event_type_nid bigint NOT NULL,
event_state_key_nid bigint NOT NULL,
event_nid bigint NOT NULL,
UNIQUE (state_block_nid, event_type_nid, event_state_key_nid)
);
`
const insertStateDataSQL = "" +
"INSERT INTO roomserver_state_block (state_block_nid, event_type_nid, event_state_key_nid, event_nid)" +
" VALUES ($1, $2, $3, $4)"
const selectNextStateBlockNIDSQL = "" +
"SELECT nextval('roomserver_state_block_nid_seq')"
// Bulk state lookup by numeric state block ID.
// Sort by the state_block_nid, event_type_nid, event_state_key_nid
// This means that all the entries for a given state_block_nid will appear
// together in the list and those entries will sorted by event_type_nid
// and event_state_key_nid. This property makes it easier to merge two
// state data blocks together.
const bulkSelectStateBlockEntriesSQL = "" +
"SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" +
" FROM roomserver_state_block WHERE state_block_nid = ANY($1)" +
" ORDER BY state_block_nid, event_type_nid, event_state_key_nid"
// Bulk state lookup by numeric state block ID.
// Filters the rows in each block to the requested types and state keys.
// We would like to restrict to particular type state key pairs but we are
// restricted by the query language to pull the cross product of a list
// of types and a list state_keys. So we have to filter the result in the
// application to restrict it to the list of event types and state keys we
// actually wanted.
const bulkSelectFilteredStateBlockEntriesSQL = "" +
"SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" +
" FROM roomserver_state_block WHERE state_block_nid = ANY($1)" +
" AND event_type_nid = ANY($2) AND event_state_key_nid = ANY($3)" +
" ORDER BY state_block_nid, event_type_nid, event_state_key_nid"
type stateBlockStatements struct {
insertStateDataStmt *sql.Stmt
selectNextStateBlockNIDStmt *sql.Stmt
bulkSelectStateBlockEntriesStmt *sql.Stmt
bulkSelectFilteredStateBlockEntriesStmt *sql.Stmt
}
func NewPostgresStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
s := &stateBlockStatements{}
_, err := db.Exec(stateDataSchema)
if err != nil {
return nil, err
}
return s, shared.StatementList{
{&s.insertStateDataStmt, insertStateDataSQL},
{&s.selectNextStateBlockNIDStmt, selectNextStateBlockNIDSQL},
{&s.bulkSelectStateBlockEntriesStmt, bulkSelectStateBlockEntriesSQL},
{&s.bulkSelectFilteredStateBlockEntriesStmt, bulkSelectFilteredStateBlockEntriesSQL},
}.Prepare(db)
}
func (s *stateBlockStatements) BulkInsertStateData(
ctx context.Context,
txn *sql.Tx,
entries []types.StateEntry,
) (types.StateBlockNID, error) {
stateBlockNID, err := s.selectNextStateBlockNID(ctx)
if err != nil {
return 0, err
}
for _, entry := range entries {
_, err := s.insertStateDataStmt.ExecContext(
ctx,
int64(stateBlockNID),
int64(entry.EventTypeNID),
int64(entry.EventStateKeyNID),
int64(entry.EventNID),
)
if err != nil {
return 0, err
}
}
return stateBlockNID, nil
}
func (s *stateBlockStatements) selectNextStateBlockNID(
ctx context.Context,
) (types.StateBlockNID, error) {
var stateBlockNID int64
err := s.selectNextStateBlockNIDStmt.QueryRowContext(ctx).Scan(&stateBlockNID)
return types.StateBlockNID(stateBlockNID), err
}
func (s *stateBlockStatements) BulkSelectStateBlockEntries(
ctx context.Context, stateBlockNIDs []types.StateBlockNID,
) ([]types.StateEntryList, error) {
nids := make([]int64, len(stateBlockNIDs))
for i := range stateBlockNIDs {
nids[i] = int64(stateBlockNIDs[i])
}
rows, err := s.bulkSelectStateBlockEntriesStmt.QueryContext(ctx, pq.Int64Array(nids))
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockEntries: rows.close() failed")
results := make([]types.StateEntryList, len(stateBlockNIDs))
// current is a pointer to the StateEntryList to append the state entries to.
var current *types.StateEntryList
i := 0
for rows.Next() {
var (
stateBlockNID int64
eventTypeNID int64
eventStateKeyNID int64
eventNID int64
entry types.StateEntry
)
if err = rows.Scan(
&stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID,
); err != nil {
return nil, err
}
entry.EventTypeNID = types.EventTypeNID(eventTypeNID)
entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID)
entry.EventNID = types.EventNID(eventNID)
if current == nil || types.StateBlockNID(stateBlockNID) != current.StateBlockNID {
// The state entry row is for a different state data block to the current one.
// So we start appending to the next entry in the list.
current = &results[i]
current.StateBlockNID = types.StateBlockNID(stateBlockNID)
i++
}
current.StateEntries = append(current.StateEntries, entry)
}
if err = rows.Err(); err != nil {
return nil, err
}
if i != len(stateBlockNIDs) {
return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", i, len(stateBlockNIDs))
}
return results, err
}
func (s *stateBlockStatements) BulkSelectFilteredStateBlockEntries(
ctx context.Context,
stateBlockNIDs []types.StateBlockNID,
stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntryList, error) {
tuples := stateKeyTupleSorter(stateKeyTuples)
// Sort the tuples so that we can run binary search against them as we filter the rows returned by the db.
sort.Sort(tuples)
eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays()
rows, err := s.bulkSelectFilteredStateBlockEntriesStmt.QueryContext(
ctx,
stateBlockNIDsAsArray(stateBlockNIDs),
eventTypeNIDArray,
eventStateKeyNIDArray,
)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectFilteredStateBlockEntries: rows.close() failed")
var results []types.StateEntryList
var current types.StateEntryList
for rows.Next() {
var (
stateBlockNID int64
eventTypeNID int64
eventStateKeyNID int64
eventNID int64
entry types.StateEntry
)
if err := rows.Scan(
&stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID,
); err != nil {
return nil, err
}
entry.EventTypeNID = types.EventTypeNID(eventTypeNID)
entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID)
entry.EventNID = types.EventNID(eventNID)
// We can use binary search here because we sorted the tuples earlier
if !tuples.contains(entry.StateKeyTuple) {
// The select will return the cross product of types and state keys.
// So we need to check if type of the entry is in the list.
continue
}
if types.StateBlockNID(stateBlockNID) != current.StateBlockNID {
// The state entry row is for a different state data block to the current one.
// So we append the current entry to the results and start adding to a new one.
// The first time through the loop current will be empty.
if current.StateEntries != nil {
results = append(results, current)
}
current = types.StateEntryList{StateBlockNID: types.StateBlockNID(stateBlockNID)}
}
current.StateEntries = append(current.StateEntries, entry)
}
// Add the last entry to the list if it is not empty.
if current.StateEntries != nil {
results = append(results, current)
}
return results, rows.Err()
}
func stateBlockNIDsAsArray(stateBlockNIDs []types.StateBlockNID) pq.Int64Array {
nids := make([]int64, len(stateBlockNIDs))
for i := range stateBlockNIDs {
nids[i] = int64(stateBlockNIDs[i])
}
return pq.Int64Array(nids)
}
type stateKeyTupleSorter []types.StateKeyTuple
func (s stateKeyTupleSorter) Len() int { return len(s) }
func (s stateKeyTupleSorter) Less(i, j int) bool { return s[i].LessThan(s[j]) }
func (s stateKeyTupleSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
// Check whether a tuple is in the list. Assumes that the list is sorted.
func (s stateKeyTupleSorter) contains(value types.StateKeyTuple) bool {
i := sort.Search(len(s), func(i int) bool { return !s[i].LessThan(value) })
return i < len(s) && s[i] == value
}
// List the unique eventTypeNIDs and eventStateKeyNIDs.
// Assumes that the list is sorted.
func (s stateKeyTupleSorter) typesAndStateKeysAsArrays() (eventTypeNIDs pq.Int64Array, eventStateKeyNIDs pq.Int64Array) {
eventTypeNIDs = make(pq.Int64Array, len(s))
eventStateKeyNIDs = make(pq.Int64Array, len(s))
for i := range s {
eventTypeNIDs[i] = int64(s[i].EventTypeNID)
eventStateKeyNIDs[i] = int64(s[i].EventStateKeyNID)
}
eventTypeNIDs = eventTypeNIDs[:util.SortAndUnique(int64Sorter(eventTypeNIDs))]
eventStateKeyNIDs = eventStateKeyNIDs[:util.SortAndUnique(int64Sorter(eventStateKeyNIDs))]
return
}
type int64Sorter []int64
func (s int64Sorter) Len() int { return len(s) }
func (s int64Sorter) Less(i, j int) bool { return s[i] < s[j] }
func (s int64Sorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }

View file

@ -1,86 +0,0 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"sort"
"testing"
"github.com/matrix-org/dendrite/roomserver/types"
)
func TestStateKeyTupleSorter(t *testing.T) {
input := stateKeyTupleSorter{
{EventTypeNID: 1, EventStateKeyNID: 2},
{EventTypeNID: 1, EventStateKeyNID: 4},
{EventTypeNID: 2, EventStateKeyNID: 2},
{EventTypeNID: 1, EventStateKeyNID: 1},
}
want := []types.StateKeyTuple{
{EventTypeNID: 1, EventStateKeyNID: 1},
{EventTypeNID: 1, EventStateKeyNID: 2},
{EventTypeNID: 1, EventStateKeyNID: 4},
{EventTypeNID: 2, EventStateKeyNID: 2},
}
doNotWant := []types.StateKeyTuple{
{EventTypeNID: 0, EventStateKeyNID: 0},
{EventTypeNID: 1, EventStateKeyNID: 3},
{EventTypeNID: 2, EventStateKeyNID: 1},
{EventTypeNID: 3, EventStateKeyNID: 1},
}
wantTypeNIDs := []int64{1, 2}
wantStateKeyNIDs := []int64{1, 2, 4}
// Sort the input and check it's in the right order.
sort.Sort(input)
gotTypeNIDs, gotStateKeyNIDs := input.typesAndStateKeysAsArrays()
for i := range want {
if input[i] != want[i] {
t.Errorf("Wanted %#v at index %d got %#v", want[i], i, input[i])
}
if !input.contains(want[i]) {
t.Errorf("Wanted %#v.contains(%#v) to be true but got false", input, want[i])
}
}
for i := range doNotWant {
if input.contains(doNotWant[i]) {
t.Errorf("Wanted %#v.contains(%#v) to be false but got true", input, doNotWant[i])
}
}
if len(wantTypeNIDs) != len(gotTypeNIDs) {
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
}
for i := range wantTypeNIDs {
if wantTypeNIDs[i] != gotTypeNIDs[i] {
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
}
}
if len(wantStateKeyNIDs) != len(gotStateKeyNIDs) {
t.Fatalf("Wanted state key NIDs %#v got %#v", wantStateKeyNIDs, gotStateKeyNIDs)
}
for i := range wantStateKeyNIDs {
if wantStateKeyNIDs[i] != gotStateKeyNIDs[i] {
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
}
}
}

View file

@ -1,126 +0,0 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"fmt"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types"
)
const stateSnapshotSchema = `
-- The state of a room before an event.
-- Stored as a list of state_block entries stored in a separate table.
-- The actual state is constructed by combining all the state_block entries
-- referenced by state_block_nids together. If the same state key tuple appears
-- multiple times then the entry from the later state_block clobbers the earlier
-- entries.
-- This encoding format allows us to implement a delta encoding which is useful
-- because room state tends to accumulate small changes over time. Although if
-- the list of deltas becomes too long it becomes more efficient to encode
-- the full state under single state_block_nid.
CREATE SEQUENCE IF NOT EXISTS roomserver_state_snapshot_nid_seq;
CREATE TABLE IF NOT EXISTS roomserver_state_snapshots (
-- Local numeric ID for the state.
state_snapshot_nid bigint PRIMARY KEY DEFAULT nextval('roomserver_state_snapshot_nid_seq'),
-- Local numeric ID of the room this state is for.
-- Unused in normal operation, but useful for background work or ad-hoc debugging.
room_nid bigint NOT NULL,
-- List of state_block_nids, stored sorted by state_block_nid.
state_block_nids bigint[] NOT NULL
);
`
const insertStateSQL = "" +
"INSERT INTO roomserver_state_snapshots (room_nid, state_block_nids)" +
" VALUES ($1, $2)" +
" RETURNING state_snapshot_nid"
// Bulk state data NID lookup.
// Sorting by state_snapshot_nid means we can use binary search over the result
// to lookup the state data NIDs for a state snapshot NID.
const bulkSelectStateBlockNIDsSQL = "" +
"SELECT state_snapshot_nid, state_block_nids FROM roomserver_state_snapshots" +
" WHERE state_snapshot_nid = ANY($1) ORDER BY state_snapshot_nid ASC"
type stateSnapshotStatements struct {
insertStateStmt *sql.Stmt
bulkSelectStateBlockNIDsStmt *sql.Stmt
}
func NewPostgresStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
s := &stateSnapshotStatements{}
_, err := db.Exec(stateSnapshotSchema)
if err != nil {
return nil, err
}
return s, shared.StatementList{
{&s.insertStateStmt, insertStateSQL},
{&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL},
}.Prepare(db)
}
func (s *stateSnapshotStatements) InsertState(
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID,
) (stateNID types.StateSnapshotNID, err error) {
nids := make([]int64, len(stateBlockNIDs))
for i := range stateBlockNIDs {
nids[i] = int64(stateBlockNIDs[i])
}
err = sqlutil.TxStmt(txn, s.insertStateStmt).QueryRowContext(ctx, int64(roomNID), pq.Int64Array(nids)).Scan(&stateNID)
return
}
func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
ctx context.Context, stateNIDs []types.StateSnapshotNID,
) ([]types.StateBlockNIDList, error) {
nids := make([]int64, len(stateNIDs))
for i := range stateNIDs {
nids[i] = int64(stateNIDs[i])
}
rows, err := s.bulkSelectStateBlockNIDsStmt.QueryContext(ctx, pq.Int64Array(nids))
if err != nil {
return nil, err
}
defer rows.Close() // nolint: errcheck
results := make([]types.StateBlockNIDList, len(stateNIDs))
i := 0
for ; rows.Next(); i++ {
result := &results[i]
var stateBlockNIDs pq.Int64Array
if err = rows.Scan(&result.StateSnapshotNID, &stateBlockNIDs); err != nil {
return nil, err
}
result.StateBlockNIDs = make([]types.StateBlockNID, len(stateBlockNIDs))
for k := range stateBlockNIDs {
result.StateBlockNIDs[k] = types.StateBlockNID(stateBlockNIDs[k])
}
}
if err = rows.Err(); err != nil {
return nil, err
}
if i != len(stateNIDs) {
return nil, fmt.Errorf("storage: state NIDs missing from the database (%d != %d)", i, len(stateNIDs))
}
return results, nil
}

View file

@ -0,0 +1,124 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"fmt"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types"
)
const stateSchema = `
CREATE TABLE IF NOT EXISTS roomserver_state (
state_nid BIGSERIAL PRIMARY KEY,
room_nid bigint NOT NULL,
event_nids bigint[] NOT NULL,
UNIQUE (room_nid, event_nids),
CONSTRAINT fk_room_id FOREIGN KEY(room_nid) REFERENCES roomserver_rooms(room_nid)
);
`
const insertNewStateSnapshotSQL = "" +
"INSERT INTO roomserver_state (room_nid, event_nids)" +
" VALUES ($1, $2)" +
" ON CONFLICT (room_nid, event_nids) DO UPDATE SET room_nid = $1" +
" RETURNING state_nid"
const bulkSelectNewStateSnapshotSQL = "" +
"SELECT state_nid, event_nids" +
" FROM roomserver_state WHERE state_nid = ANY($1)" +
" ORDER BY state_nid"
type stateStatements struct {
insertStateStmt *sql.Stmt
bulkSelectStateStmt *sql.Stmt
}
func NewPostgresStateTable(db *sql.DB) (tables.State, error) {
s := &stateStatements{}
_, err := db.Exec(stateSchema)
if err != nil {
return nil, err
}
return s, shared.StatementList{
{&s.insertStateStmt, insertNewStateSnapshotSQL},
{&s.bulkSelectStateStmt, bulkSelectNewStateSnapshotSQL},
}.Prepare(db)
}
func (s *stateStatements) InsertState(
ctx context.Context,
txn *sql.Tx,
roomNID types.RoomNID,
eventNIDs []types.EventNID,
) (types.StateSnapshotNID, error) {
stmt := sqlutil.TxStmt(txn, s.insertStateStmt)
var id int64
var err error
eventNIDs = types.DeduplicateEventNIDs(eventNIDs)
if err = stmt.QueryRowContext(ctx, int64(roomNID), eventNIDsAsArray(eventNIDs)).Scan(&id); err != nil {
return 0, fmt.Errorf("stmt.ExecContext: %w", err)
}
return types.StateSnapshotNID(id), err
}
func (s *stateStatements) BulkSelectState(
ctx context.Context, stateNIDs []types.StateSnapshotNID,
) (map[types.StateSnapshotNID][]types.EventNID, error) {
rows, err := s.bulkSelectStateStmt.QueryContext(ctx, stateSnapshotNIDsAsArray(stateNIDs))
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockEntries: rows.close() failed")
results := map[types.StateSnapshotNID][]types.EventNID{}
for rows.Next() {
var stateNID int64
var eventNIDs pq.Int64Array
if err = rows.Scan(&stateNID, &eventNIDs); err != nil {
return nil, err
}
for _, id := range eventNIDs {
results[types.StateSnapshotNID(stateNID)] = append(
results[types.StateSnapshotNID(stateNID)],
types.EventNID(id),
)
}
}
if err = rows.Err(); err != nil {
return nil, fmt.Errorf("rows.Err: %w", err)
}
if len(results) != len(stateNIDs) {
return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", len(results), len(stateNIDs))
}
return results, err
}
func stateSnapshotNIDsAsArray(stateSnapshotNIDs []types.StateSnapshotNID) pq.Int64Array {
nids := make([]int64, len(stateSnapshotNIDs))
for i := range stateSnapshotNIDs {
nids[i] = int64(stateSnapshotNIDs[i])
}
return nids
}

View file

@ -86,11 +86,17 @@ func (d *Database) prepare(db *sql.DB, cache caching.RoomServerCaches) (err erro
if err != nil { if err != nil {
return err return err
} }
stateBlock, err := NewPostgresStateBlockTable(db) /*
if err != nil { stateBlock, err := NewPostgresStateBlockTable(db)
return err if err != nil {
} return err
stateSnapshot, err := NewPostgresStateSnapshotTable(db) }
stateSnapshot, err := NewPostgresStateSnapshotTable(db)
if err != nil {
return err
}
*/
state, err := NewPostgresStateTable(db)
if err != nil { if err != nil {
return err return err
} }
@ -128,14 +134,15 @@ func (d *Database) prepare(db *sql.DB, cache caching.RoomServerCaches) (err erro
EventsTable: events, EventsTable: events,
RoomsTable: rooms, RoomsTable: rooms,
TransactionsTable: transactions, TransactionsTable: transactions,
StateBlockTable: stateBlock, StateTable: state,
StateSnapshotTable: stateSnapshot, // StateBlockTable: stateBlock,
PrevEventsTable: prevEvents, // StateSnapshotTable: stateSnapshot,
RoomAliasesTable: roomAliases, PrevEventsTable: prevEvents,
InvitesTable: invites, RoomAliasesTable: roomAliases,
MembershipTable: membership, InvitesTable: invites,
PublishedTable: published, MembershipTable: membership,
RedactionsTable: redactions, PublishedTable: published,
RedactionsTable: redactions,
} }
return nil return nil
} }

View file

@ -5,7 +5,6 @@ import (
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"fmt" "fmt"
"sort"
"github.com/matrix-org/dendrite/internal/caching" "github.com/matrix-org/dendrite/internal/caching"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
@ -13,7 +12,6 @@ import (
"github.com/matrix-org/dendrite/roomserver/storage/tables" "github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types" "github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
) )
@ -36,8 +34,7 @@ type Database struct {
EventStateKeysTable tables.EventStateKeys EventStateKeysTable tables.EventStateKeys
RoomsTable tables.Rooms RoomsTable tables.Rooms
TransactionsTable tables.Transactions TransactionsTable tables.Transactions
StateSnapshotTable tables.StateSnapshot StateTable tables.State
StateBlockTable tables.StateBlock
RoomAliasesTable tables.RoomAliases RoomAliasesTable tables.RoomAliases
PrevEventsTable tables.PreviousEvents PrevEventsTable tables.PreviousEvents
InvitesTable tables.Invites InvitesTable tables.Invites
@ -113,16 +110,6 @@ func (d *Database) StateEntriesForEventIDs(
return d.EventsTable.BulkSelectStateEventByID(ctx, eventIDs) return d.EventsTable.BulkSelectStateEventByID(ctx, eventIDs)
} }
func (d *Database) StateEntriesForTuples(
ctx context.Context,
stateBlockNIDs []types.StateBlockNID,
stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntryList, error) {
return d.StateBlockTable.BulkSelectFilteredStateBlockEntries(
ctx, stateBlockNIDs, stateKeyTuples,
)
}
func (d *Database) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) { func (d *Database) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error) {
if roomInfo, ok := d.Cache.GetRoomInfo(roomID); ok { if roomInfo, ok := d.Cache.GetRoomInfo(roomID); ok {
return &roomInfo, nil return &roomInfo, nil
@ -138,21 +125,16 @@ func (d *Database) RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo
func (d *Database) AddState( func (d *Database) AddState(
ctx context.Context, ctx context.Context,
roomNID types.RoomNID, roomNID types.RoomNID,
stateBlockNIDs []types.StateBlockNID,
state []types.StateEntry, state []types.StateEntry,
) (stateNID types.StateSnapshotNID, err error) { ) (stateNID types.StateSnapshotNID, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error { err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if len(state) > 0 { eventNIDs := make([]types.EventNID, 0, len(state))
var stateBlockNID types.StateBlockNID for _, s := range state {
stateBlockNID, err = d.StateBlockTable.BulkInsertStateData(ctx, txn, state) eventNIDs = append(eventNIDs, s.EventNID)
if err != nil {
return fmt.Errorf("d.StateBlockTable.BulkInsertStateData: %w", err)
}
stateBlockNIDs = append(stateBlockNIDs[:len(stateBlockNIDs):len(stateBlockNIDs)], stateBlockNID)
} }
stateNID, err = d.StateSnapshotTable.InsertState(ctx, txn, roomNID, stateBlockNIDs) stateNID, err = d.StateTable.InsertState(ctx, txn, roomNID, eventNIDs)
if err != nil { if err != nil {
return fmt.Errorf("d.StateSnapshotTable.InsertState: %w", err) return fmt.Errorf("d.StateTable.InsertState: %w", err)
} }
return nil return nil
}) })
@ -228,16 +210,22 @@ func (d *Database) LatestEventIDs(
return return
} }
func (d *Database) StateBlockNIDs(
ctx context.Context, stateNIDs []types.StateSnapshotNID,
) ([]types.StateBlockNIDList, error) {
return d.StateSnapshotTable.BulkSelectStateBlockNIDs(ctx, stateNIDs)
}
func (d *Database) StateEntries( func (d *Database) StateEntries(
ctx context.Context, stateBlockNIDs []types.StateBlockNID, ctx context.Context, stateSnapshotNID types.StateSnapshotNID,
) ([]types.StateEntryList, error) { ) ([]types.StateEntry, error) {
return d.StateBlockTable.BulkSelectStateBlockEntries(ctx, stateBlockNIDs) nids, err := d.StateTable.BulkSelectState(ctx, []types.StateSnapshotNID{stateSnapshotNID})
if err != nil {
return nil, fmt.Errorf("d.StateTable.BulkSelectState: %w", err)
}
state, ok := nids[stateSnapshotNID]
if !ok {
return nil, fmt.Errorf("state snapshot %d not found", stateSnapshotNID)
}
entries, err := d.EventsTable.BulkSelectStateEventByNID(ctx, state)
if err != nil {
return nil, fmt.Errorf("d.EventsTable.BulkSelectStateEventByNID: %w", err)
}
return entries, nil
} }
func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error { func (d *Database) SetRoomAlias(ctx context.Context, alias string, roomID string, creatorUserID string) error {
@ -817,9 +805,17 @@ func (d *Database) GetStateEvent(ctx context.Context, roomID, evType, stateKey s
if err != nil { if err != nil {
return nil, err return nil, err
} }
entries, err := d.loadStateAtSnapshot(ctx, roomInfo.StateSnapshotNID) snapshots, err := d.StateTable.BulkSelectState(ctx, []types.StateSnapshotNID{roomInfo.StateSnapshotNID})
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("d.StateTable.BulkSelectState: %w", err)
}
nids, ok := snapshots[roomInfo.StateSnapshotNID]
if !ok {
return nil, fmt.Errorf("state snapshot %d not found", roomInfo.StateSnapshotNID)
}
entries, err := d.EventsTable.BulkSelectStateEventByNID(ctx, nids)
if err != nil {
return nil, fmt.Errorf("d.EventsTable.BulkSelectStateEventByNID: %w", err)
} }
var eventNIDs []types.EventNID var eventNIDs []types.EventNID
for _, e := range entries { for _, e := range entries {
@ -938,7 +934,7 @@ func (d *Database) GetBulkStateContent(ctx context.Context, roomIDs []string, tu
if roomInfo == nil || roomInfo.IsStub { if roomInfo == nil || roomInfo.IsStub {
continue continue
} }
entries, err2 := d.loadStateAtSnapshot(ctx, roomInfo.StateSnapshotNID) entries, err2 := d.StateEntries(ctx, roomInfo.StateSnapshotNID)
if err2 != nil { if err2 != nil {
return nil, fmt.Errorf("GetBulkStateContent: failed to load state for room %s : %w", roomID, err2) return nil, fmt.Errorf("GetBulkStateContent: failed to load state for room %s : %w", roomID, err2)
} }
@ -1039,65 +1035,3 @@ func (d *Database) ForgetRoom(ctx context.Context, userID, roomID string, forget
return d.MembershipTable.UpdateForgetMembership(ctx, nil, roomNIDs[0], stateKeyNID, forget) return d.MembershipTable.UpdateForgetMembership(ctx, nil, roomNIDs[0], stateKeyNID, forget)
}) })
} }
// FIXME TODO: Remove all this - horrible dupe with roomserver/state. Can't use the original impl because of circular loops
// it should live in this package!
func (d *Database) loadStateAtSnapshot(
ctx context.Context, stateNID types.StateSnapshotNID,
) ([]types.StateEntry, error) {
stateBlockNIDLists, err := d.StateBlockNIDs(ctx, []types.StateSnapshotNID{stateNID})
if err != nil {
return nil, err
}
// We've asked for exactly one snapshot from the db so we should have exactly one entry in the result.
stateBlockNIDList := stateBlockNIDLists[0]
stateEntryLists, err := d.StateEntries(ctx, stateBlockNIDList.StateBlockNIDs)
if err != nil {
return nil, err
}
stateEntriesMap := stateEntryListMap(stateEntryLists)
// Combine all the state entries for this snapshot.
// The order of state block NIDs in the list tells us the order to combine them in.
var fullState []types.StateEntry
for _, stateBlockNID := range stateBlockNIDList.StateBlockNIDs {
entries, ok := stateEntriesMap.lookup(stateBlockNID)
if !ok {
// This should only get hit if the database is corrupt.
// It should be impossible for an event to reference a NID that doesn't exist
panic(fmt.Errorf("Corrupt DB: Missing state block numeric ID %d", stateBlockNID))
}
fullState = append(fullState, entries...)
}
// Stable sort so that the most recent entry for each state key stays
// remains later in the list than the older entries for the same state key.
sort.Stable(stateEntryByStateKeySorter(fullState))
// Unique returns the last entry and hence the most recent entry for each state key.
fullState = fullState[:util.Unique(stateEntryByStateKeySorter(fullState))]
return fullState, nil
}
type stateEntryListMap []types.StateEntryList
func (m stateEntryListMap) lookup(stateBlockNID types.StateBlockNID) (stateEntries []types.StateEntry, ok bool) {
list := []types.StateEntryList(m)
i := sort.Search(len(list), func(i int) bool {
return list[i].StateBlockNID >= stateBlockNID
})
if i < len(list) && list[i].StateBlockNID == stateBlockNID {
ok = true
stateEntries = list[i].StateEntries
}
return
}
type stateEntryByStateKeySorter []types.StateEntry
func (s stateEntryByStateKeySorter) Len() int { return len(s) }
func (s stateEntryByStateKeySorter) Less(i, j int) bool {
return s[i].StateKeyTuple.LessThan(s[j].StateKeyTuple)
}
func (s stateEntryByStateKeySorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }

View file

@ -63,6 +63,11 @@ const bulkSelectStateEventByIDSQL = "" +
" WHERE event_id IN ($1)" + " WHERE event_id IN ($1)" +
" ORDER BY event_type_nid, event_state_key_nid ASC" " ORDER BY event_type_nid, event_state_key_nid ASC"
const bulkSelectStateEventByNIDSQL = "" +
"SELECT event_type_nid, event_state_key_nid, event_nid FROM roomserver_events" +
" WHERE event_nid IN ($1)" +
" ORDER BY event_type_nid, event_state_key_nid ASC"
const bulkSelectStateAtEventByIDSQL = "" + const bulkSelectStateAtEventByIDSQL = "" +
"SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, is_rejected FROM roomserver_events" + "SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, is_rejected FROM roomserver_events" +
" WHERE event_id IN ($1)" " WHERE event_id IN ($1)"
@ -103,6 +108,7 @@ type eventStatements struct {
insertEventStmt *sql.Stmt insertEventStmt *sql.Stmt
selectEventStmt *sql.Stmt selectEventStmt *sql.Stmt
bulkSelectStateEventByIDStmt *sql.Stmt bulkSelectStateEventByIDStmt *sql.Stmt
bulkSelectStateEventByNIDStmt *sql.Stmt
bulkSelectStateAtEventByIDStmt *sql.Stmt bulkSelectStateAtEventByIDStmt *sql.Stmt
updateEventStateStmt *sql.Stmt updateEventStateStmt *sql.Stmt
selectEventSentToOutputStmt *sql.Stmt selectEventSentToOutputStmt *sql.Stmt
@ -128,6 +134,7 @@ func NewSqliteEventsTable(db *sql.DB) (tables.Events, error) {
{&s.insertEventStmt, insertEventSQL}, {&s.insertEventStmt, insertEventSQL},
{&s.selectEventStmt, selectEventSQL}, {&s.selectEventStmt, selectEventSQL},
{&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL}, {&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL},
{&s.bulkSelectStateEventByNIDStmt, bulkSelectStateEventByNIDSQL},
{&s.bulkSelectStateAtEventByIDStmt, bulkSelectStateAtEventByIDSQL}, {&s.bulkSelectStateAtEventByIDStmt, bulkSelectStateAtEventByIDSQL},
{&s.updateEventStateStmt, updateEventStateSQL}, {&s.updateEventStateStmt, updateEventStateSQL},
{&s.updateEventSentToOutputStmt, updateEventSentToOutputSQL}, {&s.updateEventSentToOutputStmt, updateEventSentToOutputSQL},
@ -232,6 +239,57 @@ func (s *eventStatements) BulkSelectStateEventByID(
return results, err return results, err
} }
// bulkSelectStateEventByID lookups a list of state events by event ID.
// If any of the requested events are missing from the database it returns a types.MissingEventError
func (s *eventStatements) BulkSelectStateEventByNID(
ctx context.Context, eventNIDs []types.EventNID,
) ([]types.StateEntry, error) {
///////////////
iEventNIDs := make([]interface{}, len(eventNIDs))
for k, v := range eventNIDs {
iEventNIDs[k] = v
}
selectOrig := strings.Replace(bulkSelectStateEventByNIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1)
selectStmt, err := s.db.Prepare(selectOrig)
if err != nil {
return nil, err
}
///////////////
rows, err := selectStmt.QueryContext(ctx, iEventNIDs...)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateEventByID: rows.close() failed")
// We know that we will only get as many results as event IDs
// because of the unique constraint on event IDs.
// So we can allocate an array of the correct size now.
// We might get fewer results than IDs so we adjust the length of the slice before returning it.
results := make([]types.StateEntry, len(eventNIDs))
i := 0
for ; rows.Next(); i++ {
result := &results[i]
if err = rows.Scan(
&result.EventTypeNID,
&result.EventStateKeyNID,
&result.EventNID,
); err != nil {
return nil, err
}
}
if i != len(eventNIDs) {
// If there are fewer rows returned than IDs then we were asked to lookup event IDs we don't have.
// We don't know which ones were missing because we don't return the string IDs in the query.
// However it should be possible debug this by replaying queries or entries from the input kafka logs.
// If this turns out to be impossible and we do need the debug information here, it would be better
// to do it as a separate query rather than slowing down/complicating the internal case.
return nil, types.MissingEventError(
fmt.Sprintf("storage: state event IDs missing from the database (%d != %d)", i, len(eventNIDs)),
)
}
return results, err
}
// bulkSelectStateAtEventByID lookups the state at a list of events by event ID. // bulkSelectStateAtEventByID lookups the state at a list of events by event ID.
// If any of the requested events are missing from the database it returns a types.MissingEventError. // If any of the requested events are missing from the database it returns a types.MissingEventError.
// If we do not have the state for any of the requested events it returns a types.MissingEventError. // If we do not have the state for any of the requested events it returns a types.MissingEventError.

View file

@ -1,289 +0,0 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"fmt"
"sort"
"strings"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types"
"github.com/matrix-org/util"
)
const stateDataSchema = `
CREATE TABLE IF NOT EXISTS roomserver_state_block (
state_block_nid INTEGER NOT NULL,
event_type_nid INTEGER NOT NULL,
event_state_key_nid INTEGER NOT NULL,
event_nid INTEGER NOT NULL,
UNIQUE (state_block_nid, event_type_nid, event_state_key_nid)
);
`
const insertStateDataSQL = "" +
"INSERT INTO roomserver_state_block (state_block_nid, event_type_nid, event_state_key_nid, event_nid)" +
" VALUES ($1, $2, $3, $4)"
const selectNextStateBlockNIDSQL = `
SELECT IFNULL(MAX(state_block_nid), 0) + 1 FROM roomserver_state_block
`
// Bulk state lookup by numeric state block ID.
// Sort by the state_block_nid, event_type_nid, event_state_key_nid
// This means that all the entries for a given state_block_nid will appear
// together in the list and those entries will sorted by event_type_nid
// and event_state_key_nid. This property makes it easier to merge two
// state data blocks together.
const bulkSelectStateBlockEntriesSQL = "" +
"SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" +
" FROM roomserver_state_block WHERE state_block_nid IN ($1)" +
" ORDER BY state_block_nid, event_type_nid, event_state_key_nid"
// Bulk state lookup by numeric state block ID.
// Filters the rows in each block to the requested types and state keys.
// We would like to restrict to particular type state key pairs but we are
// restricted by the query language to pull the cross product of a list
// of types and a list state_keys. So we have to filter the result in the
// application to restrict it to the list of event types and state keys we
// actually wanted.
const bulkSelectFilteredStateBlockEntriesSQL = "" +
"SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" +
" FROM roomserver_state_block WHERE state_block_nid IN ($1)" +
" AND event_type_nid IN ($2) AND event_state_key_nid IN ($3)" +
" ORDER BY state_block_nid, event_type_nid, event_state_key_nid"
type stateBlockStatements struct {
db *sql.DB
insertStateDataStmt *sql.Stmt
selectNextStateBlockNIDStmt *sql.Stmt
bulkSelectStateBlockEntriesStmt *sql.Stmt
bulkSelectFilteredStateBlockEntriesStmt *sql.Stmt
}
func NewSqliteStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
s := &stateBlockStatements{
db: db,
}
_, err := db.Exec(stateDataSchema)
if err != nil {
return nil, err
}
return s, shared.StatementList{
{&s.insertStateDataStmt, insertStateDataSQL},
{&s.selectNextStateBlockNIDStmt, selectNextStateBlockNIDSQL},
{&s.bulkSelectStateBlockEntriesStmt, bulkSelectStateBlockEntriesSQL},
{&s.bulkSelectFilteredStateBlockEntriesStmt, bulkSelectFilteredStateBlockEntriesSQL},
}.Prepare(db)
}
func (s *stateBlockStatements) BulkInsertStateData(
ctx context.Context, txn *sql.Tx,
entries []types.StateEntry,
) (types.StateBlockNID, error) {
if len(entries) == 0 {
return 0, nil
}
var stateBlockNID types.StateBlockNID
err := sqlutil.TxStmt(txn, s.selectNextStateBlockNIDStmt).QueryRowContext(ctx).Scan(&stateBlockNID)
if err != nil {
return 0, err
}
for _, entry := range entries {
_, err = sqlutil.TxStmt(txn, s.insertStateDataStmt).ExecContext(
ctx,
int64(stateBlockNID),
int64(entry.EventTypeNID),
int64(entry.EventStateKeyNID),
int64(entry.EventNID),
)
if err != nil {
return 0, err
}
}
return stateBlockNID, err
}
func (s *stateBlockStatements) BulkSelectStateBlockEntries(
ctx context.Context, stateBlockNIDs []types.StateBlockNID,
) ([]types.StateEntryList, error) {
nids := make([]interface{}, len(stateBlockNIDs))
for k, v := range stateBlockNIDs {
nids[k] = v
}
selectOrig := strings.Replace(bulkSelectStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1)
selectStmt, err := s.db.Prepare(selectOrig)
if err != nil {
return nil, err
}
rows, err := selectStmt.QueryContext(ctx, nids...)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockEntries: rows.close() failed")
results := make([]types.StateEntryList, len(stateBlockNIDs))
// current is a pointer to the StateEntryList to append the state entries to.
var current *types.StateEntryList
i := 0
for rows.Next() {
var (
stateBlockNID int64
eventTypeNID int64
eventStateKeyNID int64
eventNID int64
entry types.StateEntry
)
if err := rows.Scan(
&stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID,
); err != nil {
return nil, err
}
entry.EventTypeNID = types.EventTypeNID(eventTypeNID)
entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID)
entry.EventNID = types.EventNID(eventNID)
if current == nil || types.StateBlockNID(stateBlockNID) != current.StateBlockNID {
// The state entry row is for a different state data block to the current one.
// So we start appending to the next entry in the list.
current = &results[i]
current.StateBlockNID = types.StateBlockNID(stateBlockNID)
i++
}
current.StateEntries = append(current.StateEntries, entry)
}
if i != len(nids) {
return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", i, len(nids))
}
return results, nil
}
func (s *stateBlockStatements) BulkSelectFilteredStateBlockEntries(
ctx context.Context,
stateBlockNIDs []types.StateBlockNID,
stateKeyTuples []types.StateKeyTuple,
) ([]types.StateEntryList, error) {
tuples := stateKeyTupleSorter(stateKeyTuples)
// Sort the tuples so that we can run binary search against them as we filter the rows returned by the db.
sort.Sort(tuples)
eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays()
sqlStatement := strings.Replace(bulkSelectFilteredStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(stateBlockNIDs)), 1)
sqlStatement = strings.Replace(sqlStatement, "($2)", sqlutil.QueryVariadicOffset(len(eventTypeNIDArray), len(stateBlockNIDs)), 1)
sqlStatement = strings.Replace(sqlStatement, "($3)", sqlutil.QueryVariadicOffset(len(eventStateKeyNIDArray), len(stateBlockNIDs)+len(eventTypeNIDArray)), 1)
var params []interface{}
for _, val := range stateBlockNIDs {
params = append(params, int64(val))
}
for _, val := range eventTypeNIDArray {
params = append(params, val)
}
for _, val := range eventStateKeyNIDArray {
params = append(params, val)
}
rows, err := s.db.QueryContext(
ctx,
sqlStatement,
params...,
)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectFilteredStateBlockEntries: rows.close() failed")
var results []types.StateEntryList
var current types.StateEntryList
for rows.Next() {
var (
stateBlockNID int64
eventTypeNID int64
eventStateKeyNID int64
eventNID int64
entry types.StateEntry
)
if err := rows.Scan(
&stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID,
); err != nil {
return nil, err
}
entry.EventTypeNID = types.EventTypeNID(eventTypeNID)
entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID)
entry.EventNID = types.EventNID(eventNID)
// We can use binary search here because we sorted the tuples earlier
if !tuples.contains(entry.StateKeyTuple) {
// The select will return the cross product of types and state keys.
// So we need to check if type of the entry is in the list.
continue
}
if types.StateBlockNID(stateBlockNID) != current.StateBlockNID {
// The state entry row is for a different state data block to the current one.
// So we append the current entry to the results and start adding to a new one.
// The first time through the loop current will be empty.
if current.StateEntries != nil {
results = append(results, current)
}
current = types.StateEntryList{StateBlockNID: types.StateBlockNID(stateBlockNID)}
}
current.StateEntries = append(current.StateEntries, entry)
}
// Add the last entry to the list if it is not empty.
if current.StateEntries != nil {
results = append(results, current)
}
return results, nil
}
type stateKeyTupleSorter []types.StateKeyTuple
func (s stateKeyTupleSorter) Len() int { return len(s) }
func (s stateKeyTupleSorter) Less(i, j int) bool { return s[i].LessThan(s[j]) }
func (s stateKeyTupleSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
// Check whether a tuple is in the list. Assumes that the list is sorted.
func (s stateKeyTupleSorter) contains(value types.StateKeyTuple) bool {
i := sort.Search(len(s), func(i int) bool { return !s[i].LessThan(value) })
return i < len(s) && s[i] == value
}
// List the unique eventTypeNIDs and eventStateKeyNIDs.
// Assumes that the list is sorted.
func (s stateKeyTupleSorter) typesAndStateKeysAsArrays() (eventTypeNIDs []int64, eventStateKeyNIDs []int64) {
eventTypeNIDs = make([]int64, len(s))
eventStateKeyNIDs = make([]int64, len(s))
for i := range s {
eventTypeNIDs[i] = int64(s[i].EventTypeNID)
eventStateKeyNIDs[i] = int64(s[i].EventStateKeyNID)
}
eventTypeNIDs = eventTypeNIDs[:util.SortAndUnique(int64Sorter(eventTypeNIDs))]
eventStateKeyNIDs = eventStateKeyNIDs[:util.SortAndUnique(int64Sorter(eventStateKeyNIDs))]
return
}
type int64Sorter []int64
func (s int64Sorter) Len() int { return len(s) }
func (s int64Sorter) Less(i, j int) bool { return s[i] < s[j] }
func (s int64Sorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }

View file

@ -1,86 +0,0 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"sort"
"testing"
"github.com/matrix-org/dendrite/roomserver/types"
)
func TestStateKeyTupleSorter(t *testing.T) {
input := stateKeyTupleSorter{
{EventTypeNID: 1, EventStateKeyNID: 2},
{EventTypeNID: 1, EventStateKeyNID: 4},
{EventTypeNID: 2, EventStateKeyNID: 2},
{EventTypeNID: 1, EventStateKeyNID: 1},
}
want := []types.StateKeyTuple{
{EventTypeNID: 1, EventStateKeyNID: 1},
{EventTypeNID: 1, EventStateKeyNID: 2},
{EventTypeNID: 1, EventStateKeyNID: 4},
{EventTypeNID: 2, EventStateKeyNID: 2},
}
doNotWant := []types.StateKeyTuple{
{EventTypeNID: 0, EventStateKeyNID: 0},
{EventTypeNID: 1, EventStateKeyNID: 3},
{EventTypeNID: 2, EventStateKeyNID: 1},
{EventTypeNID: 3, EventStateKeyNID: 1},
}
wantTypeNIDs := []int64{1, 2}
wantStateKeyNIDs := []int64{1, 2, 4}
// Sort the input and check it's in the right order.
sort.Sort(input)
gotTypeNIDs, gotStateKeyNIDs := input.typesAndStateKeysAsArrays()
for i := range want {
if input[i] != want[i] {
t.Errorf("Wanted %#v at index %d got %#v", want[i], i, input[i])
}
if !input.contains(want[i]) {
t.Errorf("Wanted %#v.contains(%#v) to be true but got false", input, want[i])
}
}
for i := range doNotWant {
if input.contains(doNotWant[i]) {
t.Errorf("Wanted %#v.contains(%#v) to be false but got true", input, doNotWant[i])
}
}
if len(wantTypeNIDs) != len(gotTypeNIDs) {
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
}
for i := range wantTypeNIDs {
if wantTypeNIDs[i] != gotTypeNIDs[i] {
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
}
}
if len(wantStateKeyNIDs) != len(gotStateKeyNIDs) {
t.Fatalf("Wanted state key NIDs %#v got %#v", wantStateKeyNIDs, gotStateKeyNIDs)
}
for i := range wantStateKeyNIDs {
if wantStateKeyNIDs[i] != gotStateKeyNIDs[i] {
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
}
}
}

View file

@ -1,126 +0,0 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types"
)
const stateSnapshotSchema = `
CREATE TABLE IF NOT EXISTS roomserver_state_snapshots (
state_snapshot_nid INTEGER PRIMARY KEY AUTOINCREMENT,
room_nid INTEGER NOT NULL,
state_block_nids TEXT NOT NULL DEFAULT '[]'
);
`
const insertStateSQL = `
INSERT INTO roomserver_state_snapshots (room_nid, state_block_nids)
VALUES ($1, $2);`
// Bulk state data NID lookup.
// Sorting by state_snapshot_nid means we can use binary search over the result
// to lookup the state data NIDs for a state snapshot NID.
const bulkSelectStateBlockNIDsSQL = "" +
"SELECT state_snapshot_nid, state_block_nids FROM roomserver_state_snapshots" +
" WHERE state_snapshot_nid IN ($1) ORDER BY state_snapshot_nid ASC"
type stateSnapshotStatements struct {
db *sql.DB
insertStateStmt *sql.Stmt
bulkSelectStateBlockNIDsStmt *sql.Stmt
}
func NewSqliteStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
s := &stateSnapshotStatements{
db: db,
}
_, err := db.Exec(stateSnapshotSchema)
if err != nil {
return nil, err
}
return s, shared.StatementList{
{&s.insertStateStmt, insertStateSQL},
{&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL},
}.Prepare(db)
}
func (s *stateSnapshotStatements) InsertState(
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID,
) (stateNID types.StateSnapshotNID, err error) {
stateBlockNIDsJSON, err := json.Marshal(stateBlockNIDs)
if err != nil {
return
}
insertStmt := sqlutil.TxStmt(txn, s.insertStateStmt)
res, err := insertStmt.ExecContext(ctx, int64(roomNID), string(stateBlockNIDsJSON))
if err != nil {
return 0, err
}
lastRowID, err := res.LastInsertId()
if err != nil {
return 0, err
}
stateNID = types.StateSnapshotNID(lastRowID)
return
}
func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
ctx context.Context, stateNIDs []types.StateSnapshotNID,
) ([]types.StateBlockNIDList, error) {
nids := make([]interface{}, len(stateNIDs))
for k, v := range stateNIDs {
nids[k] = v
}
selectOrig := strings.Replace(bulkSelectStateBlockNIDsSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1)
selectStmt, err := s.db.Prepare(selectOrig)
if err != nil {
return nil, err
}
rows, err := selectStmt.QueryContext(ctx, nids...)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockNIDs: rows.close() failed")
results := make([]types.StateBlockNIDList, len(stateNIDs))
i := 0
for ; rows.Next(); i++ {
result := &results[i]
var stateBlockNIDsJSON string
if err := rows.Scan(&result.StateSnapshotNID, &stateBlockNIDsJSON); err != nil {
return nil, err
}
if err := json.Unmarshal([]byte(stateBlockNIDsJSON), &result.StateBlockNIDs); err != nil {
return nil, err
}
}
if i != len(stateNIDs) {
return nil, fmt.Errorf("storage: state NIDs missing from the database (%d != %d)", i, len(stateNIDs))
}
return results, nil
}

View file

@ -0,0 +1,132 @@
// Copyright 2017-2018 New Vector Ltd
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/roomserver/storage/shared"
"github.com/matrix-org/dendrite/roomserver/storage/tables"
"github.com/matrix-org/dendrite/roomserver/types"
)
const stateSchema = `
CREATE TABLE IF NOT EXISTS roomserver_state (
state_nid INTEGER PRIMARY KEY AUTOINCREMENT,
room_nid INTEGER NOT NULL,
event_nids TEXT NOT NULL,
UNIQUE (room_nid, event_nids),
CONSTRAINT fk_room_id FOREIGN KEY(room_nid) REFERENCES roomserver_rooms(room_nid)
);
`
const insertNewStateSnapshotSQL = "" +
"INSERT INTO roomserver_state (room_nid, event_nids)" +
" VALUES ($1, $2)" +
" ON CONFLICT (room_nid, event_nids) DO UPDATE SET room_nid = $1" +
" RETURNING state_nid"
const bulkSelectNewStateSnapshotSQL = "" +
"SELECT state_nid, event_nids" +
" FROM roomserver_state WHERE state_nid IN ($1)" +
" ORDER BY state_nid"
type stateStatements struct {
db *sql.DB
insertStateStmt *sql.Stmt
bulkSelectStateStmt *sql.Stmt
}
func NewPostgresStateTable(db *sql.DB) (tables.State, error) {
s := &stateStatements{db: db}
_, err := db.Exec(stateSchema)
if err != nil {
return nil, err
}
return s, shared.StatementList{
{&s.insertStateStmt, insertNewStateSnapshotSQL},
{&s.bulkSelectStateStmt, bulkSelectNewStateSnapshotSQL},
}.Prepare(db)
}
func (s *stateStatements) InsertState(
ctx context.Context,
txn *sql.Tx,
roomNID types.RoomNID,
eventNIDs []types.EventNID,
) (types.StateSnapshotNID, error) {
stmt := sqlutil.TxStmt(txn, s.insertStateStmt)
value, err := json.Marshal(types.DeduplicateEventNIDs(eventNIDs))
if err != nil {
return 0, fmt.Errorf("json.Marshal: %w", err)
}
var id int64
if err = stmt.QueryRowContext(ctx, int64(roomNID), value).Scan(&id); err != nil {
return 0, fmt.Errorf("stmt.ExecContext: %w", err)
}
return types.StateSnapshotNID(id), nil
}
func (s *stateStatements) BulkSelectState(
ctx context.Context, stateNIDs []types.StateSnapshotNID,
) (map[types.StateSnapshotNID][]types.EventNID, error) {
selectOrig := strings.Replace(bulkSelectNewStateSnapshotSQL, "($1)", sqlutil.QueryVariadic(len(stateNIDs)), 1)
selectStmt, err := s.db.Prepare(selectOrig)
if err != nil {
return nil, err
}
nids := make([]interface{}, len(stateNIDs))
for i := range stateNIDs {
nids[i] = int64(stateNIDs[i])
}
rows, err := selectStmt.QueryContext(ctx, nids...)
if err != nil {
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockEntries: rows.close() failed")
results := map[types.StateSnapshotNID][]types.EventNID{}
for rows.Next() {
var stateNID int64
var eventNIDJSON json.RawMessage
if err = rows.Scan(&stateNID, &eventNIDJSON); err != nil {
return nil, fmt.Errorf("rows.Scan: %w", err)
}
var eventNIDs []int64
if err = json.Unmarshal(eventNIDJSON, &eventNIDs); err != nil {
return nil, fmt.Errorf("json.Unmarshal: %w", err)
}
for _, id := range eventNIDs {
results[types.StateSnapshotNID(stateNID)] = append(
results[types.StateSnapshotNID(stateNID)],
types.EventNID(id),
)
}
}
if err = rows.Err(); err != nil {
return nil, fmt.Errorf("rows.Err: %w", err)
}
if len(results) != len(stateNIDs) {
return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", len(results), len(stateNIDs))
}
return results, err
}

View file

@ -98,11 +98,17 @@ func (d *Database) prepare(db *sql.DB, cache caching.RoomServerCaches) error {
if err != nil { if err != nil {
return err return err
} }
stateBlock, err := NewSqliteStateBlockTable(db) /*
if err != nil { stateBlock, err := NewSqliteStateBlockTable(db)
return err if err != nil {
} return err
stateSnapshot, err := NewSqliteStateSnapshotTable(db) }
stateSnapshot, err := NewSqliteStateSnapshotTable(db)
if err != nil {
return err
}
*/
state, err := NewPostgresStateTable(db)
if err != nil { if err != nil {
return err return err
} }
@ -131,17 +137,18 @@ func (d *Database) prepare(db *sql.DB, cache caching.RoomServerCaches) error {
return err return err
} }
d.Database = shared.Database{ d.Database = shared.Database{
DB: db, DB: db,
Cache: cache, Cache: cache,
Writer: sqlutil.NewExclusiveWriter(), Writer: sqlutil.NewExclusiveWriter(),
EventsTable: events, EventsTable: events,
EventTypesTable: eventTypes, EventTypesTable: eventTypes,
EventStateKeysTable: eventStateKeys, EventStateKeysTable: eventStateKeys,
EventJSONTable: eventJSON, EventJSONTable: eventJSON,
RoomsTable: rooms, RoomsTable: rooms,
TransactionsTable: transactions, TransactionsTable: transactions,
StateBlockTable: stateBlock, StateTable: state,
StateSnapshotTable: stateSnapshot, // StateBlockTable: stateBlock,
// StateSnapshotTable: stateSnapshot,
PrevEventsTable: prevEvents, PrevEventsTable: prevEvents,
RoomAliasesTable: roomAliases, RoomAliasesTable: roomAliases,
InvitesTable: invites, InvitesTable: invites,

View file

@ -43,6 +43,7 @@ type Events interface {
// bulkSelectStateEventByID lookups a list of state events by event ID. // bulkSelectStateEventByID lookups a list of state events by event ID.
// If any of the requested events are missing from the database it returns a types.MissingEventError // If any of the requested events are missing from the database it returns a types.MissingEventError
BulkSelectStateEventByID(ctx context.Context, eventIDs []string) ([]types.StateEntry, error) BulkSelectStateEventByID(ctx context.Context, eventIDs []string) ([]types.StateEntry, error)
BulkSelectStateEventByNID(ctx context.Context, eventNIDs []types.EventNID) ([]types.StateEntry, error)
// BulkSelectStateAtEventByID lookups the state at a list of events by event ID. // BulkSelectStateAtEventByID lookups the state at a list of events by event ID.
// If any of the requested events are missing from the database it returns a types.MissingEventError. // If any of the requested events are missing from the database it returns a types.MissingEventError.
// If we do not have the state for any of the requested events it returns a types.MissingEventError. // If we do not have the state for any of the requested events it returns a types.MissingEventError.
@ -80,6 +81,12 @@ type Transactions interface {
SelectTransactionEventID(ctx context.Context, transactionID string, sessionID int64, userID string) (eventID string, err error) SelectTransactionEventID(ctx context.Context, transactionID string, sessionID int64, userID string) (eventID string, err error)
} }
type State interface {
InsertState(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNIDs []types.EventNID) (types.StateSnapshotNID, error)
BulkSelectState(ctx context.Context, stateNIDs []types.StateSnapshotNID) (map[types.StateSnapshotNID][]types.EventNID, error)
}
/*
type StateSnapshot interface { type StateSnapshot interface {
InsertState(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID) (stateNID types.StateSnapshotNID, err error) InsertState(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID) (stateNID types.StateSnapshotNID, err error)
BulkSelectStateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error) BulkSelectStateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error)
@ -90,6 +97,7 @@ type StateBlock interface {
BulkSelectStateBlockEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error) BulkSelectStateBlockEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error)
BulkSelectFilteredStateBlockEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error) BulkSelectFilteredStateBlockEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error)
} }
*/
type RoomAliases interface { type RoomAliases interface {
InsertRoomAlias(ctx context.Context, txn *sql.Tx, alias string, roomID string, creatorUserID string) (err error) InsertRoomAlias(ctx context.Context, txn *sql.Tx, alias string, roomID string, creatorUserID string) (err error)

View file

@ -74,6 +74,24 @@ func (a StateEntry) LessThan(b StateEntry) bool {
return a.EventNID < b.EventNID return a.EventNID < b.EventNID
} }
// Deduplicate takes a set of event NIDs and ensures that there are no
// duplicates. If there are then we dedupe them.
func DeduplicateEventNIDs(a []EventNID) []EventNID {
if len(a) < 2 {
return a
}
sort.SliceStable(a, func(i, j int) bool {
return a[i] < a[j]
})
for i := 0; i < len(a)-1; i++ {
if a[i] == a[i+1] {
a = append(a[:i], a[i+1:]...)
i--
}
}
return a
}
// Deduplicate takes a set of state entries and ensures that there are no // Deduplicate takes a set of state entries and ensures that there are no
// duplicate (event type, state key) tuples. If there are then we dedupe // duplicate (event type, state key) tuples. If there are then we dedupe
// them, making sure that the latest/highest NIDs are always chosen. // them, making sure that the latest/highest NIDs are always chosen.
@ -151,18 +169,6 @@ const (
EmptyStateKeyNID = 1 EmptyStateKeyNID = 1
) )
// StateBlockNIDList is used to return the result of bulk StateBlockNID lookups from the database.
type StateBlockNIDList struct {
StateSnapshotNID StateSnapshotNID
StateBlockNIDs []StateBlockNID
}
// StateEntryList is used to return the result of bulk state entry lookups from the database.
type StateEntryList struct {
StateBlockNID StateBlockNID
StateEntries []StateEntry
}
// A MissingEventError is an error that happened because the roomserver was // A MissingEventError is an error that happened because the roomserver was
// missing requested events from its database. // missing requested events from its database.
type MissingEventError string type MissingEventError string