Refactor StoreEvent and create a new RoomDatabase interface (#2985)

This PR changes a few things:
- It pulls out the creation of several NIDs from the `StoreEvent`
function to make the functions more reusable
- Uses more caching when using those NIDs to avoid DB round trips
This commit is contained in:
Till 2023-02-24 09:40:20 +01:00 committed by GitHub
parent e6aa0955ff
commit ad07b169b8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
31 changed files with 459 additions and 302 deletions

View file

@ -69,15 +69,12 @@ type Database interface {
) ([]types.StateEntryList, error)
// Look up the Events for a list of numeric event IDs.
// Returns a sorted list of events.
Events(ctx context.Context, eventNIDs []types.EventNID) ([]types.Event, error)
Events(ctx context.Context, roomNID types.RoomNID, eventNIDs []types.EventNID) ([]types.Event, error)
// Look up snapshot NID for an event ID string
SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error)
BulkSelectSnapshotsFromEventIDs(ctx context.Context, eventIDs []string) (map[types.StateSnapshotNID][]string, error)
// Stores a matrix room event in the database. Returns the room NID, the state snapshot and the redacted event ID if any, or an error.
StoreEvent(
ctx context.Context, event *gomatrixserverlib.Event, authEventNIDs []types.EventNID,
isRejected bool,
) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error)
StoreEvent(ctx context.Context, event *gomatrixserverlib.Event, roomNID types.RoomNID, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool) (types.EventNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error)
// Look up the state entries for a list of string event IDs
// Returns an error if the there is an error talking to the database
// Returns a types.MissingEventError if the event IDs aren't in the database.
@ -87,7 +84,7 @@ type Database interface {
EventStateKeys(ctx context.Context, eventStateKeyNIDs []types.EventStateKeyNID) (map[types.EventStateKeyNID]string, error)
// Look up the numeric IDs for a list of events.
// Returns an error if there was a problem talking to the database.
EventNIDs(ctx context.Context, eventIDs []string) (map[string]types.EventNID, error)
EventNIDs(ctx context.Context, eventIDs []string) (map[string]types.EventMetadata, error)
// Set the state at an event. FIXME TODO: "at"
SetState(ctx context.Context, eventNID types.EventNID, stateNID types.StateSnapshotNID) error
// Lookup the event IDs for a batch of event numeric IDs.
@ -138,7 +135,7 @@ type Database interface {
// EventsFromIDs looks up the Events for a list of event IDs. Does not error if event was
// not found.
// Returns an error if the retrieval went wrong.
EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error)
EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) ([]types.Event, error)
// Publish or unpublish a room from the room directory.
PublishRoom(ctx context.Context, roomID, appserviceID, networkID string, publish bool) error
// Returns a list of room IDs for rooms which are published.
@ -182,4 +179,36 @@ type Database interface {
GetMembershipForHistoryVisibility(
ctx context.Context, userNID types.EventStateKeyNID, info *types.RoomInfo, eventIDs ...string,
) (map[string]*gomatrixserverlib.HeaderedEvent, error)
GetOrCreateRoomNID(ctx context.Context, event *gomatrixserverlib.Event) (types.RoomNID, error)
GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error)
GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error)
}
type RoomDatabase interface {
// RoomInfo returns room information for the given room ID, or nil if there is no room.
RoomInfo(ctx context.Context, roomID string) (*types.RoomInfo, error)
// IsEventRejected returns true if the event is known and rejected.
IsEventRejected(ctx context.Context, roomNID types.RoomNID, eventID string) (rejected bool, err error)
MissingAuthPrevEvents(ctx context.Context, e *gomatrixserverlib.Event) (missingAuth, missingPrev []string, err error)
// Stores a matrix room event in the database. Returns the room NID, the state snapshot and the redacted event ID if any, or an error.
StoreEvent(ctx context.Context, event *gomatrixserverlib.Event, roomNID types.RoomNID, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID, authEventNIDs []types.EventNID, isRejected bool) (types.EventNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error)
UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error
GetRoomUpdater(ctx context.Context, roomInfo *types.RoomInfo) (*shared.RoomUpdater, error)
StateEntriesForEventIDs(ctx context.Context, eventIDs []string, excludeRejected bool) ([]types.StateEntry, error)
GetMembershipEventNIDsForRoom(ctx context.Context, roomNID types.RoomNID, joinOnly bool, localOnly bool) ([]types.EventNID, error)
StateBlockNIDs(ctx context.Context, stateNIDs []types.StateSnapshotNID) ([]types.StateBlockNIDList, error)
StateEntries(ctx context.Context, stateBlockNIDs []types.StateBlockNID) ([]types.StateEntryList, error)
SnapshotNIDFromEventID(ctx context.Context, eventID string) (types.StateSnapshotNID, error)
BulkSelectSnapshotsFromEventIDs(ctx context.Context, eventIDs []string) (map[types.StateSnapshotNID][]string, error)
StateEntriesForTuples(ctx context.Context, stateBlockNIDs []types.StateBlockNID, stateKeyTuples []types.StateKeyTuple) ([]types.StateEntryList, error)
StateAtEventIDs(ctx context.Context, eventIDs []string) ([]types.StateAtEvent, error)
AddState(ctx context.Context, roomNID types.RoomNID, stateBlockNIDs []types.StateBlockNID, state []types.StateEntry) (types.StateSnapshotNID, error)
Events(ctx context.Context, roomNID types.RoomNID, eventNIDs []types.EventNID) ([]types.Event, error)
EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) ([]types.Event, error)
LatestEventIDs(ctx context.Context, roomNID types.RoomNID) ([]gomatrixserverlib.EventReference, types.StateSnapshotNID, int64, error)
EventTypeNIDs(ctx context.Context, eventTypes []string) (map[string]types.EventTypeNID, error)
EventStateKeyNIDs(ctx context.Context, eventStateKeys []string) (map[string]types.EventStateKeyNID, error)
GetOrCreateRoomNID(ctx context.Context, event *gomatrixserverlib.Event) (types.RoomNID, error)
GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error)
GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (types.EventStateKeyNID, error)
}

View file

@ -140,10 +140,10 @@ const bulkSelectEventIDSQL = "" +
"SELECT event_nid, event_id FROM roomserver_events WHERE event_nid = ANY($1)"
const bulkSelectEventNIDSQL = "" +
"SELECT event_id, event_nid FROM roomserver_events WHERE event_id = ANY($1)"
"SELECT event_id, event_nid, room_nid FROM roomserver_events WHERE event_id = ANY($1)"
const bulkSelectUnsentEventNIDSQL = "" +
"SELECT event_id, event_nid FROM roomserver_events WHERE event_id = ANY($1) AND sent_to_output = FALSE"
"SELECT event_id, event_nid, room_nid FROM roomserver_events WHERE event_id = ANY($1) AND sent_to_output = FALSE"
const selectMaxEventDepthSQL = "" +
"SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid = ANY($1)"
@ -520,20 +520,20 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, ev
// BulkSelectEventNIDs returns a map from string event ID to numeric event ID.
// If an event ID is not in the database then it is omitted from the map.
func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) {
func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventMetadata, error) {
return s.bulkSelectEventNID(ctx, txn, eventIDs, false)
}
// BulkSelectEventNIDs returns a map from string event ID to numeric event ID
// only for events that haven't already been sent to the roomserver output.
// If an event ID is not in the database then it is omitted from the map.
func (s *eventStatements) BulkSelectUnsentEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) {
func (s *eventStatements) BulkSelectUnsentEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventMetadata, error) {
return s.bulkSelectEventNID(ctx, txn, eventIDs, true)
}
// bulkSelectEventNIDs returns a map from string event ID to numeric event ID.
// If an event ID is not in the database then it is omitted from the map.
func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string, onlyUnsent bool) (map[string]types.EventNID, error) {
func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string, onlyUnsent bool) (map[string]types.EventMetadata, error) {
var stmt *sql.Stmt
if onlyUnsent {
stmt = sqlutil.TxStmt(txn, s.bulkSelectUnsentEventNIDStmt)
@ -545,14 +545,18 @@ func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, e
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventNID: rows.close() failed")
results := make(map[string]types.EventNID, len(eventIDs))
results := make(map[string]types.EventMetadata, len(eventIDs))
var eventID string
var eventNID int64
var roomNID int64
for rows.Next() {
if err = rows.Scan(&eventID, &eventNID); err != nil {
if err = rows.Scan(&eventID, &eventNID, &roomNID); err != nil {
return nil, err
}
results[eventID] = types.EventNID(eventNID)
results[eventID] = types.EventMetadata{
EventNID: types.EventNID(eventNID),
RoomNID: types.RoomNID(roomNID),
}
}
return results, rows.Err()
}

View file

@ -116,10 +116,8 @@ func (u *RoomUpdater) StorePreviousEvents(eventNID types.EventNID, previousEvent
})
}
func (u *RoomUpdater) Events(
ctx context.Context, eventNIDs []types.EventNID,
) ([]types.Event, error) {
return u.d.events(ctx, u.txn, eventNIDs)
func (u *RoomUpdater) Events(ctx context.Context, _ types.RoomNID, eventNIDs []types.EventNID) ([]types.Event, error) {
return u.d.events(ctx, u.txn, u.roomInfo.RoomNID, eventNIDs)
}
func (u *RoomUpdater) SnapshotNIDFromEventID(
@ -197,12 +195,8 @@ func (u *RoomUpdater) StateAtEventIDs(
return u.d.EventsTable.BulkSelectStateAtEventByID(ctx, u.txn, eventIDs)
}
func (u *RoomUpdater) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) {
return u.d.eventsFromIDs(ctx, u.txn, eventIDs, false)
}
func (u *RoomUpdater) UnsentEventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) {
return u.d.eventsFromIDs(ctx, u.txn, eventIDs, true)
func (u *RoomUpdater) EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) ([]types.Event, error) {
return u.d.eventsFromIDs(ctx, u.txn, roomNID, eventIDs, NoFilter)
}
// IsReferenced implements types.RoomRecentEventsUpdater

View file

@ -9,6 +9,7 @@ import (
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/matrix-org/dendrite/internal/caching"
@ -111,16 +112,31 @@ func (d *Database) eventStateKeyNIDs(
) (map[string]types.EventStateKeyNID, error) {
result := make(map[string]types.EventStateKeyNID)
eventStateKeys = util.UniqueStrings(eventStateKeys)
nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, txn, eventStateKeys)
if err != nil {
return nil, err
// first ask the cache about these keys
fetchEventStateKeys := make([]string, 0, len(eventStateKeys))
for _, eventStateKey := range eventStateKeys {
eventStateKeyNID, ok := d.Cache.GetEventStateKeyNID(eventStateKey)
if ok {
result[eventStateKey] = eventStateKeyNID
continue
}
fetchEventStateKeys = append(fetchEventStateKeys, eventStateKey)
}
for eventStateKey, nid := range nids {
result[eventStateKey] = nid
if len(fetchEventStateKeys) > 0 {
nids, err := d.EventStateKeysTable.BulkSelectEventStateKeyNID(ctx, txn, fetchEventStateKeys)
if err != nil {
return nil, err
}
for eventStateKey, nid := range nids {
result[eventStateKey] = nid
}
}
// We received some nids, but are still missing some, work out which and create them
if len(eventStateKeys) > len(result) {
var nid types.EventStateKeyNID
var err error
err = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error {
for _, eventStateKey := range eventStateKeys {
if _, ok := result[eventStateKey]; ok {
@ -262,7 +278,7 @@ func (d *Database) addState(
func (d *Database) EventNIDs(
ctx context.Context, eventIDs []string,
) (map[string]types.EventNID, error) {
) (map[string]types.EventMetadata, error) {
return d.eventNIDs(ctx, nil, eventIDs, NoFilter)
}
@ -275,7 +291,7 @@ const (
func (d *Database) eventNIDs(
ctx context.Context, txn *sql.Tx, eventIDs []string, filter UnsentFilter,
) (map[string]types.EventNID, error) {
) (map[string]types.EventMetadata, error) {
switch filter {
case FilterUnsentOnly:
return d.EventsTable.BulkSelectUnsentEventNID(ctx, txn, eventIDs)
@ -325,11 +341,11 @@ func (d *Database) EventIDs(
return d.EventsTable.BulkSelectEventID(ctx, nil, eventNIDs)
}
func (d *Database) EventsFromIDs(ctx context.Context, eventIDs []string) ([]types.Event, error) {
return d.eventsFromIDs(ctx, nil, eventIDs, NoFilter)
func (d *Database) EventsFromIDs(ctx context.Context, roomNID types.RoomNID, eventIDs []string) ([]types.Event, error) {
return d.eventsFromIDs(ctx, nil, roomNID, eventIDs, NoFilter)
}
func (d *Database) eventsFromIDs(ctx context.Context, txn *sql.Tx, eventIDs []string, filter UnsentFilter) ([]types.Event, error) {
func (d *Database) eventsFromIDs(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventIDs []string, filter UnsentFilter) ([]types.Event, error) {
nidMap, err := d.eventNIDs(ctx, txn, eventIDs, filter)
if err != nil {
return nil, err
@ -337,10 +353,16 @@ func (d *Database) eventsFromIDs(ctx context.Context, txn *sql.Tx, eventIDs []st
var nids []types.EventNID
for _, nid := range nidMap {
nids = append(nids, nid)
nids = append(nids, nid.EventNID)
if roomNID != 0 && roomNID != nid.RoomNID {
logrus.Errorf("expected events from room %d, but also found %d", roomNID, nid.RoomNID)
}
if roomNID == 0 {
roomNID = nid.RoomNID
}
}
return d.events(ctx, txn, nids)
return d.events(ctx, txn, roomNID, nids)
}
func (d *Database) LatestEventIDs(
@ -480,14 +502,18 @@ func (d *Database) GetInvitesForUser(
}
func (d *Database) Events(
ctx context.Context, eventNIDs []types.EventNID,
ctx context.Context, roomNID types.RoomNID, eventNIDs []types.EventNID,
) ([]types.Event, error) {
return d.events(ctx, nil, eventNIDs)
return d.events(ctx, nil, roomNID, eventNIDs)
}
func (d *Database) events(
ctx context.Context, txn *sql.Tx, inputEventNIDs types.EventNIDs,
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, inputEventNIDs types.EventNIDs,
) ([]types.Event, error) {
if roomNID == 0 {
// No need to go further, as we won't find any events for this room.
return nil, nil
}
sort.Sort(inputEventNIDs)
events := make(map[types.EventNID]*gomatrixserverlib.Event, len(inputEventNIDs))
eventNIDs := make([]types.EventNID, 0, len(inputEventNIDs))
@ -519,40 +545,34 @@ func (d *Database) events(
if err != nil {
return nil, err
}
eventIDs, _ := d.EventsTable.BulkSelectEventID(ctx, txn, eventNIDs)
eventIDs, err := d.EventsTable.BulkSelectEventID(ctx, txn, eventNIDs)
if err != nil {
eventIDs = map[types.EventNID]string{}
}
var roomNIDs map[types.EventNID]types.RoomNID
roomNIDs, err = d.EventsTable.SelectRoomNIDsForEventNIDs(ctx, txn, eventNIDs)
if err != nil {
return nil, err
}
uniqueRoomNIDs := make(map[types.RoomNID]struct{})
for _, n := range roomNIDs {
uniqueRoomNIDs[n] = struct{}{}
}
roomVersions := make(map[types.RoomNID]gomatrixserverlib.RoomVersion)
fetchNIDList := make([]types.RoomNID, 0, len(uniqueRoomNIDs))
for n := range uniqueRoomNIDs {
if roomID, ok := d.Cache.GetRoomServerRoomID(n); ok {
if roomVersion, ok := d.Cache.GetRoomVersion(roomID); ok {
roomVersions[n] = roomVersion
continue
}
var roomVersion gomatrixserverlib.RoomVersion
var fetchRoomVersion bool
var ok bool
var roomID string
if roomID, ok = d.Cache.GetRoomServerRoomID(roomNID); ok {
roomVersion, ok = d.Cache.GetRoomVersion(roomID)
if !ok {
fetchRoomVersion = true
}
fetchNIDList = append(fetchNIDList, n)
}
dbRoomVersions, err := d.RoomsTable.SelectRoomVersionsForRoomNIDs(ctx, txn, fetchNIDList)
if err != nil {
return nil, err
}
for n, v := range dbRoomVersions {
roomVersions[n] = v
if roomVersion == "" || fetchRoomVersion {
var dbRoomVersions map[types.RoomNID]gomatrixserverlib.RoomVersion
dbRoomVersions, err = d.RoomsTable.SelectRoomVersionsForRoomNIDs(ctx, txn, []types.RoomNID{roomNID})
if err != nil {
return nil, err
}
if roomVersion, ok = dbRoomVersions[roomNID]; !ok {
return nil, fmt.Errorf("unable to find roomversion for room %d", roomNID)
}
}
for _, eventJSON := range eventJSONs {
roomNID := roomNIDs[eventJSON.EventNID]
roomVersion := roomVersions[roomNID]
events[eventJSON.EventNID], err = gomatrixserverlib.NewEventFromTrustedJSONWithEventID(
eventIDs[eventJSON.EventNID], eventJSON.EventJSON, false, roomVersion,
)
@ -624,73 +644,71 @@ func (d *Database) IsEventRejected(ctx context.Context, roomNID types.RoomNID, e
return d.EventsTable.SelectEventRejected(ctx, nil, roomNID, eventID)
}
func (d *Database) StoreEvent(
ctx context.Context, event *gomatrixserverlib.Event,
authEventNIDs []types.EventNID, isRejected bool,
) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) {
return d.storeEvent(ctx, nil, event, authEventNIDs, isRejected)
// GetOrCreateRoomNID gets or creates a new roomNID for the given event
func (d *Database) GetOrCreateRoomNID(ctx context.Context, event *gomatrixserverlib.Event) (roomNID types.RoomNID, err error) {
// Get the default room version. If the client doesn't supply a room_version
// then we will use our configured default to create the room.
// https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-createroom
// Note that the below logic depends on the m.room.create event being the
// first event that is persisted to the database when creating or joining a
// room.
var roomVersion gomatrixserverlib.RoomVersion
if roomVersion, err = extractRoomVersionFromCreateEvent(event); err != nil {
return 0, fmt.Errorf("extractRoomVersionFromCreateEvent: %w", err)
}
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
roomNID, err = d.assignRoomNID(ctx, txn, event.RoomID(), roomVersion)
if err != nil {
return err
}
return nil
})
return roomNID, err
}
func (d *Database) storeEvent(
ctx context.Context, updater *RoomUpdater, event *gomatrixserverlib.Event,
authEventNIDs []types.EventNID, isRejected bool,
) (types.EventNID, types.RoomNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) {
var (
roomNID types.RoomNID
eventTypeNID types.EventTypeNID
eventStateKeyNID types.EventStateKeyNID
eventNID types.EventNID
stateNID types.StateSnapshotNID
redactionEvent *gomatrixserverlib.Event
redactedEventID string
err error
)
var txn *sql.Tx
if updater != nil && updater.txn != nil {
txn = updater.txn
}
// First writer is with a database-provided transaction, so that NIDs are assigned
// globally outside of the updater context, to help avoid races.
func (d *Database) GetOrCreateEventTypeNID(ctx context.Context, eventType string) (eventTypeNID types.EventTypeNID, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
// TODO: Here we should aim to have two different code paths for new rooms
// vs existing ones.
// Get the default room version. If the client doesn't supply a room_version
// then we will use our configured default to create the room.
// https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-createroom
// Note that the below logic depends on the m.room.create event being the
// first event that is persisted to the database when creating or joining a
// room.
var roomVersion gomatrixserverlib.RoomVersion
if roomVersion, err = extractRoomVersionFromCreateEvent(event); err != nil {
return fmt.Errorf("extractRoomVersionFromCreateEvent: %w", err)
}
if roomNID, err = d.assignRoomNID(ctx, txn, event.RoomID(), roomVersion); err != nil {
return fmt.Errorf("d.assignRoomNID: %w", err)
}
if eventTypeNID, err = d.assignEventTypeNID(ctx, txn, event.Type()); err != nil {
if eventTypeNID, err = d.assignEventTypeNID(ctx, txn, eventType); err != nil {
return fmt.Errorf("d.assignEventTypeNID: %w", err)
}
return nil
})
return eventTypeNID, err
}
eventStateKey := event.StateKey()
// Assigned a numeric ID for the state_key if there is one present.
// Otherwise set the numeric ID for the state_key to 0.
if eventStateKey != nil {
if eventStateKeyNID, err = d.assignStateKeyNID(ctx, txn, *eventStateKey); err != nil {
return fmt.Errorf("d.assignStateKeyNID: %w", err)
}
func (d *Database) GetOrCreateEventStateKeyNID(ctx context.Context, eventStateKey *string) (eventStateKeyNID types.EventStateKeyNID, err error) {
if eventStateKey == nil {
return 0, nil
}
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if eventStateKeyNID, err = d.assignStateKeyNID(ctx, txn, *eventStateKey); err != nil {
return fmt.Errorf("d.assignStateKeyNID: %w", err)
}
return nil
})
if err != nil {
return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.Writer.Do: %w", err)
return 0, err
}
return eventStateKeyNID, nil
}
func (d *Database) StoreEvent(
ctx context.Context, event *gomatrixserverlib.Event,
roomNID types.RoomNID, eventTypeNID types.EventTypeNID, eventStateKeyNID types.EventStateKeyNID,
authEventNIDs []types.EventNID, isRejected bool,
) (types.EventNID, types.StateAtEvent, *gomatrixserverlib.Event, string, error) {
var (
eventNID types.EventNID
stateNID types.StateSnapshotNID
redactionEvent *gomatrixserverlib.Event
redactedEventID string
err error
)
// Second writer is using the database-provided transaction, probably from the
// room updater, for easy roll-back if required.
err = d.Writer.Do(d.DB, txn, func(txn *sql.Tx) error {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if eventNID, stateNID, err = d.EventsTable.InsertEvent(
ctx,
txn,
@ -718,7 +736,7 @@ func (d *Database) storeEvent(
return fmt.Errorf("d.EventJSONTable.InsertEventJSON: %w", err)
}
if !isRejected { // ignore rejected redaction events
redactionEvent, redactedEventID, err = d.handleRedactions(ctx, txn, eventNID, event)
redactionEvent, redactedEventID, err = d.handleRedactions(ctx, txn, roomNID, eventNID, event)
if err != nil {
return fmt.Errorf("d.handleRedactions: %w", err)
}
@ -726,7 +744,7 @@ func (d *Database) storeEvent(
return nil
})
if err != nil {
return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.Writer.Do: %w", err)
return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.Writer.Do: %w", err)
}
// We should attempt to update the previous events table with any
@ -741,28 +759,28 @@ func (d *Database) storeEvent(
// any other so this is fine. If we ever update GetLatestEventsForUpdate or NewLatestEventsUpdater
// to do writes however then this will need to go inside `Writer.Do`.
succeeded := false
if updater == nil {
var roomInfo *types.RoomInfo
roomInfo, err = d.roomInfo(ctx, txn, event.RoomID())
if err != nil {
return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.RoomInfo: %w", err)
}
if roomInfo == nil && len(prevEvents) > 0 {
return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("expected room %q to exist", event.RoomID())
}
updater, err = d.GetRoomUpdater(ctx, roomInfo)
if err != nil {
return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("GetRoomUpdater: %w", err)
}
defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err)
var roomInfo *types.RoomInfo
roomInfo, err = d.roomInfo(ctx, nil, event.RoomID())
if err != nil {
return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("d.RoomInfo: %w", err)
}
if roomInfo == nil && len(prevEvents) > 0 {
return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("expected room %q to exist", event.RoomID())
}
var updater *RoomUpdater
updater, err = d.GetRoomUpdater(ctx, roomInfo)
if err != nil {
return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("GetRoomUpdater: %w", err)
}
defer sqlutil.EndTransactionWithCheck(updater, &succeeded, &err)
if err = updater.StorePreviousEvents(eventNID, prevEvents); err != nil {
return 0, 0, types.StateAtEvent{}, nil, "", fmt.Errorf("updater.StorePreviousEvents: %w", err)
return 0, types.StateAtEvent{}, nil, "", fmt.Errorf("updater.StorePreviousEvents: %w", err)
}
succeeded = true
}
return eventNID, roomNID, types.StateAtEvent{
return eventNID, types.StateAtEvent{
BeforeStateSnapshotNID: stateNID,
StateEntry: types.StateEntry{
StateKeyTuple: types.StateKeyTuple{
@ -814,6 +832,10 @@ func (d *Database) MissingAuthPrevEvents(
func (d *Database) assignRoomNID(
ctx context.Context, txn *sql.Tx, roomID string, roomVersion gomatrixserverlib.RoomVersion,
) (types.RoomNID, error) {
roomNID, ok := d.Cache.GetRoomServerRoomNID(roomID)
if ok {
return roomNID, nil
}
// Check if we already have a numeric ID in the database.
roomNID, err := d.RoomsTable.SelectRoomNID(ctx, txn, roomID)
if err == sql.ErrNoRows {
@ -824,12 +846,20 @@ func (d *Database) assignRoomNID(
roomNID, err = d.RoomsTable.SelectRoomNID(ctx, txn, roomID)
}
}
return roomNID, err
if err != nil {
return 0, err
}
d.Cache.StoreRoomServerRoomID(roomNID, roomID)
return roomNID, nil
}
func (d *Database) assignEventTypeNID(
ctx context.Context, txn *sql.Tx, eventType string,
) (types.EventTypeNID, error) {
eventTypeNID, ok := d.Cache.GetEventTypeKey(eventType)
if ok {
return eventTypeNID, nil
}
// Check if we already have a numeric ID in the database.
eventTypeNID, err := d.EventTypesTable.SelectEventTypeNID(ctx, txn, eventType)
if err == sql.ErrNoRows {
@ -840,12 +870,20 @@ func (d *Database) assignEventTypeNID(
eventTypeNID, err = d.EventTypesTable.SelectEventTypeNID(ctx, txn, eventType)
}
}
return eventTypeNID, err
if err != nil {
return 0, err
}
d.Cache.StoreEventTypeKey(eventTypeNID, eventType)
return eventTypeNID, nil
}
func (d *Database) assignStateKeyNID(
ctx context.Context, txn *sql.Tx, eventStateKey string,
) (types.EventStateKeyNID, error) {
eventStateKeyNID, ok := d.Cache.GetEventStateKeyNID(eventStateKey)
if ok {
return eventStateKeyNID, nil
}
// Check if we already have a numeric ID in the database.
eventStateKeyNID, err := d.EventStateKeysTable.SelectEventStateKeyNID(ctx, txn, eventStateKey)
if err == sql.ErrNoRows {
@ -856,6 +894,7 @@ func (d *Database) assignStateKeyNID(
eventStateKeyNID, err = d.EventStateKeysTable.SelectEventStateKeyNID(ctx, txn, eventStateKey)
}
}
d.Cache.StoreEventStateKey(eventStateKeyNID, eventStateKey)
return eventStateKeyNID, err
}
@ -899,7 +938,7 @@ func extractRoomVersionFromCreateEvent(event *gomatrixserverlib.Event) (
//
// Returns the redaction event and the event ID of the redacted event if this call resulted in a redaction.
func (d *Database) handleRedactions(
ctx context.Context, txn *sql.Tx, eventNID types.EventNID, event *gomatrixserverlib.Event,
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNID types.EventNID, event *gomatrixserverlib.Event,
) (*gomatrixserverlib.Event, string, error) {
var err error
isRedactionEvent := event.Type() == gomatrixserverlib.MRoomRedaction && event.StateKey() == nil
@ -919,7 +958,7 @@ func (d *Database) handleRedactions(
}
}
redactionEvent, redactedEvent, validated, err := d.loadRedactionPair(ctx, txn, eventNID, event)
redactionEvent, redactedEvent, validated, err := d.loadRedactionPair(ctx, txn, roomNID, eventNID, event)
if err != nil {
return nil, "", fmt.Errorf("d.loadRedactionPair: %w", err)
}
@ -985,7 +1024,7 @@ func (d *Database) handleRedactions(
// loadRedactionPair returns both the redaction event and the redacted event, else nil.
func (d *Database) loadRedactionPair(
ctx context.Context, txn *sql.Tx, eventNID types.EventNID, event *gomatrixserverlib.Event,
ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventNID types.EventNID, event *gomatrixserverlib.Event,
) (*types.Event, *types.Event, bool, error) {
var redactionEvent, redactedEvent *types.Event
var info *tables.RedactionInfo
@ -1017,9 +1056,9 @@ func (d *Database) loadRedactionPair(
}
if isRedactionEvent {
redactedEvent = d.loadEvent(ctx, info.RedactsEventID)
redactedEvent = d.loadEvent(ctx, roomNID, info.RedactsEventID)
} else {
redactionEvent = d.loadEvent(ctx, info.RedactionEventID)
redactionEvent = d.loadEvent(ctx, roomNID, info.RedactionEventID)
}
return redactionEvent, redactedEvent, info.Validated, nil
@ -1035,7 +1074,7 @@ func (d *Database) applyRedactions(events []types.Event) {
}
// loadEvent loads a single event or returns nil on any problems/missing event
func (d *Database) loadEvent(ctx context.Context, eventID string) *types.Event {
func (d *Database) loadEvent(ctx context.Context, roomNID types.RoomNID, eventID string) *types.Event {
nids, err := d.EventNIDs(ctx, []string{eventID})
if err != nil {
return nil
@ -1043,7 +1082,7 @@ func (d *Database) loadEvent(ctx context.Context, eventID string) *types.Event {
if len(nids) == 0 {
return nil
}
evs, err := d.Events(ctx, []types.EventNID{nids[eventID]})
evs, err := d.Events(ctx, roomNID, []types.EventNID{nids[eventID].EventNID})
if err != nil {
return nil
}
@ -1470,13 +1509,19 @@ func (d *Database) PurgeRoom(ctx context.Context, roomID string) error {
func (d *Database) UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
// un-publish old room
if err := d.PublishedTable.UpsertRoomPublished(ctx, txn, oldRoomID, "", "", false); err != nil {
return fmt.Errorf("failed to unpublish room: %w", err)
published, err := d.PublishedTable.SelectPublishedFromRoomID(ctx, txn, oldRoomID)
if err != nil {
return fmt.Errorf("failed to get published room: %w", err)
}
// publish new room
if err := d.PublishedTable.UpsertRoomPublished(ctx, txn, newRoomID, "", "", true); err != nil {
return fmt.Errorf("failed to publish room: %w", err)
if published {
// un-publish old room
if err = d.PublishedTable.UpsertRoomPublished(ctx, txn, oldRoomID, "", "", false); err != nil {
return fmt.Errorf("failed to unpublish room: %w", err)
}
// publish new room
if err = d.PublishedTable.UpsertRoomPublished(ctx, txn, newRoomID, "", "", true); err != nil {
return fmt.Errorf("failed to publish room: %w", err)
}
}
// Migrate any existing room aliases

View file

@ -3,7 +3,9 @@ package shared_test
import (
"context"
"testing"
"time"
"github.com/matrix-org/dendrite/internal/caching"
"github.com/stretchr/testify/assert"
"github.com/matrix-org/dendrite/internal/sqlutil"
@ -48,11 +50,14 @@ func mustCreateRoomserverDatabase(t *testing.T, dbType test.DBType) (*shared.Dat
}
assert.NoError(t, err)
cache := caching.NewRistrettoCache(8*1024*1024, time.Hour, false)
return &shared.Database{
DB: db,
EventStateKeysTable: stateKeyTable,
MembershipTable: membershipTable,
Writer: sqlutil.NewExclusiveWriter(),
Cache: cache,
}, func() {
err := base.Close()
assert.NoError(t, err)

View file

@ -110,10 +110,10 @@ const bulkSelectEventIDSQL = "" +
"SELECT event_nid, event_id FROM roomserver_events WHERE event_nid IN ($1)"
const bulkSelectEventNIDSQL = "" +
"SELECT event_id, event_nid FROM roomserver_events WHERE event_id IN ($1)"
"SELECT event_id, event_nid, room_nid FROM roomserver_events WHERE event_id IN ($1)"
const bulkSelectUnsentEventNIDSQL = "" +
"SELECT event_id, event_nid FROM roomserver_events WHERE sent_to_output = 0 AND event_id IN ($1)"
"SELECT event_id, event_nid, room_nid FROM roomserver_events WHERE sent_to_output = 0 AND event_id IN ($1)"
const selectMaxEventDepthSQL = "" +
"SELECT COALESCE(MAX(depth) + 1, 0) FROM roomserver_events WHERE event_nid IN ($1)"
@ -572,20 +572,20 @@ func (s *eventStatements) BulkSelectEventID(ctx context.Context, txn *sql.Tx, ev
// BulkSelectEventNIDs returns a map from string event ID to numeric event ID.
// If an event ID is not in the database then it is omitted from the map.
func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) {
func (s *eventStatements) BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventMetadata, error) {
return s.bulkSelectEventNID(ctx, txn, eventIDs, false)
}
// BulkSelectEventNIDs returns a map from string event ID to numeric event ID
// only for events that haven't already been sent to the roomserver output.
// If an event ID is not in the database then it is omitted from the map.
func (s *eventStatements) BulkSelectUnsentEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error) {
func (s *eventStatements) BulkSelectUnsentEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventMetadata, error) {
return s.bulkSelectEventNID(ctx, txn, eventIDs, true)
}
// bulkSelectEventNIDs returns a map from string event ID to numeric event ID.
// If an event ID is not in the database then it is omitted from the map.
func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string, onlyUnsent bool) (map[string]types.EventNID, error) {
func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string, onlyUnsent bool) (map[string]types.EventMetadata, error) {
///////////////
iEventIDs := make([]interface{}, len(eventIDs))
for k, v := range eventIDs {
@ -609,14 +609,18 @@ func (s *eventStatements) bulkSelectEventNID(ctx context.Context, txn *sql.Tx, e
return nil, err
}
defer internal.CloseAndLogIfError(ctx, rows, "bulkSelectEventNID: rows.close() failed")
results := make(map[string]types.EventNID, len(eventIDs))
results := make(map[string]types.EventMetadata, len(eventIDs))
var eventID string
var eventNID int64
var roomNID int64
for rows.Next() {
if err = rows.Scan(&eventID, &eventNID); err != nil {
if err = rows.Scan(&eventID, &eventNID, &roomNID); err != nil {
return nil, err
}
results[eventID] = types.EventNID(eventNID)
results[eventID] = types.EventMetadata{
EventNID: types.EventNID(eventNID),
RoomNID: types.RoomNID(roomNID),
}
}
return results, nil
}

View file

@ -63,8 +63,8 @@ type Events interface {
BulkSelectEventID(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (map[types.EventNID]string, error)
// BulkSelectEventNIDs returns a map from string event ID to numeric event ID.
// If an event ID is not in the database then it is omitted from the map.
BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error)
BulkSelectUnsentEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventNID, error)
BulkSelectEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventMetadata, error)
BulkSelectUnsentEventNID(ctx context.Context, txn *sql.Tx, eventIDs []string) (map[string]types.EventMetadata, error)
SelectMaxEventDepth(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (int64, error)
SelectRoomNIDsForEventNIDs(ctx context.Context, txn *sql.Tx, eventNIDs []types.EventNID) (roomNIDs map[types.EventNID]types.RoomNID, err error)
SelectEventRejected(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, eventID string) (rejected bool, err error)