mirror of
https://github.com/hoernschen/dendrite.git
synced 2024-12-27 07:28:27 +00:00
Fix issue where device keys are removed if a device ID is reused (#2982)
Fixes https://github.com/matrix-org/dendrite/issues/2980
This commit is contained in:
parent
4594233f89
commit
7f114cc538
2 changed files with 55 additions and 16 deletions
|
@ -254,6 +254,17 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe
|
||||||
if !a.Config.Matrix.IsLocalServerName(serverName) {
|
if !a.Config.Matrix.IsLocalServerName(serverName) {
|
||||||
return fmt.Errorf("server name %s is not local", serverName)
|
return fmt.Errorf("server name %s is not local", serverName)
|
||||||
}
|
}
|
||||||
|
// If a device ID was specified, check if it already exists and
|
||||||
|
// avoid sending an empty device list update which would remove
|
||||||
|
// existing device keys.
|
||||||
|
isExisting := false
|
||||||
|
if req.DeviceID != nil && *req.DeviceID != "" {
|
||||||
|
existingDev, err := a.DB.GetDeviceByID(ctx, req.Localpart, req.ServerName, *req.DeviceID)
|
||||||
|
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
isExisting = existingDev.ID == *req.DeviceID
|
||||||
|
}
|
||||||
util.GetLogger(ctx).WithFields(logrus.Fields{
|
util.GetLogger(ctx).WithFields(logrus.Fields{
|
||||||
"localpart": req.Localpart,
|
"localpart": req.Localpart,
|
||||||
"device_id": req.DeviceID,
|
"device_id": req.DeviceID,
|
||||||
|
@ -265,7 +276,7 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe
|
||||||
}
|
}
|
||||||
res.DeviceCreated = true
|
res.DeviceCreated = true
|
||||||
res.Device = dev
|
res.Device = dev
|
||||||
if req.NoDeviceListUpdate {
|
if req.NoDeviceListUpdate || isExisting {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
// create empty device keys and upload them to trigger device list changes
|
// create empty device keys and upload them to trigger device list changes
|
||||||
|
|
|
@ -18,6 +18,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -44,13 +45,25 @@ type apiTestOpts struct {
|
||||||
serverName string
|
serverName string
|
||||||
}
|
}
|
||||||
|
|
||||||
type dummyProducer struct{}
|
type dummyProducer struct {
|
||||||
|
callCount sync.Map
|
||||||
|
t *testing.T
|
||||||
|
}
|
||||||
|
|
||||||
func (d *dummyProducer) PublishMsg(*nats.Msg, ...nats.PubOpt) (*nats.PubAck, error) {
|
func (d *dummyProducer) PublishMsg(msg *nats.Msg, opts ...nats.PubOpt) (*nats.PubAck, error) {
|
||||||
|
count, loaded := d.callCount.LoadOrStore(msg.Subject, 1)
|
||||||
|
if loaded {
|
||||||
|
c, ok := count.(int)
|
||||||
|
if !ok {
|
||||||
|
d.t.Fatalf("unexpected type: %T with value %q", c, c)
|
||||||
|
}
|
||||||
|
d.callCount.Store(msg.Subject, c+1)
|
||||||
|
d.t.Logf("Incrementing call counter for %s", msg.Subject)
|
||||||
|
}
|
||||||
return &nats.PubAck{}, nil
|
return &nats.PubAck{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (api.UserInternalAPI, storage.UserDatabase, func()) {
|
func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType, publisher producers.JetStreamPublisher) (api.UserInternalAPI, storage.UserDatabase, func()) {
|
||||||
if opts.loginTokenLifetime == 0 {
|
if opts.loginTokenLifetime == 0 {
|
||||||
opts.loginTokenLifetime = api.DefaultLoginTokenLifetime * time.Millisecond
|
opts.loginTokenLifetime = api.DefaultLoginTokenLifetime * time.Millisecond
|
||||||
}
|
}
|
||||||
|
@ -82,8 +95,12 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts, dbType test.DBType) (ap
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
syncProducer := producers.NewSyncAPI(accountDB, &dummyProducer{}, "", "")
|
if publisher == nil {
|
||||||
keyChangeProducer := &producers.KeyChange{DB: keyDB, JetStream: &dummyProducer{}}
|
publisher = &dummyProducer{t: t}
|
||||||
|
}
|
||||||
|
|
||||||
|
syncProducer := producers.NewSyncAPI(accountDB, publisher, "client_data", "notification_data")
|
||||||
|
keyChangeProducer := &producers.KeyChange{DB: keyDB, JetStream: publisher, Topic: "keychange"}
|
||||||
return &internal.UserInternalAPI{
|
return &internal.UserInternalAPI{
|
||||||
DB: accountDB,
|
DB: accountDB,
|
||||||
KeyDatabase: keyDB,
|
KeyDatabase: keyDB,
|
||||||
|
@ -150,7 +167,7 @@ func TestQueryProfile(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType)
|
userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType, nil)
|
||||||
defer close()
|
defer close()
|
||||||
_, err := accountDB.CreateAccount(context.TODO(), "alice", serverName, "foobar", "", api.AccountTypeUser)
|
_, err := accountDB.CreateAccount(context.TODO(), "alice", serverName, "foobar", "", api.AccountTypeUser)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -173,7 +190,7 @@ func TestQueryProfile(t *testing.T) {
|
||||||
func TestPasswordlessLoginFails(t *testing.T) {
|
func TestPasswordlessLoginFails(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType)
|
userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType, nil)
|
||||||
defer close()
|
defer close()
|
||||||
_, err := accountDB.CreateAccount(ctx, "auser", serverName, "", "", api.AccountTypeAppService)
|
_, err := accountDB.CreateAccount(ctx, "auser", serverName, "", "", api.AccountTypeAppService)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -199,7 +216,7 @@ func TestLoginToken(t *testing.T) {
|
||||||
|
|
||||||
t.Run("tokenLoginFlow", func(t *testing.T) {
|
t.Run("tokenLoginFlow", func(t *testing.T) {
|
||||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType)
|
userAPI, accountDB, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType, nil)
|
||||||
defer close()
|
defer close()
|
||||||
_, err := accountDB.CreateAccount(ctx, "auser", serverName, "apassword", "", api.AccountTypeUser)
|
_, err := accountDB.CreateAccount(ctx, "auser", serverName, "apassword", "", api.AccountTypeUser)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -249,7 +266,7 @@ func TestLoginToken(t *testing.T) {
|
||||||
|
|
||||||
t.Run("expiredTokenIsNotReturned", func(t *testing.T) {
|
t.Run("expiredTokenIsNotReturned", func(t *testing.T) {
|
||||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
userAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{loginTokenLifetime: -1 * time.Second}, dbType)
|
userAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{loginTokenLifetime: -1 * time.Second}, dbType, nil)
|
||||||
defer close()
|
defer close()
|
||||||
|
|
||||||
creq := api.PerformLoginTokenCreationRequest{
|
creq := api.PerformLoginTokenCreationRequest{
|
||||||
|
@ -274,7 +291,7 @@ func TestLoginToken(t *testing.T) {
|
||||||
|
|
||||||
t.Run("deleteWorks", func(t *testing.T) {
|
t.Run("deleteWorks", func(t *testing.T) {
|
||||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
userAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType)
|
userAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType, nil)
|
||||||
defer close()
|
defer close()
|
||||||
|
|
||||||
creq := api.PerformLoginTokenCreationRequest{
|
creq := api.PerformLoginTokenCreationRequest{
|
||||||
|
@ -305,7 +322,7 @@ func TestLoginToken(t *testing.T) {
|
||||||
|
|
||||||
t.Run("deleteUnknownIsNoOp", func(t *testing.T) {
|
t.Run("deleteUnknownIsNoOp", func(t *testing.T) {
|
||||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
userAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType)
|
userAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType, nil)
|
||||||
defer close()
|
defer close()
|
||||||
dreq := api.PerformLoginTokenDeletionRequest{Token: "non-existent token"}
|
dreq := api.PerformLoginTokenDeletionRequest{Token: "non-existent token"}
|
||||||
var dresp api.PerformLoginTokenDeletionResponse
|
var dresp api.PerformLoginTokenDeletionResponse
|
||||||
|
@ -323,7 +340,7 @@ func TestQueryAccountByLocalpart(t *testing.T) {
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
intAPI, db, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType)
|
intAPI, db, close := MustMakeInternalAPI(t, apiTestOpts{}, dbType, nil)
|
||||||
defer close()
|
defer close()
|
||||||
|
|
||||||
createdAcc, err := db.CreateAccount(ctx, localpart, userServername, "", "", alice.AccountType)
|
createdAcc, err := db.CreateAccount(ctx, localpart, userServername, "", "", alice.AccountType)
|
||||||
|
@ -402,7 +419,7 @@ func TestAccountData(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
intAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{serverName: "test"}, dbType)
|
intAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{serverName: "test"}, dbType, nil)
|
||||||
defer close()
|
defer close()
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
|
@ -518,7 +535,7 @@ func TestDevices(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
intAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{serverName: "test"}, dbType)
|
intAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{serverName: "test"}, dbType, nil)
|
||||||
defer close()
|
defer close()
|
||||||
|
|
||||||
for _, tc := range creationTests {
|
for _, tc := range creationTests {
|
||||||
|
@ -623,7 +640,8 @@ func TestDevices(t *testing.T) {
|
||||||
func TestDeviceIDReuse(t *testing.T) {
|
func TestDeviceIDReuse(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
|
||||||
intAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{serverName: "test"}, dbType)
|
publisher := &dummyProducer{t: t}
|
||||||
|
intAPI, _, close := MustMakeInternalAPI(t, apiTestOpts{serverName: "test"}, dbType, publisher)
|
||||||
defer close()
|
defer close()
|
||||||
|
|
||||||
res := api.PerformDeviceCreationResponse{}
|
res := api.PerformDeviceCreationResponse{}
|
||||||
|
@ -637,6 +655,9 @@ func TestDeviceIDReuse(t *testing.T) {
|
||||||
|
|
||||||
// Do the same request again, we expect a different sessionID
|
// Do the same request again, we expect a different sessionID
|
||||||
res2 := api.PerformDeviceCreationResponse{}
|
res2 := api.PerformDeviceCreationResponse{}
|
||||||
|
// Set NoDeviceListUpdate to false, to verify we don't send device list updates when
|
||||||
|
// reusing the same device ID
|
||||||
|
req.NoDeviceListUpdate = false
|
||||||
err = intAPI.PerformDeviceCreation(ctx, &req, &res2)
|
err = intAPI.PerformDeviceCreation(ctx, &req, &res2)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("expected no error, but got: %v", err)
|
t.Fatalf("expected no error, but got: %v", err)
|
||||||
|
@ -645,5 +666,12 @@ func TestDeviceIDReuse(t *testing.T) {
|
||||||
if res2.Device.SessionID == res.Device.SessionID {
|
if res2.Device.SessionID == res.Device.SessionID {
|
||||||
t.Fatalf("expected a different session ID, but they are the same")
|
t.Fatalf("expected a different session ID, but they are the same")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
publisher.callCount.Range(func(key, value any) bool {
|
||||||
|
if value != nil {
|
||||||
|
t.Fatalf("expected publisher to not get called, but got value %d for subject %s", value, key)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue