mirror of
https://github.com/hoernschen/dendrite.git
synced 2025-08-01 13:52:46 +00:00
Event relations (#2790)
This adds support for tracking `m.relates_to`, as well as adding support for the various `/room/{roomID}/relations/...` endpoints to the CS API.
This commit is contained in:
parent
3c1474f68f
commit
23a3e04579
19 changed files with 943 additions and 51 deletions
|
@ -53,6 +53,7 @@ type Database struct {
|
|||
NotificationData tables.NotificationData
|
||||
Ignores tables.Ignores
|
||||
Presence tables.Presence
|
||||
Relations tables.Relations
|
||||
}
|
||||
|
||||
func (d *Database) NewDatabaseSnapshot(ctx context.Context) (*DatabaseTransaction, error) {
|
||||
|
@ -579,10 +580,40 @@ func (d *Database) SelectMembershipForUser(ctx context.Context, roomID, userID s
|
|||
return d.Memberships.SelectMembershipForUser(ctx, nil, roomID, userID, pos)
|
||||
}
|
||||
|
||||
func (s *Database) ReIndex(ctx context.Context, limit, afterID int64) (map[int64]gomatrixserverlib.HeaderedEvent, error) {
|
||||
return s.OutputEvents.ReIndex(ctx, nil, limit, afterID, []string{
|
||||
func (d *Database) ReIndex(ctx context.Context, limit, afterID int64) (map[int64]gomatrixserverlib.HeaderedEvent, error) {
|
||||
return d.OutputEvents.ReIndex(ctx, nil, limit, afterID, []string{
|
||||
gomatrixserverlib.MRoomName,
|
||||
gomatrixserverlib.MRoomTopic,
|
||||
"m.room.message",
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Database) UpdateRelations(ctx context.Context, event *gomatrixserverlib.HeaderedEvent) error {
|
||||
var content gomatrixserverlib.RelationContent
|
||||
if err := json.Unmarshal(event.Content(), &content); err != nil {
|
||||
return fmt.Errorf("json.Unmarshal: %w", err)
|
||||
}
|
||||
switch {
|
||||
case content.Relations == nil:
|
||||
return nil
|
||||
case content.Relations.EventID == "":
|
||||
return nil
|
||||
case content.Relations.RelationType == "":
|
||||
return nil
|
||||
case event.Type() == gomatrixserverlib.MRoomRedaction:
|
||||
return nil
|
||||
default:
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.Relations.InsertRelation(
|
||||
ctx, txn, event.RoomID(), content.Relations.EventID,
|
||||
event.EventID(), event.Type(), content.Relations.RelationType,
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Database) RedactRelations(ctx context.Context, roomID, redactedEventID string) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.Relations.DeleteRelation(ctx, txn, roomID, redactedEventID)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -589,3 +589,84 @@ func (d *DatabaseTransaction) PresenceAfter(ctx context.Context, after types.Str
|
|||
func (d *DatabaseTransaction) MaxStreamPositionForPresence(ctx context.Context) (types.StreamPosition, error) {
|
||||
return d.Presence.GetMaxPresenceID(ctx, d.txn)
|
||||
}
|
||||
|
||||
func (d *DatabaseTransaction) MaxStreamPositionForRelations(ctx context.Context) (types.StreamPosition, error) {
|
||||
id, err := d.Relations.SelectMaxRelationID(ctx, d.txn)
|
||||
return types.StreamPosition(id), err
|
||||
}
|
||||
|
||||
func (d *DatabaseTransaction) RelationsFor(ctx context.Context, roomID, eventID, relType, eventType string, from, to types.StreamPosition, backwards bool, limit int) (
|
||||
events []types.StreamEvent, prevBatch, nextBatch string, err error,
|
||||
) {
|
||||
r := types.Range{
|
||||
From: from,
|
||||
To: to,
|
||||
Backwards: backwards,
|
||||
}
|
||||
|
||||
if r.Backwards && r.From == 0 {
|
||||
// If we're working backwards (dir=b) and there's no ?from= specified then
|
||||
// we will automatically want to work backwards from the current position,
|
||||
// so find out what that is.
|
||||
if r.From, err = d.MaxStreamPositionForRelations(ctx); err != nil {
|
||||
return nil, "", "", fmt.Errorf("d.MaxStreamPositionForRelations: %w", err)
|
||||
}
|
||||
// The result normally isn't inclusive of the event *at* the ?from=
|
||||
// position, so add 1 here so that we include the most recent relation.
|
||||
r.From++
|
||||
} else if !r.Backwards && r.To == 0 {
|
||||
// If we're working forwards (dir=f) and there's no ?to= specified then
|
||||
// we will automatically want to work forwards towards the current position,
|
||||
// so find out what that is.
|
||||
if r.To, err = d.MaxStreamPositionForRelations(ctx); err != nil {
|
||||
return nil, "", "", fmt.Errorf("d.MaxStreamPositionForRelations: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// First look up any relations from the database. We add one to the limit here
|
||||
// so that we can tell if we're overflowing, as we will only set the "next_batch"
|
||||
// in the response if we are.
|
||||
relations, _, err := d.Relations.SelectRelationsInRange(ctx, d.txn, roomID, eventID, relType, eventType, r, limit+1)
|
||||
if err != nil {
|
||||
return nil, "", "", fmt.Errorf("d.Relations.SelectRelationsInRange: %w", err)
|
||||
}
|
||||
|
||||
// If we specified a relation type then just get those results, otherwise collate
|
||||
// them from all of the returned relation types.
|
||||
entries := []types.RelationEntry{}
|
||||
if relType != "" {
|
||||
entries = relations[relType]
|
||||
} else {
|
||||
for _, e := range relations {
|
||||
entries = append(entries, e...)
|
||||
}
|
||||
}
|
||||
|
||||
// If there were no entries returned, there were no relations, so stop at this point.
|
||||
if len(entries) == 0 {
|
||||
return nil, "", "", nil
|
||||
}
|
||||
|
||||
// Otherwise, let's try and work out what sensible prev_batch and next_batch values
|
||||
// could be. We've requested an extra event by adding one to the limit already so
|
||||
// that we can determine whether or not to provide a "next_batch", so trim off that
|
||||
// event off the end if needs be.
|
||||
if len(entries) > limit {
|
||||
entries = entries[:len(entries)-1]
|
||||
nextBatch = fmt.Sprintf("%d", entries[len(entries)-1].Position)
|
||||
}
|
||||
// TODO: set prevBatch? doesn't seem to affect the tests...
|
||||
|
||||
// Extract all of the event IDs from the relation entries so that we can pull the
|
||||
// events out of the database. Then go and fetch the events.
|
||||
eventIDs := make([]string, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
eventIDs = append(eventIDs, entry.EventID)
|
||||
}
|
||||
events, err = d.OutputEvents.SelectEvents(ctx, d.txn, eventIDs, nil, true)
|
||||
if err != nil {
|
||||
return nil, "", "", fmt.Errorf("d.OutputEvents.SelectEvents: %w", err)
|
||||
}
|
||||
|
||||
return events, prevBatch, nextBatch, nil
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue