mirror of
https://github.com/hoernschen/dendrite.git
synced 2024-12-27 07:28:27 +00:00
Optimize inserting pending PDUs/EDUs (#2821)
This optimizes the association of PDUs/EDUs to their destination by inserting all destinations in one transaction.
This commit is contained in:
parent
e98d75fd63
commit
9e4c3171da
7 changed files with 127 additions and 119 deletions
|
@ -76,40 +76,22 @@ func (oq *destinationQueue) sendEvent(event *gomatrixserverlib.HeaderedEvent, re
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a database entry that associates the given PDU NID with
|
// If there's room in memory to hold the event then add it to the
|
||||||
// this destination queue. We'll then be able to retrieve the PDU
|
// list.
|
||||||
// later.
|
oq.pendingMutex.Lock()
|
||||||
if err := oq.db.AssociatePDUWithDestination(
|
if len(oq.pendingPDUs) < maxPDUsInMemory {
|
||||||
oq.process.Context(),
|
oq.pendingPDUs = append(oq.pendingPDUs, &queuedPDU{
|
||||||
"", // TODO: remove this, as we don't need to persist the transaction ID
|
pdu: event,
|
||||||
oq.destination, // the destination server name
|
receipt: receipt,
|
||||||
receipt, // NIDs from federationapi_queue_json table
|
})
|
||||||
); err != nil {
|
|
||||||
logrus.WithError(err).Errorf("failed to associate PDU %q with destination %q", event.EventID(), oq.destination)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// Check if the destination is blacklisted. If it isn't then wake
|
|
||||||
// up the queue.
|
|
||||||
if !oq.statistics.Blacklisted() {
|
|
||||||
// If there's room in memory to hold the event then add it to the
|
|
||||||
// list.
|
|
||||||
oq.pendingMutex.Lock()
|
|
||||||
if len(oq.pendingPDUs) < maxPDUsInMemory {
|
|
||||||
oq.pendingPDUs = append(oq.pendingPDUs, &queuedPDU{
|
|
||||||
pdu: event,
|
|
||||||
receipt: receipt,
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
oq.overflowed.Store(true)
|
|
||||||
}
|
|
||||||
oq.pendingMutex.Unlock()
|
|
||||||
|
|
||||||
if !oq.backingOff.Load() {
|
|
||||||
oq.wakeQueueAndNotify()
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
oq.overflowed.Store(true)
|
oq.overflowed.Store(true)
|
||||||
}
|
}
|
||||||
|
oq.pendingMutex.Unlock()
|
||||||
|
|
||||||
|
if !oq.backingOff.Load() {
|
||||||
|
oq.wakeQueueAndNotify()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// sendEDU adds the EDU event to the pending queue for the destination.
|
// sendEDU adds the EDU event to the pending queue for the destination.
|
||||||
|
@ -120,41 +102,23 @@ func (oq *destinationQueue) sendEDU(event *gomatrixserverlib.EDU, receipt *share
|
||||||
logrus.Errorf("attempt to send nil EDU with destination %q", oq.destination)
|
logrus.Errorf("attempt to send nil EDU with destination %q", oq.destination)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Create a database entry that associates the given PDU NID with
|
|
||||||
// this destination queue. We'll then be able to retrieve the PDU
|
|
||||||
// later.
|
|
||||||
if err := oq.db.AssociateEDUWithDestination(
|
|
||||||
oq.process.Context(),
|
|
||||||
oq.destination, // the destination server name
|
|
||||||
receipt, // NIDs from federationapi_queue_json table
|
|
||||||
event.Type,
|
|
||||||
nil, // this will use the default expireEDUTypes map
|
|
||||||
); err != nil {
|
|
||||||
logrus.WithError(err).Errorf("failed to associate EDU with destination %q", oq.destination)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// Check if the destination is blacklisted. If it isn't then wake
|
|
||||||
// up the queue.
|
|
||||||
if !oq.statistics.Blacklisted() {
|
|
||||||
// If there's room in memory to hold the event then add it to the
|
|
||||||
// list.
|
|
||||||
oq.pendingMutex.Lock()
|
|
||||||
if len(oq.pendingEDUs) < maxEDUsInMemory {
|
|
||||||
oq.pendingEDUs = append(oq.pendingEDUs, &queuedEDU{
|
|
||||||
edu: event,
|
|
||||||
receipt: receipt,
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
oq.overflowed.Store(true)
|
|
||||||
}
|
|
||||||
oq.pendingMutex.Unlock()
|
|
||||||
|
|
||||||
if !oq.backingOff.Load() {
|
// If there's room in memory to hold the event then add it to the
|
||||||
oq.wakeQueueAndNotify()
|
// list.
|
||||||
}
|
oq.pendingMutex.Lock()
|
||||||
|
if len(oq.pendingEDUs) < maxEDUsInMemory {
|
||||||
|
oq.pendingEDUs = append(oq.pendingEDUs, &queuedEDU{
|
||||||
|
edu: event,
|
||||||
|
receipt: receipt,
|
||||||
|
})
|
||||||
} else {
|
} else {
|
||||||
oq.overflowed.Store(true)
|
oq.overflowed.Store(true)
|
||||||
}
|
}
|
||||||
|
oq.pendingMutex.Unlock()
|
||||||
|
|
||||||
|
if !oq.backingOff.Load() {
|
||||||
|
oq.wakeQueueAndNotify()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleBackoffNotifier is registered as the backoff notification
|
// handleBackoffNotifier is registered as the backoff notification
|
||||||
|
|
|
@ -24,6 +24,7 @@ import (
|
||||||
"github.com/getsentry/sentry-go"
|
"github.com/getsentry/sentry-go"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/prometheus/client_golang/prometheus"
|
"github.com/prometheus/client_golang/prometheus"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
|
|
||||||
|
@ -247,11 +248,25 @@ func (oqs *OutgoingQueues) SendEvent(
|
||||||
}
|
}
|
||||||
|
|
||||||
for destination := range destmap {
|
for destination := range destmap {
|
||||||
if queue := oqs.getQueue(destination); queue != nil {
|
if queue := oqs.getQueue(destination); queue != nil && !queue.statistics.Blacklisted() {
|
||||||
queue.sendEvent(ev, nid)
|
queue.sendEvent(ev, nid)
|
||||||
|
} else {
|
||||||
|
delete(destmap, destination)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Create a database entry that associates the given PDU NID with
|
||||||
|
// this destinations queue. We'll then be able to retrieve the PDU
|
||||||
|
// later.
|
||||||
|
if err := oqs.db.AssociatePDUWithDestinations(
|
||||||
|
oqs.process.Context(),
|
||||||
|
destmap,
|
||||||
|
nid, // NIDs from federationapi_queue_json table
|
||||||
|
); err != nil {
|
||||||
|
logrus.WithError(err).Errorf("failed to associate PDUs %q with destinations", nid)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -321,11 +336,27 @@ func (oqs *OutgoingQueues) SendEDU(
|
||||||
}
|
}
|
||||||
|
|
||||||
for destination := range destmap {
|
for destination := range destmap {
|
||||||
if queue := oqs.getQueue(destination); queue != nil {
|
if queue := oqs.getQueue(destination); queue != nil && !queue.statistics.Blacklisted() {
|
||||||
queue.sendEDU(e, nid)
|
queue.sendEDU(e, nid)
|
||||||
|
} else {
|
||||||
|
delete(destmap, destination)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Create a database entry that associates the given PDU NID with
|
||||||
|
// this destination queue. We'll then be able to retrieve the PDU
|
||||||
|
// later.
|
||||||
|
if err := oqs.db.AssociateEDUWithDestinations(
|
||||||
|
oqs.process.Context(),
|
||||||
|
destmap, // the destination server name
|
||||||
|
nid, // NIDs from federationapi_queue_json table
|
||||||
|
e.Type,
|
||||||
|
nil, // this will use the default expireEDUTypes map
|
||||||
|
); err != nil {
|
||||||
|
logrus.WithError(err).Errorf("failed to associate EDU with destinations")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -25,6 +25,10 @@ import (
|
||||||
"go.uber.org/atomic"
|
"go.uber.org/atomic"
|
||||||
"gotest.tools/v3/poll"
|
"gotest.tools/v3/poll"
|
||||||
|
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/federationapi/api"
|
"github.com/matrix-org/dendrite/federationapi/api"
|
||||||
"github.com/matrix-org/dendrite/federationapi/statistics"
|
"github.com/matrix-org/dendrite/federationapi/statistics"
|
||||||
"github.com/matrix-org/dendrite/federationapi/storage"
|
"github.com/matrix-org/dendrite/federationapi/storage"
|
||||||
|
@ -34,9 +38,6 @@ import (
|
||||||
"github.com/matrix-org/dendrite/setup/process"
|
"github.com/matrix-org/dendrite/setup/process"
|
||||||
"github.com/matrix-org/dendrite/test"
|
"github.com/matrix-org/dendrite/test"
|
||||||
"github.com/matrix-org/dendrite/test/testrig"
|
"github.com/matrix-org/dendrite/test/testrig"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func mustCreateFederationDatabase(t *testing.T, dbType test.DBType, realDatabase bool) (storage.Database, *process.ProcessContext, func()) {
|
func mustCreateFederationDatabase(t *testing.T, dbType test.DBType, realDatabase bool) (storage.Database, *process.ProcessContext, func()) {
|
||||||
|
@ -158,30 +159,36 @@ func (d *fakeDatabase) GetPendingEDUs(ctx context.Context, serverName gomatrixse
|
||||||
return edus, nil
|
return edus, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *fakeDatabase) AssociatePDUWithDestination(ctx context.Context, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt) error {
|
func (d *fakeDatabase) AssociatePDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt) error {
|
||||||
d.dbMutex.Lock()
|
d.dbMutex.Lock()
|
||||||
defer d.dbMutex.Unlock()
|
defer d.dbMutex.Unlock()
|
||||||
|
|
||||||
if _, ok := d.pendingPDUs[receipt]; ok {
|
if _, ok := d.pendingPDUs[receipt]; ok {
|
||||||
if _, ok := d.associatedPDUs[serverName]; !ok {
|
for destination := range destinations {
|
||||||
d.associatedPDUs[serverName] = make(map[*shared.Receipt]struct{})
|
if _, ok := d.associatedPDUs[destination]; !ok {
|
||||||
|
d.associatedPDUs[destination] = make(map[*shared.Receipt]struct{})
|
||||||
|
}
|
||||||
|
d.associatedPDUs[destination][receipt] = struct{}{}
|
||||||
}
|
}
|
||||||
d.associatedPDUs[serverName][receipt] = struct{}{}
|
|
||||||
return nil
|
return nil
|
||||||
} else {
|
} else {
|
||||||
return errors.New("PDU doesn't exist")
|
return errors.New("PDU doesn't exist")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *fakeDatabase) AssociateEDUWithDestination(ctx context.Context, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error {
|
func (d *fakeDatabase) AssociateEDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error {
|
||||||
d.dbMutex.Lock()
|
d.dbMutex.Lock()
|
||||||
defer d.dbMutex.Unlock()
|
defer d.dbMutex.Unlock()
|
||||||
|
|
||||||
if _, ok := d.pendingEDUs[receipt]; ok {
|
if _, ok := d.pendingEDUs[receipt]; ok {
|
||||||
if _, ok := d.associatedEDUs[serverName]; !ok {
|
for destination := range destinations {
|
||||||
d.associatedEDUs[serverName] = make(map[*shared.Receipt]struct{})
|
if _, ok := d.associatedEDUs[destination]; !ok {
|
||||||
|
d.associatedEDUs[destination] = make(map[*shared.Receipt]struct{})
|
||||||
|
}
|
||||||
|
d.associatedEDUs[destination][receipt] = struct{}{}
|
||||||
}
|
}
|
||||||
d.associatedEDUs[serverName][receipt] = struct{}{}
|
|
||||||
return nil
|
return nil
|
||||||
} else {
|
} else {
|
||||||
return errors.New("EDU doesn't exist")
|
return errors.New("EDU doesn't exist")
|
||||||
|
@ -821,15 +828,15 @@ func TestSendPDUBatches(t *testing.T) {
|
||||||
<-pc.WaitForShutdown()
|
<-pc.WaitForShutdown()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
destinations := map[gomatrixserverlib.ServerName]struct{}{destination: {}}
|
||||||
// Populate database with > maxPDUsPerTransaction
|
// Populate database with > maxPDUsPerTransaction
|
||||||
pduMultiplier := uint32(3)
|
pduMultiplier := uint32(3)
|
||||||
for i := 0; i < maxPDUsPerTransaction*int(pduMultiplier); i++ {
|
for i := 0; i < maxPDUsPerTransaction*int(pduMultiplier); i++ {
|
||||||
ev := mustCreatePDU(t)
|
ev := mustCreatePDU(t)
|
||||||
headeredJSON, _ := json.Marshal(ev)
|
headeredJSON, _ := json.Marshal(ev)
|
||||||
nid, _ := db.StoreJSON(pc.Context(), string(headeredJSON))
|
nid, _ := db.StoreJSON(pc.Context(), string(headeredJSON))
|
||||||
now := gomatrixserverlib.AsTimestamp(time.Now())
|
err := db.AssociatePDUWithDestinations(pc.Context(), destinations, nid)
|
||||||
transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, i))
|
assert.NoError(t, err, "failed to associate PDU with destinations")
|
||||||
db.AssociatePDUWithDestination(pc.Context(), transactionID, destination, nid)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ev := mustCreatePDU(t)
|
ev := mustCreatePDU(t)
|
||||||
|
@ -865,13 +872,15 @@ func TestSendEDUBatches(t *testing.T) {
|
||||||
<-pc.WaitForShutdown()
|
<-pc.WaitForShutdown()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
destinations := map[gomatrixserverlib.ServerName]struct{}{destination: {}}
|
||||||
// Populate database with > maxEDUsPerTransaction
|
// Populate database with > maxEDUsPerTransaction
|
||||||
eduMultiplier := uint32(3)
|
eduMultiplier := uint32(3)
|
||||||
for i := 0; i < maxEDUsPerTransaction*int(eduMultiplier); i++ {
|
for i := 0; i < maxEDUsPerTransaction*int(eduMultiplier); i++ {
|
||||||
ev := mustCreateEDU(t)
|
ev := mustCreateEDU(t)
|
||||||
ephemeralJSON, _ := json.Marshal(ev)
|
ephemeralJSON, _ := json.Marshal(ev)
|
||||||
nid, _ := db.StoreJSON(pc.Context(), string(ephemeralJSON))
|
nid, _ := db.StoreJSON(pc.Context(), string(ephemeralJSON))
|
||||||
db.AssociateEDUWithDestination(pc.Context(), destination, nid, ev.Type, nil)
|
err := db.AssociateEDUWithDestinations(pc.Context(), destinations, nid, ev.Type, nil)
|
||||||
|
assert.NoError(t, err, "failed to associate EDU with destinations")
|
||||||
}
|
}
|
||||||
|
|
||||||
ev := mustCreateEDU(t)
|
ev := mustCreateEDU(t)
|
||||||
|
@ -907,23 +916,23 @@ func TestSendPDUAndEDUBatches(t *testing.T) {
|
||||||
<-pc.WaitForShutdown()
|
<-pc.WaitForShutdown()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
destinations := map[gomatrixserverlib.ServerName]struct{}{destination: {}}
|
||||||
// Populate database with > maxEDUsPerTransaction
|
// Populate database with > maxEDUsPerTransaction
|
||||||
multiplier := uint32(3)
|
multiplier := uint32(3)
|
||||||
|
|
||||||
for i := 0; i < maxPDUsPerTransaction*int(multiplier)+1; i++ {
|
for i := 0; i < maxPDUsPerTransaction*int(multiplier)+1; i++ {
|
||||||
ev := mustCreatePDU(t)
|
ev := mustCreatePDU(t)
|
||||||
headeredJSON, _ := json.Marshal(ev)
|
headeredJSON, _ := json.Marshal(ev)
|
||||||
nid, _ := db.StoreJSON(pc.Context(), string(headeredJSON))
|
nid, _ := db.StoreJSON(pc.Context(), string(headeredJSON))
|
||||||
now := gomatrixserverlib.AsTimestamp(time.Now())
|
err := db.AssociatePDUWithDestinations(pc.Context(), destinations, nid)
|
||||||
transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, i))
|
assert.NoError(t, err, "failed to associate PDU with destinations")
|
||||||
db.AssociatePDUWithDestination(pc.Context(), transactionID, destination, nid)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < maxEDUsPerTransaction*int(multiplier); i++ {
|
for i := 0; i < maxEDUsPerTransaction*int(multiplier); i++ {
|
||||||
ev := mustCreateEDU(t)
|
ev := mustCreateEDU(t)
|
||||||
ephemeralJSON, _ := json.Marshal(ev)
|
ephemeralJSON, _ := json.Marshal(ev)
|
||||||
nid, _ := db.StoreJSON(pc.Context(), string(ephemeralJSON))
|
nid, _ := db.StoreJSON(pc.Context(), string(ephemeralJSON))
|
||||||
db.AssociateEDUWithDestination(pc.Context(), destination, nid, ev.Type, nil)
|
err := db.AssociateEDUWithDestinations(pc.Context(), destinations, nid, ev.Type, nil)
|
||||||
|
assert.NoError(t, err, "failed to associate EDU with destinations")
|
||||||
}
|
}
|
||||||
|
|
||||||
ev := mustCreateEDU(t)
|
ev := mustCreateEDU(t)
|
||||||
|
@ -960,13 +969,12 @@ func TestExternalFailureBackoffDoesntStartQueue(t *testing.T) {
|
||||||
|
|
||||||
dest := queues.getQueue(destination)
|
dest := queues.getQueue(destination)
|
||||||
queues.statistics.ForServer(destination).Failure()
|
queues.statistics.ForServer(destination).Failure()
|
||||||
|
destinations := map[gomatrixserverlib.ServerName]struct{}{destination: {}}
|
||||||
ev := mustCreatePDU(t)
|
ev := mustCreatePDU(t)
|
||||||
headeredJSON, _ := json.Marshal(ev)
|
headeredJSON, _ := json.Marshal(ev)
|
||||||
nid, _ := db.StoreJSON(pc.Context(), string(headeredJSON))
|
nid, _ := db.StoreJSON(pc.Context(), string(headeredJSON))
|
||||||
now := gomatrixserverlib.AsTimestamp(time.Now())
|
err := db.AssociatePDUWithDestinations(pc.Context(), destinations, nid)
|
||||||
transactionID := gomatrixserverlib.TransactionID(fmt.Sprintf("%d-%d", now, 1))
|
assert.NoError(t, err, "failed to associate PDU with destinations")
|
||||||
db.AssociatePDUWithDestination(pc.Context(), transactionID, destination, nid)
|
|
||||||
|
|
||||||
pollEnd := time.Now().Add(3 * time.Second)
|
pollEnd := time.Now().Add(3 * time.Second)
|
||||||
runningCheck := func(log poll.LogT) poll.Result {
|
runningCheck := func(log poll.LogT) poll.Result {
|
||||||
|
@ -988,6 +996,7 @@ func TestQueueInteractsWithRealDatabasePDUAndEDU(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
failuresUntilBlacklist := uint32(1)
|
failuresUntilBlacklist := uint32(1)
|
||||||
destination := gomatrixserverlib.ServerName("remotehost")
|
destination := gomatrixserverlib.ServerName("remotehost")
|
||||||
|
destinations := map[gomatrixserverlib.ServerName]struct{}{destination: {}}
|
||||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, dbType, true)
|
db, fc, queues, pc, close := testSetup(failuresUntilBlacklist, false, t, dbType, true)
|
||||||
// NOTE : These defers aren't called if go test is killed so the dbs may not get cleaned up.
|
// NOTE : These defers aren't called if go test is killed so the dbs may not get cleaned up.
|
||||||
|
@ -1009,7 +1018,8 @@ func TestQueueInteractsWithRealDatabasePDUAndEDU(t *testing.T) {
|
||||||
edu := mustCreateEDU(t)
|
edu := mustCreateEDU(t)
|
||||||
ephemeralJSON, _ := json.Marshal(edu)
|
ephemeralJSON, _ := json.Marshal(edu)
|
||||||
nid, _ := db.StoreJSON(pc.Context(), string(ephemeralJSON))
|
nid, _ := db.StoreJSON(pc.Context(), string(ephemeralJSON))
|
||||||
db.AssociateEDUWithDestination(pc.Context(), destination, nid, edu.Type, nil)
|
err = db.AssociateEDUWithDestinations(pc.Context(), destinations, nid, edu.Type, nil)
|
||||||
|
assert.NoError(t, err, "failed to associate EDU with destinations")
|
||||||
|
|
||||||
checkBlacklisted := func(log poll.LogT) poll.Result {
|
checkBlacklisted := func(log poll.LogT) poll.Result {
|
||||||
if fc.txCount.Load() == failuresUntilBlacklist {
|
if fc.txCount.Load() == failuresUntilBlacklist {
|
||||||
|
|
|
@ -18,9 +18,10 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/federationapi/storage/shared"
|
"github.com/matrix-org/dendrite/federationapi/storage/shared"
|
||||||
"github.com/matrix-org/dendrite/federationapi/types"
|
"github.com/matrix-org/dendrite/federationapi/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Database interface {
|
type Database interface {
|
||||||
|
@ -38,8 +39,8 @@ type Database interface {
|
||||||
GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent, err error)
|
GetPendingPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (pdus map[*shared.Receipt]*gomatrixserverlib.HeaderedEvent, err error)
|
||||||
GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*shared.Receipt]*gomatrixserverlib.EDU, err error)
|
GetPendingEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, limit int) (edus map[*shared.Receipt]*gomatrixserverlib.EDU, err error)
|
||||||
|
|
||||||
AssociatePDUWithDestination(ctx context.Context, transactionID gomatrixserverlib.TransactionID, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt) error
|
AssociatePDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt) error
|
||||||
AssociateEDUWithDestination(ctx context.Context, serverName gomatrixserverlib.ServerName, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error
|
AssociateEDUWithDestinations(ctx context.Context, destinations map[gomatrixserverlib.ServerName]struct{}, receipt *shared.Receipt, eduType string, expireEDUTypes map[string]time.Duration) error
|
||||||
|
|
||||||
CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error
|
CleanPDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error
|
||||||
CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error
|
CleanEDUs(ctx context.Context, serverName gomatrixserverlib.ServerName, receipts []*shared.Receipt) error
|
||||||
|
|
|
@ -38,9 +38,9 @@ var defaultExpireEDUTypes = map[string]time.Duration{
|
||||||
// AssociateEDUWithDestination creates an association that the
|
// AssociateEDUWithDestination creates an association that the
|
||||||
// destination queues will use to determine which JSON blobs to send
|
// destination queues will use to determine which JSON blobs to send
|
||||||
// to which servers.
|
// to which servers.
|
||||||
func (d *Database) AssociateEDUWithDestination(
|
func (d *Database) AssociateEDUWithDestinations(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
serverName gomatrixserverlib.ServerName,
|
destinations map[gomatrixserverlib.ServerName]struct{},
|
||||||
receipt *Receipt,
|
receipt *Receipt,
|
||||||
eduType string,
|
eduType string,
|
||||||
expireEDUTypes map[string]time.Duration,
|
expireEDUTypes map[string]time.Duration,
|
||||||
|
@ -59,17 +59,18 @@ func (d *Database) AssociateEDUWithDestination(
|
||||||
expiresAt = 0
|
expiresAt = 0
|
||||||
}
|
}
|
||||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
if err := d.FederationQueueEDUs.InsertQueueEDU(
|
var err error
|
||||||
ctx, // context
|
for destination := range destinations {
|
||||||
txn, // SQL transaction
|
err = d.FederationQueueEDUs.InsertQueueEDU(
|
||||||
eduType, // EDU type for coalescing
|
ctx, // context
|
||||||
serverName, // destination server name
|
txn, // SQL transaction
|
||||||
receipt.nid, // NID from the federationapi_queue_json table
|
eduType, // EDU type for coalescing
|
||||||
expiresAt, // The timestamp this EDU will expire
|
destination, // destination server name
|
||||||
); err != nil {
|
receipt.nid, // NID from the federationapi_queue_json table
|
||||||
return fmt.Errorf("InsertQueueEDU: %w", err)
|
expiresAt, // The timestamp this EDU will expire
|
||||||
|
)
|
||||||
}
|
}
|
||||||
return nil
|
return err
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -27,23 +27,23 @@ import (
|
||||||
// AssociatePDUWithDestination creates an association that the
|
// AssociatePDUWithDestination creates an association that the
|
||||||
// destination queues will use to determine which JSON blobs to send
|
// destination queues will use to determine which JSON blobs to send
|
||||||
// to which servers.
|
// to which servers.
|
||||||
func (d *Database) AssociatePDUWithDestination(
|
func (d *Database) AssociatePDUWithDestinations(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
transactionID gomatrixserverlib.TransactionID,
|
destinations map[gomatrixserverlib.ServerName]struct{},
|
||||||
serverName gomatrixserverlib.ServerName,
|
|
||||||
receipt *Receipt,
|
receipt *Receipt,
|
||||||
) error {
|
) error {
|
||||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
if err := d.FederationQueuePDUs.InsertQueuePDU(
|
var err error
|
||||||
ctx, // context
|
for destination := range destinations {
|
||||||
txn, // SQL transaction
|
err = d.FederationQueuePDUs.InsertQueuePDU(
|
||||||
transactionID, // transaction ID
|
ctx, // context
|
||||||
serverName, // destination server name
|
txn, // SQL transaction
|
||||||
receipt.nid, // NID from the federationapi_queue_json table
|
"", // transaction ID
|
||||||
); err != nil {
|
destination, // destination server name
|
||||||
return fmt.Errorf("InsertQueuePDU: %w", err)
|
receipt.nid, // NID from the federationapi_queue_json table
|
||||||
|
)
|
||||||
}
|
}
|
||||||
return nil
|
return err
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -35,6 +35,7 @@ func TestExpireEDUs(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
destinations := map[gomatrixserverlib.ServerName]struct{}{"localhost": {}}
|
||||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
db, close := mustCreateFederationDatabase(t, dbType)
|
db, close := mustCreateFederationDatabase(t, dbType)
|
||||||
defer close()
|
defer close()
|
||||||
|
@ -43,7 +44,7 @@ func TestExpireEDUs(t *testing.T) {
|
||||||
receipt, err := db.StoreJSON(ctx, "{}")
|
receipt, err := db.StoreJSON(ctx, "{}")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
err = db.AssociateEDUWithDestination(ctx, "localhost", receipt, gomatrixserverlib.MReceipt, expireEDUTypes)
|
err = db.AssociateEDUWithDestinations(ctx, destinations, receipt, gomatrixserverlib.MReceipt, expireEDUTypes)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
}
|
}
|
||||||
// add data without expiry
|
// add data without expiry
|
||||||
|
@ -51,7 +52,7 @@ func TestExpireEDUs(t *testing.T) {
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
// m.read_marker gets the default expiry of 24h, so won't be deleted further down in this test
|
// m.read_marker gets the default expiry of 24h, so won't be deleted further down in this test
|
||||||
err = db.AssociateEDUWithDestination(ctx, "localhost", receipt, "m.read_marker", expireEDUTypes)
|
err = db.AssociateEDUWithDestinations(ctx, destinations, receipt, "m.read_marker", expireEDUTypes)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
// Delete expired EDUs
|
// Delete expired EDUs
|
||||||
|
@ -67,7 +68,7 @@ func TestExpireEDUs(t *testing.T) {
|
||||||
receipt, err = db.StoreJSON(ctx, "{}")
|
receipt, err = db.StoreJSON(ctx, "{}")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
err = db.AssociateEDUWithDestination(ctx, "localhost", receipt, gomatrixserverlib.MDirectToDevice, expireEDUTypes)
|
err = db.AssociateEDUWithDestinations(ctx, destinations, receipt, gomatrixserverlib.MDirectToDevice, expireEDUTypes)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
|
|
||||||
err = db.DeleteExpiredEDUs(ctx)
|
err = db.DeleteExpiredEDUs(ctx)
|
||||||
|
|
Loading…
Reference in a new issue