From 7f114cc5387f04d748270d48f92708f137df38a7 Mon Sep 17 00:00:00 2001 From: Till <2353100+S7evinK@users.noreply.github.com> Date: Mon, 20 Feb 2023 15:26:09 +0100 Subject: [PATCH] Fix issue where device keys are removed if a device ID is reused (#2982) Fixes https://github.com/matrix-org/dendrite/issues/2980 --- userapi/internal/user_api.go | 13 +++++++- userapi/userapi_test.go | 58 ++++++++++++++++++++++++++---------- 2 files changed, 55 insertions(+), 16 deletions(-) diff --git a/userapi/internal/user_api.go b/userapi/internal/user_api.go index 1cbd9719..8977697b 100644 --- a/userapi/internal/user_api.go +++ b/userapi/internal/user_api.go @@ -254,6 +254,17 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe if !a.Config.Matrix.IsLocalServerName(serverName) { return fmt.Errorf("server name %s is not local", serverName) } + // If a device ID was specified, check if it already exists and + // avoid sending an empty device list update which would remove + // existing device keys. + isExisting := false + if req.DeviceID != nil && *req.DeviceID != "" { + existingDev, err := a.DB.GetDeviceByID(ctx, req.Localpart, req.ServerName, *req.DeviceID) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return err + } + isExisting = existingDev.ID == *req.DeviceID + } util.GetLogger(ctx).WithFields(logrus.Fields{ "localpart": req.Localpart, "device_id": req.DeviceID, @@ -265,7 +276,7 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe } res.DeviceCreated = true res.Device = dev - if req.NoDeviceListUpdate { + if req.NoDeviceListUpdate || isExisting { return nil } // create empty device keys and upload them to trigger device list changes diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 08b1336b..01e491cb 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -18,6 +18,7 @@ import ( "context" "fmt" "reflect" + "sync" "testing" "time" @@ -44,13 +45,25 @@ type apiTestOpts struct { serverName string } -type dummyProducer struct{} +type dummyProducer struct { + callCount sync.Map + t *testing.T +} -func (d *dummyProducer) PublishMsg(*nats.Msg, ...nats.PubOpt) (*nats.PubAck, error) { +func (d *dummyProducer) PublishMsg(msg *nats.Msg, opts ...nats.PubOpt) (*nats.PubAck, error) { + count, loaded := d.callCount.LoadOrStore(msg.Subject, 1) + if loaded { + c, ok := count.(int) + if !ok { + d.t.Fatalf("unexpected type: %T with value %q", c, c) + } + d.callCount.Store(msg.Subject, c+1) + d.t.Logf("Incrementing call counter for %s", msg.Subject) + } return &nats.PubAck{}, nil } -func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (api.UserInternalAPI, storage.UserDatabase, func()) { +func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType, publisher producers.JetStreamPublisher) (api.UserInternalAPI, storage.UserDatabase, func()) { if opts.loginTokenLifetime == 0 { opts.loginTokenLifetime = api.DefaultLoginTokenLifetime * time.Millisecond } @@ -82,8 +95,12 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (ap }, } - syncProducer := producers.NewSyncAPI(accountDB, &dummyProducer{}, "", "") - keyChangeProducer := &producers.KeyChange{DB: keyDB, JetStream: &dummyProducer{}} + if publisher == nil { + publisher = &dummyProducer{t: t} + } + + syncProducer := producers.NewSyncAPI(accountDB, publisher, "client_data", "notification_data") + keyChangeProducer := &producers.KeyChange{DB: keyDB, JetStream: publisher, Topic: "keychange"} return &internal.UserInternalAPI{ DB: accountDB, KeyDatabase: keyDB, @@ -150,7 +167,7 @@ func TestQueryProfile(t *testing.T) { } test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType) + userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType, nil) defer close() _, err := accountDB.CreateAccount(context.TODO(), "alice", serverName, "foobar", "", api.AccountTypeUser) if err != nil { @@ -173,7 +190,7 @@ func TestQueryProfile(t *testing.T) { func TestPasswordlessLoginFails(t *testing.T) { ctx := context.Background() test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType) + userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType, nil) defer close() _, err := accountDB.CreateAccount(ctx, "auser", serverName, "", "", api.AccountTypeAppService) if err != nil { @@ -199,7 +216,7 @@ func TestLoginToken(t *testing.T) { t.Run("tokenLoginFlow", func(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType) + userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType, nil) defer close() _, err := accountDB.CreateAccount(ctx, "auser", serverName, "apassword", "", api.AccountTypeUser) if err != nil { @@ -249,7 +266,7 @@ func TestLoginToken(t *testing.T) { t.Run("expiredTokenIsNotReturned", func(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - userAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{loginTokenLifetime: -1 * time.Second}, dbType) + userAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{loginTokenLifetime: -1 * time.Second}, dbType, nil) defer close() creq := api.PerformLoginTokenCreationRequest{ @@ -274,7 +291,7 @@ func TestLoginToken(t *testing.T) { t.Run("deleteWorks", func(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - userAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType) + userAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType, nil) defer close() creq := api.PerformLoginTokenCreationRequest{ @@ -305,7 +322,7 @@ func TestLoginToken(t *testing.T) { t.Run("deleteUnknownIsNoOp", func(t *testing.T) { test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - userAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType) + userAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType, nil) defer close() dreq := api.PerformLoginTokenDeletionRequest{Token: "non-existent token"} var dresp api.PerformLoginTokenDeletionResponse @@ -323,7 +340,7 @@ func TestQueryAccountByLocalpart(t *testing.T) { ctx := context.Background() test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - intAPI, db, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType) + intAPI, db, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType, nil) defer close() createdAcc, err := db.CreateAccount(ctx, localpart, userServername, "", "", alice.AccountType) @@ -402,7 +419,7 @@ func TestAccountData(t *testing.T) { } test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - intAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{serverName: "test"}, dbType) + intAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{serverName: "test"}, dbType, nil) defer close() for _, tc := range testCases { @@ -518,7 +535,7 @@ func TestDevices(t *testing.T) { } test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - intAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{serverName: "test"}, dbType) + intAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{serverName: "test"}, dbType, nil) defer close() for _, tc := range creationTests { @@ -623,7 +640,8 @@ func TestDevices(t *testing.T) { func TestDeviceIDReuse(t *testing.T) { ctx := context.Background() test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) { - intAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{serverName: "test"}, dbType) + publisher := &dummyProducer{t: t} + intAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{serverName: "test"}, dbType, publisher) defer close() res := api.PerformDeviceCreationResponse{} @@ -637,6 +655,9 @@ func TestDeviceIDReuse(t *testing.T) { // Do the same request again, we expect a different sessionID res2 := api.PerformDeviceCreationResponse{} + // Set NoDeviceListUpdate to false, to verify we don't send device list updates when + // reusing the same device ID + req.NoDeviceListUpdate = false err = intAPI.PerformDeviceCreation(ctx, &req, &res2) if err != nil { t.Fatalf("expected no error, but got: %v", err) @@ -645,5 +666,12 @@ func TestDeviceIDReuse(t *testing.T) { if res2.Device.SessionID == res.Device.SessionID { t.Fatalf("expected a different session ID, but they are the same") } + + publisher.callCount.Range(func(key, value any) bool { + if value != nil { + t.Fatalf("expected publisher to not get called, but got value %d for subject %s", value, key) + } + return true + }) }) }