mirror of
https://github.com/hoernschen/dendrite.git
synced 2024-12-27 07:28:27 +00:00
Cleanup stale device lists for users we don't share a room with anymore (#2857)
The stale device lists table might contain entries for users we don't share a room with anymore. This now asks the roomserver about left users and removes those entries from the table. Co-authored-by: Neil Alexander <neilalexander@users.noreply.github.com>
This commit is contained in:
parent
aaf4e5c865
commit
7d2344049d
31 changed files with 666 additions and 40 deletions
|
@ -180,14 +180,14 @@ func startup() {
|
||||||
base := base.NewBaseDendrite(cfg, "Monolith")
|
base := base.NewBaseDendrite(cfg, "Monolith")
|
||||||
defer base.Close() // nolint: errcheck
|
defer base.Close() // nolint: errcheck
|
||||||
|
|
||||||
|
rsAPI := roomserver.NewInternalAPI(base)
|
||||||
|
|
||||||
federation := conn.CreateFederationClient(base, pSessions)
|
federation := conn.CreateFederationClient(base, pSessions)
|
||||||
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation)
|
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation, rsAPI)
|
||||||
|
|
||||||
serverKeyAPI := &signing.YggdrasilKeys{}
|
serverKeyAPI := &signing.YggdrasilKeys{}
|
||||||
keyRing := serverKeyAPI.KeyRing()
|
keyRing := serverKeyAPI.KeyRing()
|
||||||
|
|
||||||
rsAPI := roomserver.NewInternalAPI(base)
|
|
||||||
|
|
||||||
userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient())
|
userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient())
|
||||||
keyAPI.SetUserAPI(userAPI)
|
keyAPI.SetUserAPI(userAPI)
|
||||||
|
|
||||||
|
|
|
@ -350,7 +350,7 @@ func (m *DendriteMonolith) Start() {
|
||||||
base, federation, rsAPI, base.Caches, keyRing, true,
|
base, federation, rsAPI, base.Caches, keyRing, true,
|
||||||
)
|
)
|
||||||
|
|
||||||
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI)
|
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI, rsAPI)
|
||||||
m.userAPI = userapi.NewInternalAPI(base, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, base.PushGatewayHTTPClient())
|
m.userAPI = userapi.NewInternalAPI(base, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, base.PushGatewayHTTPClient())
|
||||||
keyAPI.SetUserAPI(m.userAPI)
|
keyAPI.SetUserAPI(m.userAPI)
|
||||||
|
|
||||||
|
|
|
@ -165,7 +165,7 @@ func (m *DendriteMonolith) Start() {
|
||||||
base, federation, rsAPI, base.Caches, keyRing, true,
|
base, federation, rsAPI, base.Caches, keyRing, true,
|
||||||
)
|
)
|
||||||
|
|
||||||
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation)
|
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation, rsAPI)
|
||||||
userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, base.PushGatewayHTTPClient())
|
userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, cfg.Derived.ApplicationServices, keyAPI, rsAPI, base.PushGatewayHTTPClient())
|
||||||
keyAPI.SetUserAPI(userAPI)
|
keyAPI.SetUserAPI(userAPI)
|
||||||
|
|
||||||
|
|
|
@ -213,7 +213,7 @@ func main() {
|
||||||
base, federation, rsAPI, base.Caches, keyRing, true,
|
base, federation, rsAPI, base.Caches, keyRing, true,
|
||||||
)
|
)
|
||||||
|
|
||||||
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI)
|
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI, rsComponent)
|
||||||
userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient())
|
userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient())
|
||||||
keyAPI.SetUserAPI(userAPI)
|
keyAPI.SetUserAPI(userAPI)
|
||||||
|
|
||||||
|
|
|
@ -157,11 +157,12 @@ func main() {
|
||||||
serverKeyAPI := &signing.YggdrasilKeys{}
|
serverKeyAPI := &signing.YggdrasilKeys{}
|
||||||
keyRing := serverKeyAPI.KeyRing()
|
keyRing := serverKeyAPI.KeyRing()
|
||||||
|
|
||||||
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation)
|
|
||||||
|
|
||||||
rsComponent := roomserver.NewInternalAPI(
|
rsComponent := roomserver.NewInternalAPI(
|
||||||
base,
|
base,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
keyAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, federation, rsComponent)
|
||||||
|
|
||||||
rsAPI := rsComponent
|
rsAPI := rsComponent
|
||||||
|
|
||||||
userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient())
|
userAPI := userapi.NewInternalAPI(base, &cfg.UserAPI, nil, keyAPI, rsAPI, base.PushGatewayHTTPClient())
|
||||||
|
|
|
@ -95,7 +95,7 @@ func main() {
|
||||||
}
|
}
|
||||||
keyRing := fsAPI.KeyRing()
|
keyRing := fsAPI.KeyRing()
|
||||||
|
|
||||||
keyImpl := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI)
|
keyImpl := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI, rsAPI)
|
||||||
keyAPI := keyImpl
|
keyAPI := keyImpl
|
||||||
if base.UseHTTPAPIs {
|
if base.UseHTTPAPIs {
|
||||||
keyserver.AddInternalRoutes(base.InternalAPIMux, keyAPI, base.EnableMetrics)
|
keyserver.AddInternalRoutes(base.InternalAPIMux, keyAPI, base.EnableMetrics)
|
||||||
|
|
|
@ -22,7 +22,8 @@ import (
|
||||||
|
|
||||||
func KeyServer(base *basepkg.BaseDendrite, cfg *config.Dendrite) {
|
func KeyServer(base *basepkg.BaseDendrite, cfg *config.Dendrite) {
|
||||||
fsAPI := base.FederationAPIHTTPClient()
|
fsAPI := base.FederationAPIHTTPClient()
|
||||||
intAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI)
|
rsAPI := base.RoomserverHTTPClient()
|
||||||
|
intAPI := keyserver.NewInternalAPI(base, &base.Cfg.KeyServer, fsAPI, rsAPI)
|
||||||
intAPI.SetUserAPI(base.UserAPIClient())
|
intAPI.SetUserAPI(base.UserAPIClient())
|
||||||
|
|
||||||
keyserver.AddInternalRoutes(base.InternalAPIMux, intAPI, base.EnableMetrics)
|
keyserver.AddInternalRoutes(base.InternalAPIMux, intAPI, base.EnableMetrics)
|
||||||
|
|
|
@ -24,6 +24,8 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
rsapi "github.com/matrix-org/dendrite/roomserver/api"
|
||||||
|
|
||||||
"github.com/matrix-org/gomatrix"
|
"github.com/matrix-org/gomatrix"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
"github.com/matrix-org/util"
|
"github.com/matrix-org/util"
|
||||||
|
@ -102,6 +104,7 @@ type DeviceListUpdater struct {
|
||||||
// block on or timeout via a select.
|
// block on or timeout via a select.
|
||||||
userIDToChan map[string]chan bool
|
userIDToChan map[string]chan bool
|
||||||
userIDToChanMu *sync.Mutex
|
userIDToChanMu *sync.Mutex
|
||||||
|
rsAPI rsapi.KeyserverRoomserverAPI
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeviceListUpdaterDatabase is the subset of functionality from storage.Database required for the updater.
|
// DeviceListUpdaterDatabase is the subset of functionality from storage.Database required for the updater.
|
||||||
|
@ -124,6 +127,8 @@ type DeviceListUpdaterDatabase interface {
|
||||||
|
|
||||||
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
|
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
|
||||||
DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
|
DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
|
||||||
|
|
||||||
|
DeleteStaleDeviceLists(ctx context.Context, userIDs []string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type DeviceListUpdaterAPI interface {
|
type DeviceListUpdaterAPI interface {
|
||||||
|
@ -140,7 +145,7 @@ func NewDeviceListUpdater(
|
||||||
process *process.ProcessContext, db DeviceListUpdaterDatabase,
|
process *process.ProcessContext, db DeviceListUpdaterDatabase,
|
||||||
api DeviceListUpdaterAPI, producer KeyChangeProducer,
|
api DeviceListUpdaterAPI, producer KeyChangeProducer,
|
||||||
fedClient fedsenderapi.KeyserverFederationAPI, numWorkers int,
|
fedClient fedsenderapi.KeyserverFederationAPI, numWorkers int,
|
||||||
thisServer gomatrixserverlib.ServerName,
|
rsAPI rsapi.KeyserverRoomserverAPI, thisServer gomatrixserverlib.ServerName,
|
||||||
) *DeviceListUpdater {
|
) *DeviceListUpdater {
|
||||||
return &DeviceListUpdater{
|
return &DeviceListUpdater{
|
||||||
process: process,
|
process: process,
|
||||||
|
@ -154,6 +159,7 @@ func NewDeviceListUpdater(
|
||||||
workerChans: make([]chan gomatrixserverlib.ServerName, numWorkers),
|
workerChans: make([]chan gomatrixserverlib.ServerName, numWorkers),
|
||||||
userIDToChan: make(map[string]chan bool),
|
userIDToChan: make(map[string]chan bool),
|
||||||
userIDToChanMu: &sync.Mutex{},
|
userIDToChanMu: &sync.Mutex{},
|
||||||
|
rsAPI: rsAPI,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -168,7 +174,7 @@ func (u *DeviceListUpdater) Start() error {
|
||||||
go u.worker(ch)
|
go u.worker(ch)
|
||||||
}
|
}
|
||||||
|
|
||||||
staleLists, err := u.db.StaleDeviceLists(context.Background(), []gomatrixserverlib.ServerName{})
|
staleLists, err := u.db.StaleDeviceLists(u.process.Context(), []gomatrixserverlib.ServerName{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -186,6 +192,25 @@ func (u *DeviceListUpdater) Start() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CleanUp removes stale device entries for users we don't share a room with anymore
|
||||||
|
func (u *DeviceListUpdater) CleanUp() error {
|
||||||
|
staleUsers, err := u.db.StaleDeviceLists(u.process.Context(), []gomatrixserverlib.ServerName{})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
res := rsapi.QueryLeftUsersResponse{}
|
||||||
|
if err = u.rsAPI.QueryLeftUsers(u.process.Context(), &rsapi.QueryLeftUsersRequest{StaleDeviceListUsers: staleUsers}, &res); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(res.LeftUsers) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
logrus.Debugf("Deleting %d stale device list entries", len(res.LeftUsers))
|
||||||
|
return u.db.DeleteStaleDeviceLists(u.process.Context(), res.LeftUsers)
|
||||||
|
}
|
||||||
|
|
||||||
func (u *DeviceListUpdater) mutex(userID string) *sync.Mutex {
|
func (u *DeviceListUpdater) mutex(userID string) *sync.Mutex {
|
||||||
u.mu.Lock()
|
u.mu.Lock()
|
||||||
defer u.mu.Unlock()
|
defer u.mu.Unlock()
|
||||||
|
|
|
@ -30,7 +30,12 @@ import (
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/keyserver/api"
|
"github.com/matrix-org/dendrite/keyserver/api"
|
||||||
|
"github.com/matrix-org/dendrite/keyserver/storage"
|
||||||
|
roomserver "github.com/matrix-org/dendrite/roomserver/api"
|
||||||
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
"github.com/matrix-org/dendrite/setup/process"
|
"github.com/matrix-org/dendrite/setup/process"
|
||||||
|
"github.com/matrix-org/dendrite/test"
|
||||||
|
"github.com/matrix-org/dendrite/test/testrig"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -53,6 +58,10 @@ type mockDeviceListUpdaterDatabase struct {
|
||||||
mu sync.Mutex // protect staleUsers
|
mu sync.Mutex // protect staleUsers
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *mockDeviceListUpdaterDatabase) DeleteStaleDeviceLists(ctx context.Context, userIDs []string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
|
// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
|
||||||
// If no domains are given, all user IDs with stale device lists are returned.
|
// If no domains are given, all user IDs with stale device lists are returned.
|
||||||
func (d *mockDeviceListUpdaterDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
|
func (d *mockDeviceListUpdaterDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
|
||||||
|
@ -153,7 +162,7 @@ func TestUpdateHavePrevID(t *testing.T) {
|
||||||
}
|
}
|
||||||
ap := &mockDeviceListUpdaterAPI{}
|
ap := &mockDeviceListUpdaterAPI{}
|
||||||
producer := &mockKeyChangeProducer{}
|
producer := &mockKeyChangeProducer{}
|
||||||
updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1, "localhost")
|
updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, nil, 1, nil, "localhost")
|
||||||
event := gomatrixserverlib.DeviceListUpdateEvent{
|
event := gomatrixserverlib.DeviceListUpdateEvent{
|
||||||
DeviceDisplayName: "Foo Bar",
|
DeviceDisplayName: "Foo Bar",
|
||||||
Deleted: false,
|
Deleted: false,
|
||||||
|
@ -225,7 +234,7 @@ func TestUpdateNoPrevID(t *testing.T) {
|
||||||
`)),
|
`)),
|
||||||
}, nil
|
}, nil
|
||||||
})
|
})
|
||||||
updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2, "example.test")
|
updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 2, nil, "example.test")
|
||||||
if err := updater.Start(); err != nil {
|
if err := updater.Start(); err != nil {
|
||||||
t.Fatalf("failed to start updater: %s", err)
|
t.Fatalf("failed to start updater: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -239,6 +248,7 @@ func TestUpdateNoPrevID(t *testing.T) {
|
||||||
UserID: remoteUserID,
|
UserID: remoteUserID,
|
||||||
}
|
}
|
||||||
err := updater.Update(ctx, event)
|
err := updater.Update(ctx, event)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Update returned an error: %s", err)
|
t.Fatalf("Update returned an error: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -294,7 +304,7 @@ func TestDebounce(t *testing.T) {
|
||||||
close(incomingFedReq)
|
close(incomingFedReq)
|
||||||
return <-fedCh, nil
|
return <-fedCh, nil
|
||||||
})
|
})
|
||||||
updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1, "localhost")
|
updater := NewDeviceListUpdater(process.NewProcessContext(), db, ap, producer, fedClient, 1, nil, "localhost")
|
||||||
if err := updater.Start(); err != nil {
|
if err := updater.Start(); err != nil {
|
||||||
t.Fatalf("failed to start updater: %s", err)
|
t.Fatalf("failed to start updater: %s", err)
|
||||||
}
|
}
|
||||||
|
@ -349,3 +359,73 @@ func TestDebounce(t *testing.T) {
|
||||||
t.Errorf("user %s is marked as stale", userID)
|
t.Errorf("user %s is marked as stale", userID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func mustCreateKeyserverDB(t *testing.T, dbType test.DBType) (storage.Database, func()) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
base, _, _ := testrig.Base(nil)
|
||||||
|
connStr, clearDB := test.PrepareDBConnectionString(t, dbType)
|
||||||
|
db, err := storage.NewDatabase(base, &config.DatabaseOptions{ConnectionString: config.DataSource(connStr)})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return db, clearDB
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockKeyserverRoomserverAPI struct {
|
||||||
|
leftUsers []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockKeyserverRoomserverAPI) QueryLeftUsers(ctx context.Context, req *roomserver.QueryLeftUsersRequest, res *roomserver.QueryLeftUsersResponse) error {
|
||||||
|
res.LeftUsers = m.leftUsers
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeviceListUpdater_CleanUp(t *testing.T) {
|
||||||
|
processCtx := process.NewProcessContext()
|
||||||
|
|
||||||
|
alice := test.NewUser(t)
|
||||||
|
bob := test.NewUser(t)
|
||||||
|
|
||||||
|
// Bob is not joined to any of our rooms
|
||||||
|
rsAPI := &mockKeyserverRoomserverAPI{leftUsers: []string{bob.ID}}
|
||||||
|
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
db, clearDB := mustCreateKeyserverDB(t, dbType)
|
||||||
|
defer clearDB()
|
||||||
|
|
||||||
|
// This should not get deleted
|
||||||
|
if err := db.MarkDeviceListStale(processCtx.Context(), alice.ID, true); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// this one should get deleted
|
||||||
|
if err := db.MarkDeviceListStale(processCtx.Context(), bob.ID, true); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
updater := NewDeviceListUpdater(processCtx, db, nil,
|
||||||
|
nil, nil,
|
||||||
|
0, rsAPI, "test")
|
||||||
|
if err := updater.CleanUp(); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// check that we still have Alice in our stale list
|
||||||
|
staleUsers, err := db.StaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"})
|
||||||
|
if err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// There should only be Alice
|
||||||
|
wantCount := 1
|
||||||
|
if count := len(staleUsers); count != wantCount {
|
||||||
|
t.Fatalf("expected there to be %d stale device lists, got %d", wantCount, count)
|
||||||
|
}
|
||||||
|
|
||||||
|
if staleUsers[0] != alice.ID {
|
||||||
|
t.Fatalf("unexpected stale device list user: %s, want %s", staleUsers[0], alice.ID)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
@ -18,6 +18,8 @@ import (
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
rsapi "github.com/matrix-org/dendrite/roomserver/api"
|
||||||
|
|
||||||
fedsenderapi "github.com/matrix-org/dendrite/federationapi/api"
|
fedsenderapi "github.com/matrix-org/dendrite/federationapi/api"
|
||||||
"github.com/matrix-org/dendrite/keyserver/api"
|
"github.com/matrix-org/dendrite/keyserver/api"
|
||||||
"github.com/matrix-org/dendrite/keyserver/consumers"
|
"github.com/matrix-org/dendrite/keyserver/consumers"
|
||||||
|
@ -40,6 +42,7 @@ func AddInternalRoutes(router *mux.Router, intAPI api.KeyInternalAPI, enableMetr
|
||||||
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
|
// can call functions directly on the returned API or via an HTTP interface using AddInternalRoutes.
|
||||||
func NewInternalAPI(
|
func NewInternalAPI(
|
||||||
base *base.BaseDendrite, cfg *config.KeyServer, fedClient fedsenderapi.KeyserverFederationAPI,
|
base *base.BaseDendrite, cfg *config.KeyServer, fedClient fedsenderapi.KeyserverFederationAPI,
|
||||||
|
rsAPI rsapi.KeyserverRoomserverAPI,
|
||||||
) api.KeyInternalAPI {
|
) api.KeyInternalAPI {
|
||||||
js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream)
|
js, _ := base.NATS.Prepare(base.ProcessContext, &cfg.Matrix.JetStream)
|
||||||
|
|
||||||
|
@ -47,6 +50,7 @@ func NewInternalAPI(
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithError(err).Panicf("failed to connect to key server database")
|
logrus.WithError(err).Panicf("failed to connect to key server database")
|
||||||
}
|
}
|
||||||
|
|
||||||
keyChangeProducer := &producers.KeyChange{
|
keyChangeProducer := &producers.KeyChange{
|
||||||
Topic: string(cfg.Matrix.JetStream.Prefixed(jetstream.OutputKeyChangeEvent)),
|
Topic: string(cfg.Matrix.JetStream.Prefixed(jetstream.OutputKeyChangeEvent)),
|
||||||
JetStream: js,
|
JetStream: js,
|
||||||
|
@ -58,8 +62,14 @@ func NewInternalAPI(
|
||||||
FedClient: fedClient,
|
FedClient: fedClient,
|
||||||
Producer: keyChangeProducer,
|
Producer: keyChangeProducer,
|
||||||
}
|
}
|
||||||
updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8, cfg.Matrix.ServerName) // 8 workers TODO: configurable
|
updater := internal.NewDeviceListUpdater(base.ProcessContext, db, ap, keyChangeProducer, fedClient, 8, rsAPI, cfg.Matrix.ServerName) // 8 workers TODO: configurable
|
||||||
ap.Updater = updater
|
ap.Updater = updater
|
||||||
|
|
||||||
|
// Remove users which we don't share a room with anymore
|
||||||
|
if err := updater.CleanUp(); err != nil {
|
||||||
|
logrus.WithError(err).Error("failed to cleanup stale device lists")
|
||||||
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
if err := updater.Start(); err != nil {
|
if err := updater.Start(); err != nil {
|
||||||
logrus.WithError(err).Panicf("failed to start device list updater")
|
logrus.WithError(err).Panicf("failed to start device list updater")
|
||||||
|
|
29
keyserver/keyserver_test.go
Normal file
29
keyserver/keyserver_test.go
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
package keyserver
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
roomserver "github.com/matrix-org/dendrite/roomserver/api"
|
||||||
|
"github.com/matrix-org/dendrite/test"
|
||||||
|
"github.com/matrix-org/dendrite/test/testrig"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockKeyserverRoomserverAPI struct {
|
||||||
|
leftUsers []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockKeyserverRoomserverAPI) QueryLeftUsers(ctx context.Context, req *roomserver.QueryLeftUsersRequest, res *roomserver.QueryLeftUsersResponse) error {
|
||||||
|
res.LeftUsers = m.leftUsers
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Merely tests that we can create an internal keyserver API
|
||||||
|
func Test_NewInternalAPI(t *testing.T) {
|
||||||
|
rsAPI := &mockKeyserverRoomserverAPI{}
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
base, closeBase := testrig.CreateBaseDendrite(t, dbType)
|
||||||
|
defer closeBase()
|
||||||
|
_ = NewInternalAPI(base, &base.Cfg.KeyServer, nil, rsAPI)
|
||||||
|
})
|
||||||
|
}
|
|
@ -85,4 +85,9 @@ type Database interface {
|
||||||
|
|
||||||
StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error
|
StoreCrossSigningKeysForUser(ctx context.Context, userID string, keyMap types.CrossSigningKeyMap) error
|
||||||
StoreCrossSigningSigsForTarget(ctx context.Context, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error
|
StoreCrossSigningSigsForTarget(ctx context.Context, originUserID string, originKeyID gomatrixserverlib.KeyID, targetUserID string, targetKeyID gomatrixserverlib.KeyID, signature gomatrixserverlib.Base64Bytes) error
|
||||||
|
|
||||||
|
DeleteStaleDeviceLists(
|
||||||
|
ctx context.Context,
|
||||||
|
userIDs []string,
|
||||||
|
) error
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,6 +19,10 @@ import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/lib/pq"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
@ -48,10 +52,14 @@ const selectStaleDeviceListsWithDomainsSQL = "" +
|
||||||
const selectStaleDeviceListsSQL = "" +
|
const selectStaleDeviceListsSQL = "" +
|
||||||
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 ORDER BY ts_added_secs DESC"
|
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 ORDER BY ts_added_secs DESC"
|
||||||
|
|
||||||
|
const deleteStaleDevicesSQL = "" +
|
||||||
|
"DELETE FROM keyserver_stale_device_lists WHERE user_id = ANY($1)"
|
||||||
|
|
||||||
type staleDeviceListsStatements struct {
|
type staleDeviceListsStatements struct {
|
||||||
upsertStaleDeviceListStmt *sql.Stmt
|
upsertStaleDeviceListStmt *sql.Stmt
|
||||||
selectStaleDeviceListsWithDomainsStmt *sql.Stmt
|
selectStaleDeviceListsWithDomainsStmt *sql.Stmt
|
||||||
selectStaleDeviceListsStmt *sql.Stmt
|
selectStaleDeviceListsStmt *sql.Stmt
|
||||||
|
deleteStaleDeviceListsStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewPostgresStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) {
|
func NewPostgresStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) {
|
||||||
|
@ -60,16 +68,12 @@ func NewPostgresStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, erro
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if s.upsertStaleDeviceListStmt, err = db.Prepare(upsertStaleDeviceListSQL); err != nil {
|
return s, sqlutil.StatementList{
|
||||||
return nil, err
|
{&s.upsertStaleDeviceListStmt, upsertStaleDeviceListSQL},
|
||||||
}
|
{&s.selectStaleDeviceListsStmt, selectStaleDeviceListsSQL},
|
||||||
if s.selectStaleDeviceListsStmt, err = db.Prepare(selectStaleDeviceListsSQL); err != nil {
|
{&s.selectStaleDeviceListsWithDomainsStmt, selectStaleDeviceListsWithDomainsSQL},
|
||||||
return nil, err
|
{&s.deleteStaleDeviceListsStmt, deleteStaleDevicesSQL},
|
||||||
}
|
}.Prepare(db)
|
||||||
if s.selectStaleDeviceListsWithDomainsStmt, err = db.Prepare(selectStaleDeviceListsWithDomainsSQL); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return s, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error {
|
func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error {
|
||||||
|
@ -105,6 +109,15 @@ func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx conte
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeleteStaleDeviceLists removes users from stale device lists
|
||||||
|
func (s *staleDeviceListsStatements) DeleteStaleDeviceLists(
|
||||||
|
ctx context.Context, txn *sql.Tx, userIDs []string,
|
||||||
|
) error {
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.deleteStaleDeviceListsStmt)
|
||||||
|
_, err := stmt.ExecContext(ctx, pq.Array(userIDs))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) {
|
func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) {
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed")
|
defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed")
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
|
|
|
@ -249,3 +249,13 @@ func (d *Database) StoreCrossSigningSigsForTarget(
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeleteStaleDeviceLists deletes stale device list entries for users we don't share a room with anymore.
|
||||||
|
func (d *Database) DeleteStaleDeviceLists(
|
||||||
|
ctx context.Context,
|
||||||
|
userIDs []string,
|
||||||
|
) error {
|
||||||
|
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||||
|
return d.StaleDeviceListsTable.DeleteStaleDeviceLists(ctx, txn, userIDs)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
@ -17,8 +17,11 @@ package sqlite3
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
@ -48,11 +51,15 @@ const selectStaleDeviceListsWithDomainsSQL = "" +
|
||||||
const selectStaleDeviceListsSQL = "" +
|
const selectStaleDeviceListsSQL = "" +
|
||||||
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 ORDER BY ts_added_secs DESC"
|
"SELECT user_id FROM keyserver_stale_device_lists WHERE is_stale = $1 ORDER BY ts_added_secs DESC"
|
||||||
|
|
||||||
|
const deleteStaleDevicesSQL = "" +
|
||||||
|
"DELETE FROM keyserver_stale_device_lists WHERE user_id IN ($1)"
|
||||||
|
|
||||||
type staleDeviceListsStatements struct {
|
type staleDeviceListsStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
upsertStaleDeviceListStmt *sql.Stmt
|
upsertStaleDeviceListStmt *sql.Stmt
|
||||||
selectStaleDeviceListsWithDomainsStmt *sql.Stmt
|
selectStaleDeviceListsWithDomainsStmt *sql.Stmt
|
||||||
selectStaleDeviceListsStmt *sql.Stmt
|
selectStaleDeviceListsStmt *sql.Stmt
|
||||||
|
// deleteStaleDeviceListsStmt *sql.Stmt // Prepared at runtime
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) {
|
func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error) {
|
||||||
|
@ -63,16 +70,12 @@ func NewSqliteStaleDeviceListsTable(db *sql.DB) (tables.StaleDeviceLists, error)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if s.upsertStaleDeviceListStmt, err = db.Prepare(upsertStaleDeviceListSQL); err != nil {
|
return s, sqlutil.StatementList{
|
||||||
return nil, err
|
{&s.upsertStaleDeviceListStmt, upsertStaleDeviceListSQL},
|
||||||
}
|
{&s.selectStaleDeviceListsStmt, selectStaleDeviceListsSQL},
|
||||||
if s.selectStaleDeviceListsStmt, err = db.Prepare(selectStaleDeviceListsSQL); err != nil {
|
{&s.selectStaleDeviceListsWithDomainsStmt, selectStaleDeviceListsWithDomainsSQL},
|
||||||
return nil, err
|
// { &s.deleteStaleDeviceListsStmt, deleteStaleDevicesSQL}, // Prepared at runtime
|
||||||
}
|
}.Prepare(db)
|
||||||
if s.selectStaleDeviceListsWithDomainsStmt, err = db.Prepare(selectStaleDeviceListsWithDomainsSQL); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return s, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error {
|
func (s *staleDeviceListsStatements) InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error {
|
||||||
|
@ -108,6 +111,27 @@ func (s *staleDeviceListsStatements) SelectUserIDsWithStaleDeviceLists(ctx conte
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DeleteStaleDeviceLists removes users from stale device lists
|
||||||
|
func (s *staleDeviceListsStatements) DeleteStaleDeviceLists(
|
||||||
|
ctx context.Context, txn *sql.Tx, userIDs []string,
|
||||||
|
) error {
|
||||||
|
qry := strings.Replace(deleteStaleDevicesSQL, "($1)", sqlutil.QueryVariadic(len(userIDs)), 1)
|
||||||
|
stmt, err := s.db.Prepare(qry)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, stmt, "DeleteStaleDeviceLists: stmt.Close failed")
|
||||||
|
stmt = sqlutil.TxStmt(txn, stmt)
|
||||||
|
|
||||||
|
params := make([]any, len(userIDs))
|
||||||
|
for i := range userIDs {
|
||||||
|
params[i] = userIDs[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = stmt.ExecContext(ctx, params...)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) {
|
func rowsToUserIDs(ctx context.Context, rows *sql.Rows) (result []string, err error) {
|
||||||
defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed")
|
defer internal.CloseAndLogIfError(ctx, rows, "closing rowsToUserIDs failed")
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
|
|
|
@ -56,6 +56,7 @@ type KeyChanges interface {
|
||||||
type StaleDeviceLists interface {
|
type StaleDeviceLists interface {
|
||||||
InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error
|
InsertStaleDeviceList(ctx context.Context, userID string, isStale bool) error
|
||||||
SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error)
|
SelectUserIDsWithStaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error)
|
||||||
|
DeleteStaleDeviceLists(ctx context.Context, txn *sql.Tx, userIDs []string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
type CrossSigningKeys interface {
|
type CrossSigningKeys interface {
|
||||||
|
|
94
keyserver/storage/tables/stale_device_lists_test.go
Normal file
94
keyserver/storage/tables/stale_device_lists_test.go
Normal file
|
@ -0,0 +1,94 @@
|
||||||
|
package tables_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/keyserver/storage/sqlite3"
|
||||||
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/keyserver/storage/postgres"
|
||||||
|
"github.com/matrix-org/dendrite/keyserver/storage/tables"
|
||||||
|
"github.com/matrix-org/dendrite/test"
|
||||||
|
)
|
||||||
|
|
||||||
|
func mustCreateTable(t *testing.T, dbType test.DBType) (tab tables.StaleDeviceLists, close func()) {
|
||||||
|
connStr, close := test.PrepareDBConnectionString(t, dbType)
|
||||||
|
db, err := sqlutil.Open(&config.DatabaseOptions{
|
||||||
|
ConnectionString: config.DataSource(connStr),
|
||||||
|
}, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to open database: %s", err)
|
||||||
|
}
|
||||||
|
switch dbType {
|
||||||
|
case test.DBTypePostgres:
|
||||||
|
tab, err = postgres.NewPostgresStaleDeviceListsTable(db)
|
||||||
|
case test.DBTypeSQLite:
|
||||||
|
tab, err = sqlite3.NewSqliteStaleDeviceListsTable(db)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create new table: %s", err)
|
||||||
|
}
|
||||||
|
return tab, close
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestStaleDeviceLists(t *testing.T) {
|
||||||
|
alice := test.NewUser(t)
|
||||||
|
bob := test.NewUser(t)
|
||||||
|
charlie := "@charlie:localhost"
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
tab, closeDB := mustCreateTable(t, dbType)
|
||||||
|
defer closeDB()
|
||||||
|
|
||||||
|
if err := tab.InsertStaleDeviceList(ctx, alice.ID, true); err != nil {
|
||||||
|
t.Fatalf("failed to insert stale device: %s", err)
|
||||||
|
}
|
||||||
|
if err := tab.InsertStaleDeviceList(ctx, bob.ID, true); err != nil {
|
||||||
|
t.Fatalf("failed to insert stale device: %s", err)
|
||||||
|
}
|
||||||
|
if err := tab.InsertStaleDeviceList(ctx, charlie, true); err != nil {
|
||||||
|
t.Fatalf("failed to insert stale device: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query one server
|
||||||
|
wantStaleUsers := []string{alice.ID, bob.ID}
|
||||||
|
gotStaleUsers, err := tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to query stale device lists: %s", err)
|
||||||
|
}
|
||||||
|
if !test.UnsortedStringSliceEqual(wantStaleUsers, gotStaleUsers) {
|
||||||
|
t.Fatalf("expected stale users %v, got %v", wantStaleUsers, gotStaleUsers)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query all servers
|
||||||
|
wantStaleUsers = []string{alice.ID, bob.ID, charlie}
|
||||||
|
gotStaleUsers, err = tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to query stale device lists: %s", err)
|
||||||
|
}
|
||||||
|
if !test.UnsortedStringSliceEqual(wantStaleUsers, gotStaleUsers) {
|
||||||
|
t.Fatalf("expected stale users %v, got %v", wantStaleUsers, gotStaleUsers)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete stale devices
|
||||||
|
deleteUsers := []string{alice.ID, bob.ID}
|
||||||
|
if err = tab.DeleteStaleDeviceLists(ctx, nil, deleteUsers); err != nil {
|
||||||
|
t.Fatalf("failed to delete stale device lists: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify we don't get anything back after deleting
|
||||||
|
gotStaleUsers, err = tab.SelectUserIDsWithStaleDeviceLists(ctx, []gomatrixserverlib.ServerName{"test"})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to query stale device lists: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if gotCount := len(gotStaleUsers); gotCount > 0 {
|
||||||
|
t.Fatalf("expected no stale users, got %d", gotCount)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
|
@ -17,6 +17,7 @@ type RoomserverInternalAPI interface {
|
||||||
ClientRoomserverAPI
|
ClientRoomserverAPI
|
||||||
UserRoomserverAPI
|
UserRoomserverAPI
|
||||||
FederationRoomserverAPI
|
FederationRoomserverAPI
|
||||||
|
KeyserverRoomserverAPI
|
||||||
|
|
||||||
// needed to avoid chicken and egg scenario when setting up the
|
// needed to avoid chicken and egg scenario when setting up the
|
||||||
// interdependencies between the roomserver and other input APIs
|
// interdependencies between the roomserver and other input APIs
|
||||||
|
@ -199,3 +200,7 @@ type FederationRoomserverAPI interface {
|
||||||
// Query a given amount (or less) of events prior to a given set of events.
|
// Query a given amount (or less) of events prior to a given set of events.
|
||||||
PerformBackfill(ctx context.Context, req *PerformBackfillRequest, res *PerformBackfillResponse) error
|
PerformBackfill(ctx context.Context, req *PerformBackfillRequest, res *PerformBackfillResponse) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type KeyserverRoomserverAPI interface {
|
||||||
|
QueryLeftUsers(ctx context.Context, req *QueryLeftUsersRequest, res *QueryLeftUsersResponse) error
|
||||||
|
}
|
||||||
|
|
|
@ -19,6 +19,12 @@ type RoomserverInternalAPITrace struct {
|
||||||
Impl RoomserverInternalAPI
|
Impl RoomserverInternalAPI
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *RoomserverInternalAPITrace) QueryLeftUsers(ctx context.Context, req *QueryLeftUsersRequest, res *QueryLeftUsersResponse) error {
|
||||||
|
err := t.Impl.QueryLeftUsers(ctx, req, res)
|
||||||
|
util.GetLogger(ctx).WithError(err).Infof("QueryLeftUsers req=%+v res=%+v", js(req), js(res))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func (t *RoomserverInternalAPITrace) SetFederationAPI(fsAPI fsAPI.RoomserverFederationAPI, keyRing *gomatrixserverlib.KeyRing) {
|
func (t *RoomserverInternalAPITrace) SetFederationAPI(fsAPI fsAPI.RoomserverFederationAPI, keyRing *gomatrixserverlib.KeyRing) {
|
||||||
t.Impl.SetFederationAPI(fsAPI, keyRing)
|
t.Impl.SetFederationAPI(fsAPI, keyRing)
|
||||||
}
|
}
|
||||||
|
|
|
@ -447,3 +447,15 @@ type QueryMembershipAtEventResponse struct {
|
||||||
// do not have known state will return an empty array here.
|
// do not have known state will return an empty array here.
|
||||||
Memberships map[string][]*gomatrixserverlib.HeaderedEvent `json:"memberships"`
|
Memberships map[string][]*gomatrixserverlib.HeaderedEvent `json:"memberships"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// QueryLeftUsersRequest is a request to calculate users that we (the server) don't share a
|
||||||
|
// a room with anymore. This is used to cleanup stale device list entries, where we would
|
||||||
|
// otherwise keep on trying to get device lists.
|
||||||
|
type QueryLeftUsersRequest struct {
|
||||||
|
StaleDeviceListUsers []string `json:"user_ids"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryLeftUsersResponse is the response to QueryLeftUsersRequest.
|
||||||
|
type QueryLeftUsersResponse struct {
|
||||||
|
LeftUsers []string `json:"user_ids"`
|
||||||
|
}
|
||||||
|
|
|
@ -805,6 +805,12 @@ func (r *Queryer) QueryBulkStateContent(ctx context.Context, req *api.QueryBulkS
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *Queryer) QueryLeftUsers(ctx context.Context, req *api.QueryLeftUsersRequest, res *api.QueryLeftUsersResponse) error {
|
||||||
|
var err error
|
||||||
|
res.LeftUsers, err = r.DB.GetLeftUsers(ctx, req.StaleDeviceListUsers)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func (r *Queryer) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse) error {
|
func (r *Queryer) QuerySharedUsers(ctx context.Context, req *api.QuerySharedUsersRequest, res *api.QuerySharedUsersResponse) error {
|
||||||
roomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, "join")
|
roomIDs, err := r.DB.GetRoomsByMembership(ctx, req.UserID, "join")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -63,6 +63,7 @@ const (
|
||||||
RoomserverQueryAuthChainPath = "/roomserver/queryAuthChain"
|
RoomserverQueryAuthChainPath = "/roomserver/queryAuthChain"
|
||||||
RoomserverQueryRestrictedJoinAllowed = "/roomserver/queryRestrictedJoinAllowed"
|
RoomserverQueryRestrictedJoinAllowed = "/roomserver/queryRestrictedJoinAllowed"
|
||||||
RoomserverQueryMembershipAtEventPath = "/roomserver/queryMembershipAtEvent"
|
RoomserverQueryMembershipAtEventPath = "/roomserver/queryMembershipAtEvent"
|
||||||
|
RoomserverQueryLeftMembersPath = "/roomserver/queryLeftMembers"
|
||||||
)
|
)
|
||||||
|
|
||||||
type httpRoomserverInternalAPI struct {
|
type httpRoomserverInternalAPI struct {
|
||||||
|
@ -553,3 +554,10 @@ func (h *httpRoomserverInternalAPI) QueryMembershipAtEvent(ctx context.Context,
|
||||||
h.httpClient, ctx, request, response,
|
h.httpClient, ctx, request, response,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *httpRoomserverInternalAPI) QueryLeftUsers(ctx context.Context, request *api.QueryLeftUsersRequest, response *api.QueryLeftUsersResponse) error {
|
||||||
|
return httputil.CallInternalRPCAPI(
|
||||||
|
"RoomserverQueryLeftMembers", h.roomserverURL+RoomserverQueryLeftMembersPath,
|
||||||
|
h.httpClient, ctx, request, response,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
|
@ -203,4 +203,9 @@ func AddRoutes(r api.RoomserverInternalAPI, internalAPIMux *mux.Router, enableMe
|
||||||
RoomserverQueryMembershipAtEventPath,
|
RoomserverQueryMembershipAtEventPath,
|
||||||
httputil.MakeInternalRPCAPI("RoomserverQueryMembershipAtEventPath", enableMetrics, r.QueryMembershipAtEvent),
|
httputil.MakeInternalRPCAPI("RoomserverQueryMembershipAtEventPath", enableMetrics, r.QueryMembershipAtEvent),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
internalAPIMux.Handle(
|
||||||
|
RoomserverQueryLeftMembersPath,
|
||||||
|
httputil.MakeInternalRPCAPI("RoomserverQueryLeftMembersPath", enableMetrics, r.QueryLeftUsers),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,20 +2,27 @@ package roomserver_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/mux"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/httputil"
|
||||||
"github.com/matrix-org/dendrite/roomserver"
|
"github.com/matrix-org/dendrite/roomserver"
|
||||||
"github.com/matrix-org/dendrite/roomserver/api"
|
"github.com/matrix-org/dendrite/roomserver/api"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/inthttp"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage"
|
"github.com/matrix-org/dendrite/roomserver/storage"
|
||||||
"github.com/matrix-org/dendrite/setup/base"
|
"github.com/matrix-org/dendrite/setup/base"
|
||||||
"github.com/matrix-org/dendrite/test"
|
"github.com/matrix-org/dendrite/test"
|
||||||
"github.com/matrix-org/dendrite/test/testrig"
|
"github.com/matrix-org/dendrite/test/testrig"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func mustCreateDatabase(t *testing.T, dbType test.DBType) (*base.BaseDendrite, storage.Database, func()) {
|
func mustCreateDatabase(t *testing.T, dbType test.DBType) (*base.BaseDendrite, storage.Database, func()) {
|
||||||
|
t.Helper()
|
||||||
base, close := testrig.CreateBaseDendrite(t, dbType)
|
base, close := testrig.CreateBaseDendrite(t, dbType)
|
||||||
db, err := storage.Open(base, &base.Cfg.KeyServer.Database, base.Caches)
|
db, err := storage.Open(base, &base.Cfg.RoomServer.Database, base.Caches)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create Database: %v", err)
|
t.Fatalf("failed to create Database: %v", err)
|
||||||
}
|
}
|
||||||
|
@ -67,3 +74,69 @@ func Test_SharedUsers(t *testing.T) {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_QueryLeftUsers(t *testing.T) {
|
||||||
|
alice := test.NewUser(t)
|
||||||
|
bob := test.NewUser(t)
|
||||||
|
room := test.NewRoom(t, alice, test.RoomPreset(test.PresetTrustedPrivateChat))
|
||||||
|
|
||||||
|
// Invite and join Bob
|
||||||
|
room.CreateAndInsert(t, alice, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||||
|
"membership": "invite",
|
||||||
|
}, test.WithStateKey(bob.ID))
|
||||||
|
room.CreateAndInsert(t, bob, gomatrixserverlib.MRoomMember, map[string]interface{}{
|
||||||
|
"membership": "join",
|
||||||
|
}, test.WithStateKey(bob.ID))
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
base, _, close := mustCreateDatabase(t, dbType)
|
||||||
|
defer close()
|
||||||
|
|
||||||
|
rsAPI := roomserver.NewInternalAPI(base)
|
||||||
|
// SetFederationAPI starts the room event input consumer
|
||||||
|
rsAPI.SetFederationAPI(nil, nil)
|
||||||
|
// Create the room
|
||||||
|
if err := api.SendEvents(ctx, rsAPI, api.KindNew, room.Events(), "test", "test", "test", nil, false); err != nil {
|
||||||
|
t.Fatalf("failed to send events: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query the left users, there should only be "@idontexist:test",
|
||||||
|
// as Alice and Bob are still joined.
|
||||||
|
res := &api.QueryLeftUsersResponse{}
|
||||||
|
leftUserID := "@idontexist:test"
|
||||||
|
getLeftUsersList := []string{alice.ID, bob.ID, leftUserID}
|
||||||
|
|
||||||
|
testCase := func(rsAPI api.RoomserverInternalAPI) {
|
||||||
|
if err := rsAPI.QueryLeftUsers(ctx, &api.QueryLeftUsersRequest{StaleDeviceListUsers: getLeftUsersList}, res); err != nil {
|
||||||
|
t.Fatalf("unable to query left users: %v", err)
|
||||||
|
}
|
||||||
|
wantCount := 1
|
||||||
|
if count := len(res.LeftUsers); count > wantCount {
|
||||||
|
t.Fatalf("unexpected left users count: want %d, got %d", wantCount, count)
|
||||||
|
}
|
||||||
|
if res.LeftUsers[0] != leftUserID {
|
||||||
|
t.Fatalf("unexpected left users : want %s, got %s", leftUserID, res.LeftUsers[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("HTTP API", func(t *testing.T) {
|
||||||
|
router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter()
|
||||||
|
roomserver.AddInternalRoutes(router, rsAPI, false)
|
||||||
|
apiURL, cancel := test.ListenAndServe(t, router, false)
|
||||||
|
defer cancel()
|
||||||
|
httpAPI, err := inthttp.NewRoomserverClient(apiURL, &http.Client{Timeout: time.Second * 5}, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create HTTP client")
|
||||||
|
}
|
||||||
|
testCase(httpAPI)
|
||||||
|
})
|
||||||
|
t.Run("Monolith", func(t *testing.T) {
|
||||||
|
testCase(rsAPI)
|
||||||
|
// also test tracing
|
||||||
|
traceAPI := &api.RoomserverInternalAPITrace{Impl: rsAPI}
|
||||||
|
testCase(traceAPI)
|
||||||
|
})
|
||||||
|
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
|
@ -172,5 +172,6 @@ type Database interface {
|
||||||
ForgetRoom(ctx context.Context, userID, roomID string, forget bool) error
|
ForgetRoom(ctx context.Context, userID, roomID string, forget bool) error
|
||||||
|
|
||||||
GetHistoryVisibilityState(ctx context.Context, roomInfo *types.RoomInfo, eventID string, domain string) ([]*gomatrixserverlib.Event, error)
|
GetHistoryVisibilityState(ctx context.Context, roomInfo *types.RoomInfo, eventID string, domain string) ([]*gomatrixserverlib.Event, error)
|
||||||
|
GetLeftUsers(ctx context.Context, userIDs []string) ([]string, error)
|
||||||
UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error
|
UpgradeRoom(ctx context.Context, oldRoomID, newRoomID, eventSender string) error
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,12 +21,13 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/postgres/deltas"
|
"github.com/matrix-org/dendrite/roomserver/storage/postgres/deltas"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const membershipSchema = `
|
const membershipSchema = `
|
||||||
|
@ -157,6 +158,12 @@ const selectServerInRoomSQL = "" +
|
||||||
" JOIN roomserver_event_state_keys ON roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" +
|
" JOIN roomserver_event_state_keys ON roomserver_membership.target_nid = roomserver_event_state_keys.event_state_key_nid" +
|
||||||
" WHERE membership_nid = $1 AND room_nid = $2 AND event_state_key LIKE '%:' || $3 LIMIT 1"
|
" WHERE membership_nid = $1 AND room_nid = $2 AND event_state_key LIKE '%:' || $3 LIMIT 1"
|
||||||
|
|
||||||
|
const selectJoinedUsersSQL = `
|
||||||
|
SELECT DISTINCT target_nid
|
||||||
|
FROM roomserver_membership m
|
||||||
|
WHERE membership_nid > $1 AND target_nid = ANY($2)
|
||||||
|
`
|
||||||
|
|
||||||
type membershipStatements struct {
|
type membershipStatements struct {
|
||||||
insertMembershipStmt *sql.Stmt
|
insertMembershipStmt *sql.Stmt
|
||||||
selectMembershipForUpdateStmt *sql.Stmt
|
selectMembershipForUpdateStmt *sql.Stmt
|
||||||
|
@ -174,6 +181,7 @@ type membershipStatements struct {
|
||||||
selectLocalServerInRoomStmt *sql.Stmt
|
selectLocalServerInRoomStmt *sql.Stmt
|
||||||
selectServerInRoomStmt *sql.Stmt
|
selectServerInRoomStmt *sql.Stmt
|
||||||
deleteMembershipStmt *sql.Stmt
|
deleteMembershipStmt *sql.Stmt
|
||||||
|
selectJoinedUsersStmt *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateMembershipTable(db *sql.DB) error {
|
func CreateMembershipTable(db *sql.DB) error {
|
||||||
|
@ -209,9 +217,33 @@ func PrepareMembershipTable(db *sql.DB) (tables.Membership, error) {
|
||||||
{&s.selectLocalServerInRoomStmt, selectLocalServerInRoomSQL},
|
{&s.selectLocalServerInRoomStmt, selectLocalServerInRoomSQL},
|
||||||
{&s.selectServerInRoomStmt, selectServerInRoomSQL},
|
{&s.selectServerInRoomStmt, selectServerInRoomSQL},
|
||||||
{&s.deleteMembershipStmt, deleteMembershipSQL},
|
{&s.deleteMembershipStmt, deleteMembershipSQL},
|
||||||
|
{&s.selectJoinedUsersStmt, selectJoinedUsersSQL},
|
||||||
}.Prepare(db)
|
}.Prepare(db)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *membershipStatements) SelectJoinedUsers(
|
||||||
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
targetUserNIDs []types.EventStateKeyNID,
|
||||||
|
) ([]types.EventStateKeyNID, error) {
|
||||||
|
result := make([]types.EventStateKeyNID, 0, len(targetUserNIDs))
|
||||||
|
|
||||||
|
stmt := sqlutil.TxStmt(txn, s.selectJoinedUsersStmt)
|
||||||
|
rows, err := stmt.QueryContext(ctx, tables.MembershipStateLeaveOrBan, pq.Array(targetUserNIDs))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "SelectJoinedUsers: rows.close() failed")
|
||||||
|
var targetNID types.EventStateKeyNID
|
||||||
|
for rows.Next() {
|
||||||
|
if err = rows.Scan(&targetNID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result = append(result, targetNID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
func (s *membershipStatements) InsertMembership(
|
func (s *membershipStatements) InsertMembership(
|
||||||
ctx context.Context, txn *sql.Tx,
|
ctx context.Context, txn *sql.Tx,
|
||||||
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
roomNID types.RoomNID, targetUserNID types.EventStateKeyNID,
|
||||||
|
|
|
@ -1365,6 +1365,43 @@ func (d *Database) JoinedUsersSetInRooms(ctx context.Context, roomIDs, userIDs [
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetLeftUsers calculates users we (the server) don't share a room with anymore.
|
||||||
|
func (d *Database) GetLeftUsers(ctx context.Context, userIDs []string) ([]string, error) {
|
||||||
|
// Get the userNID for all users with a stale device list
|
||||||
|
stateKeyNIDMap, err := d.EventStateKeyNIDs(ctx, userIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
userNIDs := make([]types.EventStateKeyNID, 0, len(stateKeyNIDMap))
|
||||||
|
userNIDtoUserID := make(map[types.EventStateKeyNID]string, len(stateKeyNIDMap))
|
||||||
|
// Create a map from userNID -> userID
|
||||||
|
for userID, nid := range stateKeyNIDMap {
|
||||||
|
userNIDs = append(userNIDs, nid)
|
||||||
|
userNIDtoUserID[nid] = userID
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get all users whose membership is still join, knock or invite.
|
||||||
|
stillJoinedUsersNIDs, err := d.MembershipTable.SelectJoinedUsers(ctx, nil, userNIDs)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove joined users from the "user with stale devices" list, which contains left AND joined users
|
||||||
|
for _, joinedUser := range stillJoinedUsersNIDs {
|
||||||
|
delete(userNIDtoUserID, joinedUser)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The users still in our userNIDtoUserID map are the users we don't share a room with anymore,
|
||||||
|
// and the return value we are looking for.
|
||||||
|
leftUsers := make([]string, 0, len(userNIDtoUserID))
|
||||||
|
for _, userID := range userNIDtoUserID {
|
||||||
|
leftUsers = append(leftUsers, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return leftUsers, nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetLocalServerInRoom returns true if we think we're in a given room or false otherwise.
|
// GetLocalServerInRoom returns true if we think we're in a given room or false otherwise.
|
||||||
func (d *Database) GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) {
|
func (d *Database) GetLocalServerInRoom(ctx context.Context, roomNID types.RoomNID) (bool, error) {
|
||||||
return d.MembershipTable.SelectLocalServerInRoom(ctx, nil, roomNID)
|
return d.MembershipTable.SelectLocalServerInRoom(ctx, nil, roomNID)
|
||||||
|
|
96
roomserver/storage/shared/storage_test.go
Normal file
96
roomserver/storage/shared/storage_test.go
Normal file
|
@ -0,0 +1,96 @@
|
||||||
|
package shared_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/storage/postgres"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/storage/shared"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/storage/sqlite3"
|
||||||
|
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||||
|
"github.com/matrix-org/dendrite/setup/config"
|
||||||
|
"github.com/matrix-org/dendrite/test"
|
||||||
|
"github.com/matrix-org/dendrite/test/testrig"
|
||||||
|
)
|
||||||
|
|
||||||
|
func mustCreateRoomserverDatabase(t *testing.T, dbType test.DBType) (*shared.Database, func()) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
connStr, clearDB := test.PrepareDBConnectionString(t, dbType)
|
||||||
|
base, _, _ := testrig.Base(nil)
|
||||||
|
dbOpts := &config.DatabaseOptions{ConnectionString: config.DataSource(connStr)}
|
||||||
|
|
||||||
|
db, err := sqlutil.Open(dbOpts, sqlutil.NewExclusiveWriter())
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
var membershipTable tables.Membership
|
||||||
|
var stateKeyTable tables.EventStateKeys
|
||||||
|
switch dbType {
|
||||||
|
case test.DBTypePostgres:
|
||||||
|
err = postgres.CreateEventStateKeysTable(db)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
err = postgres.CreateMembershipTable(db)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
membershipTable, err = postgres.PrepareMembershipTable(db)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
stateKeyTable, err = postgres.PrepareEventStateKeysTable(db)
|
||||||
|
case test.DBTypeSQLite:
|
||||||
|
err = sqlite3.CreateEventStateKeysTable(db)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
err = sqlite3.CreateMembershipTable(db)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
membershipTable, err = sqlite3.PrepareMembershipTable(db)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
stateKeyTable, err = sqlite3.PrepareEventStateKeysTable(db)
|
||||||
|
}
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
return &shared.Database{
|
||||||
|
DB: db,
|
||||||
|
EventStateKeysTable: stateKeyTable,
|
||||||
|
MembershipTable: membershipTable,
|
||||||
|
Writer: sqlutil.NewExclusiveWriter(),
|
||||||
|
}, func() {
|
||||||
|
err := base.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
clearDB()
|
||||||
|
err = db.Close()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_GetLeftUsers(t *testing.T) {
|
||||||
|
alice := test.NewUser(t)
|
||||||
|
bob := test.NewUser(t)
|
||||||
|
charlie := test.NewUser(t)
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
|
db, close := mustCreateRoomserverDatabase(t, dbType)
|
||||||
|
defer close()
|
||||||
|
|
||||||
|
// Create dummy entries
|
||||||
|
for _, user := range []*test.User{alice, bob, charlie} {
|
||||||
|
nid, err := db.EventStateKeysTable.InsertEventStateKeyNID(ctx, nil, user.ID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
err = db.MembershipTable.InsertMembership(ctx, nil, 1, nid, true)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
// We must update the membership with a non-zero event NID or it will get filtered out in later queries
|
||||||
|
membershipNID := tables.MembershipStateLeaveOrBan
|
||||||
|
if user == alice {
|
||||||
|
membershipNID = tables.MembershipStateJoin
|
||||||
|
}
|
||||||
|
_, err = db.MembershipTable.UpdateMembership(ctx, nil, 1, nid, nid, membershipNID, 1, false)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now try to get the left users, this should be Bob and Charlie, since they have a "leave" membership
|
||||||
|
expectedUserIDs := []string{bob.ID, charlie.ID}
|
||||||
|
leftUsers, err := db.GetLeftUsers(context.Background(), []string{alice.ID, bob.ID, charlie.ID})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.ElementsMatch(t, expectedUserIDs, leftUsers)
|
||||||
|
})
|
||||||
|
}
|
|
@ -21,12 +21,13 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/matrix-org/gomatrixserverlib"
|
||||||
|
|
||||||
"github.com/matrix-org/dendrite/internal"
|
"github.com/matrix-org/dendrite/internal"
|
||||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/sqlite3/deltas"
|
"github.com/matrix-org/dendrite/roomserver/storage/sqlite3/deltas"
|
||||||
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
"github.com/matrix-org/dendrite/roomserver/storage/tables"
|
||||||
"github.com/matrix-org/dendrite/roomserver/types"
|
"github.com/matrix-org/dendrite/roomserver/types"
|
||||||
"github.com/matrix-org/gomatrixserverlib"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const membershipSchema = `
|
const membershipSchema = `
|
||||||
|
@ -133,6 +134,12 @@ const selectServerInRoomSQL = "" +
|
||||||
const deleteMembershipSQL = "" +
|
const deleteMembershipSQL = "" +
|
||||||
"DELETE FROM roomserver_membership WHERE room_nid = $1 AND target_nid = $2"
|
"DELETE FROM roomserver_membership WHERE room_nid = $1 AND target_nid = $2"
|
||||||
|
|
||||||
|
const selectJoinedUsersSQL = `
|
||||||
|
SELECT DISTINCT target_nid
|
||||||
|
FROM roomserver_membership m
|
||||||
|
WHERE membership_nid > $1 AND target_nid IN ($2)
|
||||||
|
`
|
||||||
|
|
||||||
type membershipStatements struct {
|
type membershipStatements struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
insertMembershipStmt *sql.Stmt
|
insertMembershipStmt *sql.Stmt
|
||||||
|
@ -149,6 +156,7 @@ type membershipStatements struct {
|
||||||
selectLocalServerInRoomStmt *sql.Stmt
|
selectLocalServerInRoomStmt *sql.Stmt
|
||||||
selectServerInRoomStmt *sql.Stmt
|
selectServerInRoomStmt *sql.Stmt
|
||||||
deleteMembershipStmt *sql.Stmt
|
deleteMembershipStmt *sql.Stmt
|
||||||
|
// selectJoinedUsersStmt *sql.Stmt // Prepared at runtime
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateMembershipTable(db *sql.DB) error {
|
func CreateMembershipTable(db *sql.DB) error {
|
||||||
|
@ -412,3 +420,40 @@ func (s *membershipStatements) DeleteMembership(
|
||||||
)
|
)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *membershipStatements) SelectJoinedUsers(
|
||||||
|
ctx context.Context, txn *sql.Tx,
|
||||||
|
targetUserNIDs []types.EventStateKeyNID,
|
||||||
|
) ([]types.EventStateKeyNID, error) {
|
||||||
|
result := make([]types.EventStateKeyNID, 0, len(targetUserNIDs))
|
||||||
|
|
||||||
|
qry := strings.Replace(selectJoinedUsersSQL, "($2)", sqlutil.QueryVariadicOffset(len(targetUserNIDs), 1), 1)
|
||||||
|
|
||||||
|
stmt, err := s.db.Prepare(qry)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, stmt, "SelectJoinedUsers: stmt.Close failed")
|
||||||
|
|
||||||
|
params := make([]any, len(targetUserNIDs)+1)
|
||||||
|
params[0] = tables.MembershipStateLeaveOrBan
|
||||||
|
for i := range targetUserNIDs {
|
||||||
|
params[i+1] = targetUserNIDs[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
stmt = sqlutil.TxStmt(txn, stmt)
|
||||||
|
rows, err := stmt.QueryContext(ctx, params...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer internal.CloseAndLogIfError(ctx, rows, "SelectJoinedUsers: rows.close() failed")
|
||||||
|
var targetNID types.EventStateKeyNID
|
||||||
|
for rows.Next() {
|
||||||
|
if err = rows.Scan(&targetNID); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result = append(result, targetNID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, rows.Err()
|
||||||
|
}
|
||||||
|
|
|
@ -144,6 +144,7 @@ type Membership interface {
|
||||||
SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error)
|
SelectLocalServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID) (bool, error)
|
||||||
SelectServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error)
|
SelectServerInRoom(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, serverName gomatrixserverlib.ServerName) (bool, error)
|
||||||
DeleteMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) error
|
DeleteMembership(ctx context.Context, txn *sql.Tx, roomNID types.RoomNID, targetUserNID types.EventStateKeyNID) error
|
||||||
|
SelectJoinedUsers(ctx context.Context, txn *sql.Tx, targetUserNIDs []types.EventStateKeyNID) ([]types.EventStateKeyNID, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type Published interface {
|
type Published interface {
|
||||||
|
|
|
@ -129,5 +129,11 @@ func TestMembershipTable(t *testing.T) {
|
||||||
knownUsers, err := tab.SelectKnownUsers(ctx, nil, userNIDs[0], "localhost", 2)
|
knownUsers, err := tab.SelectKnownUsers(ctx, nil, userNIDs[0], "localhost", 2)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, 1, len(knownUsers))
|
assert.Equal(t, 1, len(knownUsers))
|
||||||
|
|
||||||
|
// get users we share a room with, given their userNID
|
||||||
|
joinedUsers, err := tab.SelectJoinedUsers(ctx, nil, userNIDs)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
// Only userNIDs[0] is actually joined, so we only expect this userNID
|
||||||
|
assert.Equal(t, userNIDs[:1], joinedUsers)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue