mirror of
https://github.com/hoernschen/dendrite.git
synced 2025-08-02 14:12:47 +00:00
Initial work on simplified state storage
This commit is contained in:
parent
e08942fb00
commit
a799847070
21 changed files with 540 additions and 1297 deletions
|
@ -63,6 +63,11 @@ const bulkSelectStateEventByIDSQL = "" +
|
|||
" WHERE event_id IN ($1)" +
|
||||
" ORDER BY event_type_nid, event_state_key_nid ASC"
|
||||
|
||||
const bulkSelectStateEventByNIDSQL = "" +
|
||||
"SELECT event_type_nid, event_state_key_nid, event_nid FROM roomserver_events" +
|
||||
" WHERE event_nid IN ($1)" +
|
||||
" ORDER BY event_type_nid, event_state_key_nid ASC"
|
||||
|
||||
const bulkSelectStateAtEventByIDSQL = "" +
|
||||
"SELECT event_type_nid, event_state_key_nid, event_nid, state_snapshot_nid, is_rejected FROM roomserver_events" +
|
||||
" WHERE event_id IN ($1)"
|
||||
|
@ -103,6 +108,7 @@ type eventStatements struct {
|
|||
insertEventStmt *sql.Stmt
|
||||
selectEventStmt *sql.Stmt
|
||||
bulkSelectStateEventByIDStmt *sql.Stmt
|
||||
bulkSelectStateEventByNIDStmt *sql.Stmt
|
||||
bulkSelectStateAtEventByIDStmt *sql.Stmt
|
||||
updateEventStateStmt *sql.Stmt
|
||||
selectEventSentToOutputStmt *sql.Stmt
|
||||
|
@ -128,6 +134,7 @@ func NewSqliteEventsTable(db *sql.DB) (tables.Events, error) {
|
|||
{&s.insertEventStmt, insertEventSQL},
|
||||
{&s.selectEventStmt, selectEventSQL},
|
||||
{&s.bulkSelectStateEventByIDStmt, bulkSelectStateEventByIDSQL},
|
||||
{&s.bulkSelectStateEventByNIDStmt, bulkSelectStateEventByNIDSQL},
|
||||
{&s.bulkSelectStateAtEventByIDStmt, bulkSelectStateAtEventByIDSQL},
|
||||
{&s.updateEventStateStmt, updateEventStateSQL},
|
||||
{&s.updateEventSentToOutputStmt, updateEventSentToOutputSQL},
|
||||
|
@ -232,6 +239,57 @@ func (s *eventStatements) BulkSelectStateEventByID(
|
|||
return results, err
|
||||
}
|
||||
|
||||
// bulkSelectStateEventByID lookups a list of state events by event ID.
|
||||
// If any of the requested events are missing from the database it returns a types.MissingEventError
|
||||
func (s *eventStatements) BulkSelectStateEventByNID(
|
||||
ctx context.Context, eventNIDs []types.EventNID,
|
||||
) ([]types.StateEntry, error) {
|
||||
///////////////
|
||||
iEventNIDs := make([]interface{}, len(eventNIDs))
|
||||
for k, v := range eventNIDs {
|
||||
iEventNIDs[k] = v
|
||||
}
|
||||
selectOrig := strings.Replace(bulkSelectStateEventByNIDSQL, "($1)", sqlutil.QueryVariadic(len(iEventNIDs)), 1)
|
||||
selectStmt, err := s.db.Prepare(selectOrig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
///////////////
|
||||
|
||||
rows, err := selectStmt.QueryContext(ctx, iEventNIDs...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateEventByID: rows.close() failed")
|
||||
// We know that we will only get as many results as event IDs
|
||||
// because of the unique constraint on event IDs.
|
||||
// So we can allocate an array of the correct size now.
|
||||
// We might get fewer results than IDs so we adjust the length of the slice before returning it.
|
||||
results := make([]types.StateEntry, len(eventNIDs))
|
||||
i := 0
|
||||
for ; rows.Next(); i++ {
|
||||
result := &results[i]
|
||||
if err = rows.Scan(
|
||||
&result.EventTypeNID,
|
||||
&result.EventStateKeyNID,
|
||||
&result.EventNID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if i != len(eventNIDs) {
|
||||
// If there are fewer rows returned than IDs then we were asked to lookup event IDs we don't have.
|
||||
// We don't know which ones were missing because we don't return the string IDs in the query.
|
||||
// However it should be possible debug this by replaying queries or entries from the input kafka logs.
|
||||
// If this turns out to be impossible and we do need the debug information here, it would be better
|
||||
// to do it as a separate query rather than slowing down/complicating the internal case.
|
||||
return nil, types.MissingEventError(
|
||||
fmt.Sprintf("storage: state event IDs missing from the database (%d != %d)", i, len(eventNIDs)),
|
||||
)
|
||||
}
|
||||
return results, err
|
||||
}
|
||||
|
||||
// bulkSelectStateAtEventByID lookups the state at a list of events by event ID.
|
||||
// If any of the requested events are missing from the database it returns a types.MissingEventError.
|
||||
// If we do not have the state for any of the requested events it returns a types.MissingEventError.
|
||||
|
|
|
@ -1,289 +0,0 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
"github.com/matrix-org/util"
|
||||
)
|
||||
|
||||
const stateDataSchema = `
|
||||
CREATE TABLE IF NOT EXISTS roomserver_state_block (
|
||||
state_block_nid INTEGER NOT NULL,
|
||||
event_type_nid INTEGER NOT NULL,
|
||||
event_state_key_nid INTEGER NOT NULL,
|
||||
event_nid INTEGER NOT NULL,
|
||||
UNIQUE (state_block_nid, event_type_nid, event_state_key_nid)
|
||||
);
|
||||
`
|
||||
|
||||
const insertStateDataSQL = "" +
|
||||
"INSERT INTO roomserver_state_block (state_block_nid, event_type_nid, event_state_key_nid, event_nid)" +
|
||||
" VALUES ($1, $2, $3, $4)"
|
||||
|
||||
const selectNextStateBlockNIDSQL = `
|
||||
SELECT IFNULL(MAX(state_block_nid), 0) + 1 FROM roomserver_state_block
|
||||
`
|
||||
|
||||
// Bulk state lookup by numeric state block ID.
|
||||
// Sort by the state_block_nid, event_type_nid, event_state_key_nid
|
||||
// This means that all the entries for a given state_block_nid will appear
|
||||
// together in the list and those entries will sorted by event_type_nid
|
||||
// and event_state_key_nid. This property makes it easier to merge two
|
||||
// state data blocks together.
|
||||
const bulkSelectStateBlockEntriesSQL = "" +
|
||||
"SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" +
|
||||
" FROM roomserver_state_block WHERE state_block_nid IN ($1)" +
|
||||
" ORDER BY state_block_nid, event_type_nid, event_state_key_nid"
|
||||
|
||||
// Bulk state lookup by numeric state block ID.
|
||||
// Filters the rows in each block to the requested types and state keys.
|
||||
// We would like to restrict to particular type state key pairs but we are
|
||||
// restricted by the query language to pull the cross product of a list
|
||||
// of types and a list state_keys. So we have to filter the result in the
|
||||
// application to restrict it to the list of event types and state keys we
|
||||
// actually wanted.
|
||||
const bulkSelectFilteredStateBlockEntriesSQL = "" +
|
||||
"SELECT state_block_nid, event_type_nid, event_state_key_nid, event_nid" +
|
||||
" FROM roomserver_state_block WHERE state_block_nid IN ($1)" +
|
||||
" AND event_type_nid IN ($2) AND event_state_key_nid IN ($3)" +
|
||||
" ORDER BY state_block_nid, event_type_nid, event_state_key_nid"
|
||||
|
||||
type stateBlockStatements struct {
|
||||
db *sql.DB
|
||||
insertStateDataStmt *sql.Stmt
|
||||
selectNextStateBlockNIDStmt *sql.Stmt
|
||||
bulkSelectStateBlockEntriesStmt *sql.Stmt
|
||||
bulkSelectFilteredStateBlockEntriesStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteStateBlockTable(db *sql.DB) (tables.StateBlock, error) {
|
||||
s := &stateBlockStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(stateDataSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s, shared.StatementList{
|
||||
{&s.insertStateDataStmt, insertStateDataSQL},
|
||||
{&s.selectNextStateBlockNIDStmt, selectNextStateBlockNIDSQL},
|
||||
{&s.bulkSelectStateBlockEntriesStmt, bulkSelectStateBlockEntriesSQL},
|
||||
{&s.bulkSelectFilteredStateBlockEntriesStmt, bulkSelectFilteredStateBlockEntriesSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *stateBlockStatements) BulkInsertStateData(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
entries []types.StateEntry,
|
||||
) (types.StateBlockNID, error) {
|
||||
if len(entries) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
var stateBlockNID types.StateBlockNID
|
||||
err := sqlutil.TxStmt(txn, s.selectNextStateBlockNIDStmt).QueryRowContext(ctx).Scan(&stateBlockNID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
for _, entry := range entries {
|
||||
_, err = sqlutil.TxStmt(txn, s.insertStateDataStmt).ExecContext(
|
||||
ctx,
|
||||
int64(stateBlockNID),
|
||||
int64(entry.EventTypeNID),
|
||||
int64(entry.EventStateKeyNID),
|
||||
int64(entry.EventNID),
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return stateBlockNID, err
|
||||
}
|
||||
|
||||
func (s *stateBlockStatements) BulkSelectStateBlockEntries(
|
||||
ctx context.Context, stateBlockNIDs []types.StateBlockNID,
|
||||
) ([]types.StateEntryList, error) {
|
||||
nids := make([]interface{}, len(stateBlockNIDs))
|
||||
for k, v := range stateBlockNIDs {
|
||||
nids[k] = v
|
||||
}
|
||||
selectOrig := strings.Replace(bulkSelectStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1)
|
||||
selectStmt, err := s.db.Prepare(selectOrig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rows, err := selectStmt.QueryContext(ctx, nids...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockEntries: rows.close() failed")
|
||||
|
||||
results := make([]types.StateEntryList, len(stateBlockNIDs))
|
||||
// current is a pointer to the StateEntryList to append the state entries to.
|
||||
var current *types.StateEntryList
|
||||
i := 0
|
||||
for rows.Next() {
|
||||
var (
|
||||
stateBlockNID int64
|
||||
eventTypeNID int64
|
||||
eventStateKeyNID int64
|
||||
eventNID int64
|
||||
entry types.StateEntry
|
||||
)
|
||||
if err := rows.Scan(
|
||||
&stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
entry.EventTypeNID = types.EventTypeNID(eventTypeNID)
|
||||
entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID)
|
||||
entry.EventNID = types.EventNID(eventNID)
|
||||
if current == nil || types.StateBlockNID(stateBlockNID) != current.StateBlockNID {
|
||||
// The state entry row is for a different state data block to the current one.
|
||||
// So we start appending to the next entry in the list.
|
||||
current = &results[i]
|
||||
current.StateBlockNID = types.StateBlockNID(stateBlockNID)
|
||||
i++
|
||||
}
|
||||
current.StateEntries = append(current.StateEntries, entry)
|
||||
}
|
||||
if i != len(nids) {
|
||||
return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", i, len(nids))
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
func (s *stateBlockStatements) BulkSelectFilteredStateBlockEntries(
|
||||
ctx context.Context,
|
||||
stateBlockNIDs []types.StateBlockNID,
|
||||
stateKeyTuples []types.StateKeyTuple,
|
||||
) ([]types.StateEntryList, error) {
|
||||
tuples := stateKeyTupleSorter(stateKeyTuples)
|
||||
// Sort the tuples so that we can run binary search against them as we filter the rows returned by the db.
|
||||
sort.Sort(tuples)
|
||||
|
||||
eventTypeNIDArray, eventStateKeyNIDArray := tuples.typesAndStateKeysAsArrays()
|
||||
sqlStatement := strings.Replace(bulkSelectFilteredStateBlockEntriesSQL, "($1)", sqlutil.QueryVariadic(len(stateBlockNIDs)), 1)
|
||||
sqlStatement = strings.Replace(sqlStatement, "($2)", sqlutil.QueryVariadicOffset(len(eventTypeNIDArray), len(stateBlockNIDs)), 1)
|
||||
sqlStatement = strings.Replace(sqlStatement, "($3)", sqlutil.QueryVariadicOffset(len(eventStateKeyNIDArray), len(stateBlockNIDs)+len(eventTypeNIDArray)), 1)
|
||||
|
||||
var params []interface{}
|
||||
for _, val := range stateBlockNIDs {
|
||||
params = append(params, int64(val))
|
||||
}
|
||||
for _, val := range eventTypeNIDArray {
|
||||
params = append(params, val)
|
||||
}
|
||||
for _, val := range eventStateKeyNIDArray {
|
||||
params = append(params, val)
|
||||
}
|
||||
|
||||
rows, err := s.db.QueryContext(
|
||||
ctx,
|
||||
sqlStatement,
|
||||
params...,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectFilteredStateBlockEntries: rows.close() failed")
|
||||
|
||||
var results []types.StateEntryList
|
||||
var current types.StateEntryList
|
||||
for rows.Next() {
|
||||
var (
|
||||
stateBlockNID int64
|
||||
eventTypeNID int64
|
||||
eventStateKeyNID int64
|
||||
eventNID int64
|
||||
entry types.StateEntry
|
||||
)
|
||||
if err := rows.Scan(
|
||||
&stateBlockNID, &eventTypeNID, &eventStateKeyNID, &eventNID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
entry.EventTypeNID = types.EventTypeNID(eventTypeNID)
|
||||
entry.EventStateKeyNID = types.EventStateKeyNID(eventStateKeyNID)
|
||||
entry.EventNID = types.EventNID(eventNID)
|
||||
|
||||
// We can use binary search here because we sorted the tuples earlier
|
||||
if !tuples.contains(entry.StateKeyTuple) {
|
||||
// The select will return the cross product of types and state keys.
|
||||
// So we need to check if type of the entry is in the list.
|
||||
continue
|
||||
}
|
||||
|
||||
if types.StateBlockNID(stateBlockNID) != current.StateBlockNID {
|
||||
// The state entry row is for a different state data block to the current one.
|
||||
// So we append the current entry to the results and start adding to a new one.
|
||||
// The first time through the loop current will be empty.
|
||||
if current.StateEntries != nil {
|
||||
results = append(results, current)
|
||||
}
|
||||
current = types.StateEntryList{StateBlockNID: types.StateBlockNID(stateBlockNID)}
|
||||
}
|
||||
current.StateEntries = append(current.StateEntries, entry)
|
||||
}
|
||||
// Add the last entry to the list if it is not empty.
|
||||
if current.StateEntries != nil {
|
||||
results = append(results, current)
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
type stateKeyTupleSorter []types.StateKeyTuple
|
||||
|
||||
func (s stateKeyTupleSorter) Len() int { return len(s) }
|
||||
func (s stateKeyTupleSorter) Less(i, j int) bool { return s[i].LessThan(s[j]) }
|
||||
func (s stateKeyTupleSorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
||||
|
||||
// Check whether a tuple is in the list. Assumes that the list is sorted.
|
||||
func (s stateKeyTupleSorter) contains(value types.StateKeyTuple) bool {
|
||||
i := sort.Search(len(s), func(i int) bool { return !s[i].LessThan(value) })
|
||||
return i < len(s) && s[i] == value
|
||||
}
|
||||
|
||||
// List the unique eventTypeNIDs and eventStateKeyNIDs.
|
||||
// Assumes that the list is sorted.
|
||||
func (s stateKeyTupleSorter) typesAndStateKeysAsArrays() (eventTypeNIDs []int64, eventStateKeyNIDs []int64) {
|
||||
eventTypeNIDs = make([]int64, len(s))
|
||||
eventStateKeyNIDs = make([]int64, len(s))
|
||||
for i := range s {
|
||||
eventTypeNIDs[i] = int64(s[i].EventTypeNID)
|
||||
eventStateKeyNIDs[i] = int64(s[i].EventStateKeyNID)
|
||||
}
|
||||
eventTypeNIDs = eventTypeNIDs[:util.SortAndUnique(int64Sorter(eventTypeNIDs))]
|
||||
eventStateKeyNIDs = eventStateKeyNIDs[:util.SortAndUnique(int64Sorter(eventStateKeyNIDs))]
|
||||
return
|
||||
}
|
||||
|
||||
type int64Sorter []int64
|
||||
|
||||
func (s int64Sorter) Len() int { return len(s) }
|
||||
func (s int64Sorter) Less(i, j int) bool { return s[i] < s[j] }
|
||||
func (s int64Sorter) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
|
@ -1,86 +0,0 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
)
|
||||
|
||||
func TestStateKeyTupleSorter(t *testing.T) {
|
||||
input := stateKeyTupleSorter{
|
||||
{EventTypeNID: 1, EventStateKeyNID: 2},
|
||||
{EventTypeNID: 1, EventStateKeyNID: 4},
|
||||
{EventTypeNID: 2, EventStateKeyNID: 2},
|
||||
{EventTypeNID: 1, EventStateKeyNID: 1},
|
||||
}
|
||||
want := []types.StateKeyTuple{
|
||||
{EventTypeNID: 1, EventStateKeyNID: 1},
|
||||
{EventTypeNID: 1, EventStateKeyNID: 2},
|
||||
{EventTypeNID: 1, EventStateKeyNID: 4},
|
||||
{EventTypeNID: 2, EventStateKeyNID: 2},
|
||||
}
|
||||
doNotWant := []types.StateKeyTuple{
|
||||
{EventTypeNID: 0, EventStateKeyNID: 0},
|
||||
{EventTypeNID: 1, EventStateKeyNID: 3},
|
||||
{EventTypeNID: 2, EventStateKeyNID: 1},
|
||||
{EventTypeNID: 3, EventStateKeyNID: 1},
|
||||
}
|
||||
wantTypeNIDs := []int64{1, 2}
|
||||
wantStateKeyNIDs := []int64{1, 2, 4}
|
||||
|
||||
// Sort the input and check it's in the right order.
|
||||
sort.Sort(input)
|
||||
gotTypeNIDs, gotStateKeyNIDs := input.typesAndStateKeysAsArrays()
|
||||
|
||||
for i := range want {
|
||||
if input[i] != want[i] {
|
||||
t.Errorf("Wanted %#v at index %d got %#v", want[i], i, input[i])
|
||||
}
|
||||
|
||||
if !input.contains(want[i]) {
|
||||
t.Errorf("Wanted %#v.contains(%#v) to be true but got false", input, want[i])
|
||||
}
|
||||
}
|
||||
|
||||
for i := range doNotWant {
|
||||
if input.contains(doNotWant[i]) {
|
||||
t.Errorf("Wanted %#v.contains(%#v) to be false but got true", input, doNotWant[i])
|
||||
}
|
||||
}
|
||||
|
||||
if len(wantTypeNIDs) != len(gotTypeNIDs) {
|
||||
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
|
||||
}
|
||||
|
||||
for i := range wantTypeNIDs {
|
||||
if wantTypeNIDs[i] != gotTypeNIDs[i] {
|
||||
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
|
||||
}
|
||||
}
|
||||
|
||||
if len(wantStateKeyNIDs) != len(gotStateKeyNIDs) {
|
||||
t.Fatalf("Wanted state key NIDs %#v got %#v", wantStateKeyNIDs, gotStateKeyNIDs)
|
||||
}
|
||||
|
||||
for i := range wantStateKeyNIDs {
|
||||
if wantStateKeyNIDs[i] != gotStateKeyNIDs[i] {
|
||||
t.Fatalf("Wanted type NIDs %#v got %#v", wantTypeNIDs, gotTypeNIDs)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,126 +0,0 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
)
|
||||
|
||||
const stateSnapshotSchema = `
|
||||
CREATE TABLE IF NOT EXISTS roomserver_state_snapshots (
|
||||
state_snapshot_nid INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
room_nid INTEGER NOT NULL,
|
||||
state_block_nids TEXT NOT NULL DEFAULT '[]'
|
||||
);
|
||||
`
|
||||
|
||||
const insertStateSQL = `
|
||||
INSERT INTO roomserver_state_snapshots (room_nid, state_block_nids)
|
||||
VALUES ($1, $2);`
|
||||
|
||||
// Bulk state data NID lookup.
|
||||
// Sorting by state_snapshot_nid means we can use binary search over the result
|
||||
// to lookup the state data NIDs for a state snapshot NID.
|
||||
const bulkSelectStateBlockNIDsSQL = "" +
|
||||
"SELECT state_snapshot_nid, state_block_nids FROM roomserver_state_snapshots" +
|
||||
" WHERE state_snapshot_nid IN ($1) ORDER BY state_snapshot_nid ASC"
|
||||
|
||||
type stateSnapshotStatements struct {
|
||||
db *sql.DB
|
||||
insertStateStmt *sql.Stmt
|
||||
bulkSelectStateBlockNIDsStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewSqliteStateSnapshotTable(db *sql.DB) (tables.StateSnapshot, error) {
|
||||
s := &stateSnapshotStatements{
|
||||
db: db,
|
||||
}
|
||||
_, err := db.Exec(stateSnapshotSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s, shared.StatementList{
|
||||
{&s.insertStateStmt, insertStateSQL},
|
||||
{&s.bulkSelectStateBlockNIDsStmt, bulkSelectStateBlockNIDsSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *stateSnapshotStatements) InsertState(
|
||||
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID,
|
||||
) (stateNID types.StateSnapshotNID, err error) {
|
||||
stateBlockNIDsJSON, err := json.Marshal(stateBlockNIDs)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
insertStmt := sqlutil.TxStmt(txn, s.insertStateStmt)
|
||||
res, err := insertStmt.ExecContext(ctx, int64(roomNID), string(stateBlockNIDsJSON))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
lastRowID, err := res.LastInsertId()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
stateNID = types.StateSnapshotNID(lastRowID)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *stateSnapshotStatements) BulkSelectStateBlockNIDs(
|
||||
ctx context.Context, stateNIDs []types.StateSnapshotNID,
|
||||
) ([]types.StateBlockNIDList, error) {
|
||||
nids := make([]interface{}, len(stateNIDs))
|
||||
for k, v := range stateNIDs {
|
||||
nids[k] = v
|
||||
}
|
||||
selectOrig := strings.Replace(bulkSelectStateBlockNIDsSQL, "($1)", sqlutil.QueryVariadic(len(nids)), 1)
|
||||
selectStmt, err := s.db.Prepare(selectOrig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := selectStmt.QueryContext(ctx, nids...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockNIDs: rows.close() failed")
|
||||
results := make([]types.StateBlockNIDList, len(stateNIDs))
|
||||
i := 0
|
||||
for ; rows.Next(); i++ {
|
||||
result := &results[i]
|
||||
var stateBlockNIDsJSON string
|
||||
if err := rows.Scan(&result.StateSnapshotNID, &stateBlockNIDsJSON); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := json.Unmarshal([]byte(stateBlockNIDsJSON), &result.StateBlockNIDs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if i != len(stateNIDs) {
|
||||
return nil, fmt.Errorf("storage: state NIDs missing from the database (%d != %d)", i, len(stateNIDs))
|
||||
}
|
||||
return results, nil
|
||||
}
|
132
roomserver/storage/sqlite3/state_table.go
Normal file
132
roomserver/storage/sqlite3/state_table.go
Normal file
|
@ -0,0 +1,132 @@
|
|||
// Copyright 2017-2018 New Vector Ltd
|
||||
// Copyright 2019-2020 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||
"github.com/matrix-org/dendrite/roomserver/types"
|
||||
)
|
||||
|
||||
const stateSchema = `
|
||||
CREATE TABLE IF NOT EXISTS roomserver_state (
|
||||
state_nid INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
room_nid INTEGER NOT NULL,
|
||||
event_nids TEXT NOT NULL,
|
||||
UNIQUE (room_nid, event_nids),
|
||||
CONSTRAINT fk_room_id FOREIGN KEY(room_nid) REFERENCES roomserver_rooms(room_nid)
|
||||
);
|
||||
`
|
||||
|
||||
const insertNewStateSnapshotSQL = "" +
|
||||
"INSERT INTO roomserver_state (room_nid, event_nids)" +
|
||||
" VALUES ($1, $2)" +
|
||||
" ON CONFLICT (room_nid, event_nids) DO UPDATE SET room_nid = $1" +
|
||||
" RETURNING state_nid"
|
||||
|
||||
const bulkSelectNewStateSnapshotSQL = "" +
|
||||
"SELECT state_nid, event_nids" +
|
||||
" FROM roomserver_state WHERE state_nid IN ($1)" +
|
||||
" ORDER BY state_nid"
|
||||
|
||||
type stateStatements struct {
|
||||
db *sql.DB
|
||||
insertStateStmt *sql.Stmt
|
||||
bulkSelectStateStmt *sql.Stmt
|
||||
}
|
||||
|
||||
func NewPostgresStateTable(db *sql.DB) (tables.State, error) {
|
||||
s := &stateStatements{db: db}
|
||||
_, err := db.Exec(stateSchema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return s, shared.StatementList{
|
||||
{&s.insertStateStmt, insertNewStateSnapshotSQL},
|
||||
{&s.bulkSelectStateStmt, bulkSelectNewStateSnapshotSQL},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
func (s *stateStatements) InsertState(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx,
|
||||
roomNID types.RoomNID,
|
||||
eventNIDs []types.EventNID,
|
||||
) (types.StateSnapshotNID, error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertStateStmt)
|
||||
value, err := json.Marshal(types.DeduplicateEventNIDs(eventNIDs))
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("json.Marshal: %w", err)
|
||||
}
|
||||
var id int64
|
||||
if err = stmt.QueryRowContext(ctx, int64(roomNID), value).Scan(&id); err != nil {
|
||||
return 0, fmt.Errorf("stmt.ExecContext: %w", err)
|
||||
}
|
||||
return types.StateSnapshotNID(id), nil
|
||||
}
|
||||
|
||||
func (s *stateStatements) BulkSelectState(
|
||||
ctx context.Context, stateNIDs []types.StateSnapshotNID,
|
||||
) (map[types.StateSnapshotNID][]types.EventNID, error) {
|
||||
selectOrig := strings.Replace(bulkSelectNewStateSnapshotSQL, "($1)", sqlutil.QueryVariadic(len(stateNIDs)), 1)
|
||||
selectStmt, err := s.db.Prepare(selectOrig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
nids := make([]interface{}, len(stateNIDs))
|
||||
for i := range stateNIDs {
|
||||
nids[i] = int64(stateNIDs[i])
|
||||
}
|
||||
rows, err := selectStmt.QueryContext(ctx, nids...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectStateBlockEntries: rows.close() failed")
|
||||
results := map[types.StateSnapshotNID][]types.EventNID{}
|
||||
for rows.Next() {
|
||||
var stateNID int64
|
||||
var eventNIDJSON json.RawMessage
|
||||
if err = rows.Scan(&stateNID, &eventNIDJSON); err != nil {
|
||||
return nil, fmt.Errorf("rows.Scan: %w", err)
|
||||
}
|
||||
var eventNIDs []int64
|
||||
if err = json.Unmarshal(eventNIDJSON, &eventNIDs); err != nil {
|
||||
return nil, fmt.Errorf("json.Unmarshal: %w", err)
|
||||
}
|
||||
for _, id := range eventNIDs {
|
||||
results[types.StateSnapshotNID(stateNID)] = append(
|
||||
results[types.StateSnapshotNID(stateNID)],
|
||||
types.EventNID(id),
|
||||
)
|
||||
}
|
||||
}
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("rows.Err: %w", err)
|
||||
}
|
||||
if len(results) != len(stateNIDs) {
|
||||
return nil, fmt.Errorf("storage: state data NIDs missing from the database (%d != %d)", len(results), len(stateNIDs))
|
||||
}
|
||||
return results, err
|
||||
}
|
|
@ -98,11 +98,17 @@ func (d *Database) prepare(db *sql.DB, cache caching.RoomServerCaches) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stateBlock, err := NewSqliteStateBlockTable(db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stateSnapshot, err := NewSqliteStateSnapshotTable(db)
|
||||
/*
|
||||
stateBlock, err := NewSqliteStateBlockTable(db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stateSnapshot, err := NewSqliteStateSnapshotTable(db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*/
|
||||
state, err := NewPostgresStateTable(db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -131,17 +137,18 @@ func (d *Database) prepare(db *sql.DB, cache caching.RoomServerCaches) error {
|
|||
return err
|
||||
}
|
||||
d.Database = shared.Database{
|
||||
DB: db,
|
||||
Cache: cache,
|
||||
Writer: sqlutil.NewExclusiveWriter(),
|
||||
EventsTable: events,
|
||||
EventTypesTable: eventTypes,
|
||||
EventStateKeysTable: eventStateKeys,
|
||||
EventJSONTable: eventJSON,
|
||||
RoomsTable: rooms,
|
||||
TransactionsTable: transactions,
|
||||
StateBlockTable: stateBlock,
|
||||
StateSnapshotTable: stateSnapshot,
|
||||
DB: db,
|
||||
Cache: cache,
|
||||
Writer: sqlutil.NewExclusiveWriter(),
|
||||
EventsTable: events,
|
||||
EventTypesTable: eventTypes,
|
||||
EventStateKeysTable: eventStateKeys,
|
||||
EventJSONTable: eventJSON,
|
||||
RoomsTable: rooms,
|
||||
TransactionsTable: transactions,
|
||||
StateTable: state,
|
||||
// StateBlockTable: stateBlock,
|
||||
// StateSnapshotTable: stateSnapshot,
|
||||
PrevEventsTable: prevEvents,
|
||||
RoomAliasesTable: roomAliases,
|
||||
InvitesTable: invites,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue