Merge both user API databases into one (#2186)

* Merge user API databases into one

* Remove DeviceDatabase from config

* Fix tests

* Try that again

* Clean up keyserver device keys when the devices no longer exist in the user API

* Tweak ordering

* Fix UserExists flag, device check

* Allow including empty entries so we can clean them up

* Remove logging
This commit is contained in:
Neil Alexander 2022-02-18 11:31:05 +00:00 committed by GitHub
parent 0a7dea4450
commit 153bfbbea5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
76 changed files with 727 additions and 899 deletions

View file

@ -31,13 +31,11 @@ import (
keyapi "github.com/matrix-org/dendrite/keyserver/api"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/accounts"
"github.com/matrix-org/dendrite/userapi/storage/devices"
"github.com/matrix-org/dendrite/userapi/storage"
)
type UserInternalAPI struct {
AccountDB accounts.Database
DeviceDB devices.Database
DB storage.Database
ServerName gomatrixserverlib.ServerName
// AppServices is the list of all registered AS
AppServices []config.ApplicationService
@ -55,11 +53,11 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc
if req.DataType == "" {
return fmt.Errorf("data type must not be empty")
}
return a.AccountDB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData)
return a.DB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData)
}
func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.PerformAccountCreationRequest, res *api.PerformAccountCreationResponse) error {
acc, err := a.AccountDB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType)
acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType)
if err != nil {
if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists
switch req.OnConflict {
@ -89,7 +87,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
return nil
}
if err = a.AccountDB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil {
if err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil {
return err
}
@ -99,7 +97,7 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
}
func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error {
if err := a.AccountDB.SetPassword(ctx, req.Localpart, req.Password); err != nil {
if err := a.DB.SetPassword(ctx, req.Localpart, req.Password); err != nil {
return err
}
res.PasswordUpdated = true
@ -112,7 +110,7 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe
"device_id": req.DeviceID,
"display_name": req.DeviceDisplayName,
}).Info("PerformDeviceCreation")
dev, err := a.DeviceDB.CreateDevice(ctx, req.Localpart, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent)
dev, err := a.DB.CreateDevice(ctx, req.Localpart, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent)
if err != nil {
return err
}
@ -137,12 +135,12 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe
deletedDeviceIDs := req.DeviceIDs
if len(req.DeviceIDs) == 0 {
var devices []api.Device
devices, err = a.DeviceDB.RemoveAllDevices(ctx, local, req.ExceptDeviceID)
devices, err = a.DB.RemoveAllDevices(ctx, local, req.ExceptDeviceID)
for _, d := range devices {
deletedDeviceIDs = append(deletedDeviceIDs, d.ID)
}
} else {
err = a.DeviceDB.RemoveDevices(ctx, local, req.DeviceIDs)
err = a.DB.RemoveDevices(ctx, local, req.DeviceIDs)
}
if err != nil {
return err
@ -196,7 +194,7 @@ func (a *UserInternalAPI) PerformLastSeenUpdate(
if err != nil {
return fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
}
if err := a.DeviceDB.UpdateDeviceLastSeen(ctx, localpart, req.DeviceID, req.RemoteAddr); err != nil {
if err := a.DB.UpdateDeviceLastSeen(ctx, localpart, req.DeviceID, req.RemoteAddr); err != nil {
return fmt.Errorf("a.DeviceDB.UpdateDeviceLastSeen: %w", err)
}
return nil
@ -208,7 +206,7 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed")
return err
}
dev, err := a.DeviceDB.GetDeviceByID(ctx, localpart, req.DeviceID)
dev, err := a.DB.GetDeviceByID(ctx, localpart, req.DeviceID)
if err == sql.ErrNoRows {
res.DeviceExists = false
return nil
@ -223,7 +221,7 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
return nil
}
err = a.DeviceDB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName)
err = a.DB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName)
if err != nil {
util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed")
return err
@ -261,7 +259,7 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil
if domain != a.ServerName {
return fmt.Errorf("cannot query profile of remote users: got %s want %s", domain, a.ServerName)
}
prof, err := a.AccountDB.GetProfileByLocalpart(ctx, local)
prof, err := a.DB.GetProfileByLocalpart(ctx, local)
if err != nil {
if err == sql.ErrNoRows {
return nil
@ -275,7 +273,7 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil
}
func (a *UserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.QuerySearchProfilesRequest, res *api.QuerySearchProfilesResponse) error {
profiles, err := a.AccountDB.SearchProfiles(ctx, req.SearchString, req.Limit)
profiles, err := a.DB.SearchProfiles(ctx, req.SearchString, req.Limit)
if err != nil {
return err
}
@ -284,7 +282,7 @@ func (a *UserInternalAPI) QuerySearchProfiles(ctx context.Context, req *api.Quer
}
func (a *UserInternalAPI) QueryDeviceInfos(ctx context.Context, req *api.QueryDeviceInfosRequest, res *api.QueryDeviceInfosResponse) error {
devices, err := a.DeviceDB.GetDevicesByID(ctx, req.DeviceIDs)
devices, err := a.DB.GetDevicesByID(ctx, req.DeviceIDs)
if err != nil {
return err
}
@ -312,10 +310,11 @@ func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevice
if domain != a.ServerName {
return fmt.Errorf("cannot query devices of remote users: got %s want %s", domain, a.ServerName)
}
devs, err := a.DeviceDB.GetDevicesByLocalpart(ctx, local)
devs, err := a.DB.GetDevicesByLocalpart(ctx, local)
if err != nil {
return err
}
res.UserExists = true
res.Devices = devs
return nil
}
@ -330,7 +329,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc
}
if req.DataType != "" {
var data json.RawMessage
data, err = a.AccountDB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType)
data, err = a.DB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType)
if err != nil {
return err
}
@ -348,7 +347,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc
}
return nil
}
global, rooms, err := a.AccountDB.GetAccountData(ctx, local)
global, rooms, err := a.DB.GetAccountData(ctx, local)
if err != nil {
return err
}
@ -367,7 +366,7 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc
return nil
}
device, err := a.DeviceDB.GetDeviceByAccessToken(ctx, req.AccessToken)
device, err := a.DB.GetDeviceByAccessToken(ctx, req.AccessToken)
if err != nil {
if err == sql.ErrNoRows {
return nil
@ -378,7 +377,7 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc
if err != nil {
return err
}
acc, err := a.AccountDB.GetAccountByLocalpart(ctx, localPart)
acc, err := a.DB.GetAccountByLocalpart(ctx, localPart)
if err != nil {
return err
}
@ -419,7 +418,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe
if localpart != "" { // AS is masquerading as another user
// Verify that the user is registered
account, err := a.AccountDB.GetAccountByLocalpart(ctx, localpart)
account, err := a.DB.GetAccountByLocalpart(ctx, localpart)
// Verify that the account exists and either appServiceID matches or
// it belongs to the appservice user namespaces
if err == nil && (account.AppServiceID == appService.ID || appService.IsInterestedInUserID(appServiceUserID)) {
@ -437,7 +436,7 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe
// PerformAccountDeactivation deactivates the user's account, removing all ability for the user to login again.
func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *api.PerformAccountDeactivationRequest, res *api.PerformAccountDeactivationResponse) error {
err := a.AccountDB.DeactivateAccount(ctx, req.Localpart)
err := a.DB.DeactivateAccount(ctx, req.Localpart)
res.AccountDeactivated = err == nil
return err
}
@ -446,7 +445,7 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a
func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *api.PerformOpenIDTokenCreationRequest, res *api.PerformOpenIDTokenCreationResponse) error {
token := util.RandomString(24)
exp, err := a.AccountDB.CreateOpenIDToken(ctx, token, req.UserID)
exp, err := a.DB.CreateOpenIDToken(ctx, token, req.UserID)
res.Token = api.OpenIDToken{
Token: token,
@ -459,7 +458,7 @@ func (a *UserInternalAPI) PerformOpenIDTokenCreation(ctx context.Context, req *a
// QueryOpenIDToken validates that the OpenID token was issued for the user, the replying party uses this for validation
func (a *UserInternalAPI) QueryOpenIDToken(ctx context.Context, req *api.QueryOpenIDTokenRequest, res *api.QueryOpenIDTokenResponse) error {
openIDTokenAttrs, err := a.AccountDB.GetOpenIDTokenAttributes(ctx, req.Token)
openIDTokenAttrs, err := a.DB.GetOpenIDTokenAttributes(ctx, req.Token)
if err != nil {
return err
}
@ -481,7 +480,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
}
return nil
}
exists, err := a.AccountDB.DeleteKeyBackup(ctx, req.UserID, req.Version)
exists, err := a.DB.DeleteKeyBackup(ctx, req.UserID, req.Version)
if err != nil {
res.Error = fmt.Sprintf("failed to delete backup: %s", err)
}
@ -494,7 +493,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
}
// Create metadata
if req.Version == "" {
version, err := a.AccountDB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData)
version, err := a.DB.CreateKeyBackup(ctx, req.UserID, req.Algorithm, req.AuthData)
if err != nil {
res.Error = fmt.Sprintf("failed to create backup: %s", err)
}
@ -507,7 +506,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
}
// Update metadata
if len(req.Keys.Rooms) == 0 {
err := a.AccountDB.UpdateKeyBackupAuthData(ctx, req.UserID, req.Version, req.AuthData)
err := a.DB.UpdateKeyBackupAuthData(ctx, req.UserID, req.Version, req.AuthData)
if err != nil {
res.Error = fmt.Sprintf("failed to update backup: %s", err)
}
@ -528,7 +527,7 @@ func (a *UserInternalAPI) PerformKeyBackup(ctx context.Context, req *api.Perform
func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.PerformKeyBackupRequest, res *api.PerformKeyBackupResponse) {
// you can only upload keys for the CURRENT version
version, _, _, _, deleted, err := a.AccountDB.GetKeyBackup(ctx, req.UserID, "")
version, _, _, _, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, "")
if err != nil {
res.Error = fmt.Sprintf("failed to query version: %s", err)
return
@ -556,7 +555,7 @@ func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.Perform
})
}
}
count, etag, err := a.AccountDB.UpsertBackupKeys(ctx, version, req.UserID, uploads)
count, etag, err := a.DB.UpsertBackupKeys(ctx, version, req.UserID, uploads)
if err != nil {
res.Error = fmt.Sprintf("failed to upsert keys: %s", err)
return
@ -566,7 +565,7 @@ func (a *UserInternalAPI) uploadBackupKeys(ctx context.Context, req *api.Perform
}
func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyBackupRequest, res *api.QueryKeyBackupResponse) {
version, algorithm, authData, etag, deleted, err := a.AccountDB.GetKeyBackup(ctx, req.UserID, req.Version)
version, algorithm, authData, etag, deleted, err := a.DB.GetKeyBackup(ctx, req.UserID, req.Version)
res.Version = version
if err != nil {
if err == sql.ErrNoRows {
@ -582,14 +581,14 @@ func (a *UserInternalAPI) QueryKeyBackup(ctx context.Context, req *api.QueryKeyB
res.Exists = !deleted
if !req.ReturnKeys {
res.Count, err = a.AccountDB.CountBackupKeys(ctx, version, req.UserID)
res.Count, err = a.DB.CountBackupKeys(ctx, version, req.UserID)
if err != nil {
res.Error = fmt.Sprintf("failed to count keys: %s", err)
}
return
}
result, err := a.AccountDB.GetBackupKeys(ctx, version, req.UserID, req.KeysForRoomID, req.KeysForSessionID)
result, err := a.DB.GetBackupKeys(ctx, version, req.UserID, req.KeysForRoomID, req.KeysForSessionID)
if err != nil {
res.Error = fmt.Sprintf("failed to query keys: %s", err)
return

View file

@ -34,7 +34,7 @@ func (a *UserInternalAPI) PerformLoginTokenCreation(ctx context.Context, req *ap
if domain != a.ServerName {
return fmt.Errorf("cannot create a login token for a remote user: got %s want %s", domain, a.ServerName)
}
tokenMeta, err := a.DeviceDB.CreateLoginToken(ctx, &req.Data)
tokenMeta, err := a.DB.CreateLoginToken(ctx, &req.Data)
if err != nil {
return err
}
@ -45,13 +45,13 @@ func (a *UserInternalAPI) PerformLoginTokenCreation(ctx context.Context, req *ap
// PerformLoginTokenDeletion ensures the token doesn't exist.
func (a *UserInternalAPI) PerformLoginTokenDeletion(ctx context.Context, req *api.PerformLoginTokenDeletionRequest, res *api.PerformLoginTokenDeletionResponse) error {
util.GetLogger(ctx).Info("PerformLoginTokenDeletion")
return a.DeviceDB.RemoveLoginToken(ctx, req.Token)
return a.DB.RemoveLoginToken(ctx, req.Token)
}
// QueryLoginToken returns the data associated with a login token. If
// the token is not valid, success is returned, but res.Data == nil.
func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLoginTokenRequest, res *api.QueryLoginTokenResponse) error {
tokenData, err := a.DeviceDB.GetLoginTokenDataByToken(ctx, req.Token)
tokenData, err := a.DB.GetLoginTokenDataByToken(ctx, req.Token)
if err != nil {
res.Data = nil
if err == sql.ErrNoRows {
@ -66,7 +66,7 @@ func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLog
if domain != a.ServerName {
return fmt.Errorf("cannot return a login token for a remote user: got %s want %s", domain, a.ServerName)
}
if _, err := a.AccountDB.GetAccountByLocalpart(ctx, localpart); err != nil {
if _, err := a.DB.GetAccountByLocalpart(ctx, localpart); err != nil {
res.Data = nil
if err == sql.ErrNoRows {
return nil