mirror of
https://github.com/hoernschen/dendrite.git
synced 2025-07-31 13:22:46 +00:00
Virtual hosting schema and logic changes (#2876)
Note that virtual users cannot federate correctly yet.
This commit is contained in:
parent
e177e0ae73
commit
529df30b56
62 changed files with 1250 additions and 732 deletions
|
@ -68,7 +68,7 @@ func (a *UserInternalAPI) InputAccountData(ctx context.Context, req *api.InputAc
|
|||
if req.DataType == "" {
|
||||
return fmt.Errorf("data type must not be empty")
|
||||
}
|
||||
if err := a.DB.SaveAccountData(ctx, local, req.RoomID, req.DataType, req.AccountData); err != nil {
|
||||
if err := a.DB.SaveAccountData(ctx, local, domain, req.RoomID, req.DataType, req.AccountData); err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("a.DB.SaveAccountData failed")
|
||||
return fmt.Errorf("failed to save account data: %w", err)
|
||||
}
|
||||
|
@ -108,7 +108,7 @@ func (a *UserInternalAPI) setFullyRead(ctx context.Context, req *api.InputAccoun
|
|||
return nil
|
||||
}
|
||||
|
||||
deleted, err := a.DB.DeleteNotificationsUpTo(ctx, localpart, req.RoomID, uint64(gomatrixserverlib.AsTimestamp(time.Now())))
|
||||
deleted, err := a.DB.DeleteNotificationsUpTo(ctx, localpart, domain, req.RoomID, uint64(gomatrixserverlib.AsTimestamp(time.Now())))
|
||||
if err != nil {
|
||||
logrus.WithError(err).Errorf("UserInternalAPI.setFullyRead: DeleteNotificationsUpTo failed")
|
||||
return err
|
||||
|
@ -124,7 +124,7 @@ func (a *UserInternalAPI) setFullyRead(ctx context.Context, req *api.InputAccoun
|
|||
return nil
|
||||
}
|
||||
|
||||
if err = userapiUtil.NotifyUserCountsAsync(ctx, a.PgClient, localpart, a.DB); err != nil {
|
||||
if err = userapiUtil.NotifyUserCountsAsync(ctx, a.PgClient, localpart, domain, a.DB); err != nil {
|
||||
logrus.WithError(err).Error("UserInternalAPI.setFullyRead: NotifyUserCounts failed")
|
||||
return err
|
||||
}
|
||||
|
@ -175,8 +175,10 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
|
|||
if serverName == "" {
|
||||
serverName = a.Config.Matrix.ServerName
|
||||
}
|
||||
// XXXX: Use the server name here
|
||||
acc, err := a.DB.CreateAccount(ctx, req.Localpart, req.Password, req.AppServiceID, req.AccountType)
|
||||
if !a.Config.Matrix.IsLocalServerName(serverName) {
|
||||
return fmt.Errorf("server name %s is not local", serverName)
|
||||
}
|
||||
acc, err := a.DB.CreateAccount(ctx, req.Localpart, serverName, req.Password, req.AppServiceID, req.AccountType)
|
||||
if err != nil {
|
||||
if errors.Is(err, sqlutil.ErrUserExists) { // This account already exists
|
||||
switch req.OnConflict {
|
||||
|
@ -215,8 +217,8 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
|
|||
return nil
|
||||
}
|
||||
|
||||
if _, _, err = a.DB.SetDisplayName(ctx, req.Localpart, req.Localpart); err != nil {
|
||||
return err
|
||||
if _, _, err = a.DB.SetDisplayName(ctx, req.Localpart, serverName, req.Localpart); err != nil {
|
||||
return fmt.Errorf("a.DB.SetDisplayName: %w", err)
|
||||
}
|
||||
|
||||
postRegisterJoinRooms(a.Cfg, acc, a.RSAPI)
|
||||
|
@ -227,11 +229,14 @@ func (a *UserInternalAPI) PerformAccountCreation(ctx context.Context, req *api.P
|
|||
}
|
||||
|
||||
func (a *UserInternalAPI) PerformPasswordUpdate(ctx context.Context, req *api.PerformPasswordUpdateRequest, res *api.PerformPasswordUpdateResponse) error {
|
||||
if err := a.DB.SetPassword(ctx, req.Localpart, req.Password); err != nil {
|
||||
if !a.Config.Matrix.IsLocalServerName(req.ServerName) {
|
||||
return fmt.Errorf("server name %s is not local", req.ServerName)
|
||||
}
|
||||
if err := a.DB.SetPassword(ctx, req.Localpart, req.ServerName, req.Password); err != nil {
|
||||
return err
|
||||
}
|
||||
if req.LogoutDevices {
|
||||
if _, err := a.DB.RemoveAllDevices(context.Background(), req.Localpart, ""); err != nil {
|
||||
if _, err := a.DB.RemoveAllDevices(context.Background(), req.Localpart, req.ServerName, ""); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -244,14 +249,15 @@ func (a *UserInternalAPI) PerformDeviceCreation(ctx context.Context, req *api.Pe
|
|||
if serverName == "" {
|
||||
serverName = a.Config.Matrix.ServerName
|
||||
}
|
||||
_ = serverName
|
||||
// XXXX: Use the server name here
|
||||
if !a.Config.Matrix.IsLocalServerName(serverName) {
|
||||
return fmt.Errorf("server name %s is not local", serverName)
|
||||
}
|
||||
util.GetLogger(ctx).WithFields(logrus.Fields{
|
||||
"localpart": req.Localpart,
|
||||
"device_id": req.DeviceID,
|
||||
"display_name": req.DeviceDisplayName,
|
||||
}).Info("PerformDeviceCreation")
|
||||
dev, err := a.DB.CreateDevice(ctx, req.Localpart, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent)
|
||||
dev, err := a.DB.CreateDevice(ctx, req.Localpart, serverName, req.DeviceID, req.AccessToken, req.DeviceDisplayName, req.IPAddr, req.UserAgent)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -276,12 +282,12 @@ func (a *UserInternalAPI) PerformDeviceDeletion(ctx context.Context, req *api.Pe
|
|||
deletedDeviceIDs := req.DeviceIDs
|
||||
if len(req.DeviceIDs) == 0 {
|
||||
var devices []api.Device
|
||||
devices, err = a.DB.RemoveAllDevices(ctx, local, req.ExceptDeviceID)
|
||||
devices, err = a.DB.RemoveAllDevices(ctx, local, domain, req.ExceptDeviceID)
|
||||
for _, d := range devices {
|
||||
deletedDeviceIDs = append(deletedDeviceIDs, d.ID)
|
||||
}
|
||||
} else {
|
||||
err = a.DB.RemoveDevices(ctx, local, req.DeviceIDs)
|
||||
err = a.DB.RemoveDevices(ctx, local, domain, req.DeviceIDs)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -335,23 +341,29 @@ func (a *UserInternalAPI) PerformLastSeenUpdate(
|
|||
req *api.PerformLastSeenUpdateRequest,
|
||||
res *api.PerformLastSeenUpdateResponse,
|
||||
) error {
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', req.UserID)
|
||||
localpart, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("gomatrixserverlib.SplitID: %w", err)
|
||||
}
|
||||
if err := a.DB.UpdateDeviceLastSeen(ctx, localpart, req.DeviceID, req.RemoteAddr, req.UserAgent); err != nil {
|
||||
if !a.Config.Matrix.IsLocalServerName(domain) {
|
||||
return fmt.Errorf("server name %s is not local", domain)
|
||||
}
|
||||
if err := a.DB.UpdateDeviceLastSeen(ctx, localpart, domain, req.DeviceID, req.RemoteAddr, req.UserAgent); err != nil {
|
||||
return fmt.Errorf("a.DeviceDB.UpdateDeviceLastSeen: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.PerformDeviceUpdateRequest, res *api.PerformDeviceUpdateResponse) error {
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', req.RequestingUserID)
|
||||
localpart, domain, err := gomatrixserverlib.SplitID('@', req.RequestingUserID)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("gomatrixserverlib.SplitID failed")
|
||||
return err
|
||||
}
|
||||
dev, err := a.DB.GetDeviceByID(ctx, localpart, req.DeviceID)
|
||||
if !a.Config.Matrix.IsLocalServerName(domain) {
|
||||
return fmt.Errorf("server name %s is not local", domain)
|
||||
}
|
||||
dev, err := a.DB.GetDeviceByID(ctx, localpart, domain, req.DeviceID)
|
||||
if err == sql.ErrNoRows {
|
||||
res.DeviceExists = false
|
||||
return nil
|
||||
|
@ -366,7 +378,7 @@ func (a *UserInternalAPI) PerformDeviceUpdate(ctx context.Context, req *api.Perf
|
|||
return nil
|
||||
}
|
||||
|
||||
err = a.DB.UpdateDevice(ctx, localpart, req.DeviceID, req.DisplayName)
|
||||
err = a.DB.UpdateDevice(ctx, localpart, domain, req.DeviceID, req.DisplayName)
|
||||
if err != nil {
|
||||
util.GetLogger(ctx).WithError(err).Error("deviceDB.UpdateDevice failed")
|
||||
return err
|
||||
|
@ -406,7 +418,7 @@ func (a *UserInternalAPI) QueryProfile(ctx context.Context, req *api.QueryProfil
|
|||
if !a.Config.Matrix.IsLocalServerName(domain) {
|
||||
return fmt.Errorf("cannot query profile of remote users (server name %s)", domain)
|
||||
}
|
||||
prof, err := a.DB.GetProfileByLocalpart(ctx, local)
|
||||
prof, err := a.DB.GetProfileByLocalpart(ctx, local, domain)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil
|
||||
|
@ -457,7 +469,7 @@ func (a *UserInternalAPI) QueryDevices(ctx context.Context, req *api.QueryDevice
|
|||
if !a.Config.Matrix.IsLocalServerName(domain) {
|
||||
return fmt.Errorf("cannot query devices of remote users (server name %s)", domain)
|
||||
}
|
||||
devs, err := a.DB.GetDevicesByLocalpart(ctx, local)
|
||||
devs, err := a.DB.GetDevicesByLocalpart(ctx, local, domain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -476,7 +488,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc
|
|||
}
|
||||
if req.DataType != "" {
|
||||
var data json.RawMessage
|
||||
data, err = a.DB.GetAccountDataByType(ctx, local, req.RoomID, req.DataType)
|
||||
data, err = a.DB.GetAccountDataByType(ctx, local, domain, req.RoomID, req.DataType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -494,7 +506,7 @@ func (a *UserInternalAPI) QueryAccountData(ctx context.Context, req *api.QueryAc
|
|||
}
|
||||
return nil
|
||||
}
|
||||
global, rooms, err := a.DB.GetAccountData(ctx, local)
|
||||
global, rooms, err := a.DB.GetAccountData(ctx, local, domain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -527,7 +539,7 @@ func (a *UserInternalAPI) QueryAccessToken(ctx context.Context, req *api.QueryAc
|
|||
if !a.Config.Matrix.IsLocalServerName(domain) {
|
||||
return nil
|
||||
}
|
||||
acc, err := a.DB.GetAccountByLocalpart(ctx, localPart)
|
||||
acc, err := a.DB.GetAccountByLocalpart(ctx, localPart, domain)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -561,14 +573,14 @@ func (a *UserInternalAPI) queryAppServiceToken(ctx context.Context, token, appSe
|
|||
AccountType: api.AccountTypeAppService,
|
||||
}
|
||||
|
||||
localpart, _, err := userutil.ParseUsernameParam(appServiceUserID, a.Config.Matrix)
|
||||
localpart, domain, err := userutil.ParseUsernameParam(appServiceUserID, a.Config.Matrix)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if localpart != "" { // AS is masquerading as another user
|
||||
// Verify that the user is registered
|
||||
account, err := a.DB.GetAccountByLocalpart(ctx, localpart)
|
||||
account, err := a.DB.GetAccountByLocalpart(ctx, localpart, domain)
|
||||
// Verify that the account exists and either appServiceID matches or
|
||||
// it belongs to the appservice user namespaces
|
||||
if err == nil && (account.AppServiceID == appService.ID || appService.IsInterestedInUserID(appServiceUserID)) {
|
||||
|
@ -620,7 +632,7 @@ func (a *UserInternalAPI) PerformAccountDeactivation(ctx context.Context, req *a
|
|||
return err
|
||||
}
|
||||
|
||||
err := a.DB.DeactivateAccount(ctx, req.Localpart)
|
||||
err := a.DB.DeactivateAccount(ctx, req.Localpart, serverName)
|
||||
res.AccountDeactivated = err == nil
|
||||
return err
|
||||
}
|
||||
|
@ -783,7 +795,7 @@ func (a *UserInternalAPI) QueryNotifications(ctx context.Context, req *api.Query
|
|||
if req.Only == "highlight" {
|
||||
filter = tables.HighlightNotifications
|
||||
}
|
||||
notifs, lastID, err := a.DB.GetNotifications(ctx, req.Localpart, fromID, req.Limit, filter)
|
||||
notifs, lastID, err := a.DB.GetNotifications(ctx, req.Localpart, req.ServerName, fromID, req.Limit, filter)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -811,23 +823,23 @@ func (a *UserInternalAPI) PerformPusherSet(ctx context.Context, req *api.Perform
|
|||
}
|
||||
}
|
||||
if req.Pusher.Kind == "" {
|
||||
return a.DB.RemovePusher(ctx, req.Pusher.AppID, req.Pusher.PushKey, req.Localpart)
|
||||
return a.DB.RemovePusher(ctx, req.Pusher.AppID, req.Pusher.PushKey, req.Localpart, req.ServerName)
|
||||
}
|
||||
if req.Pusher.PushKeyTS == 0 {
|
||||
req.Pusher.PushKeyTS = int64(time.Now().Unix())
|
||||
}
|
||||
return a.DB.UpsertPusher(ctx, req.Pusher, req.Localpart)
|
||||
return a.DB.UpsertPusher(ctx, req.Pusher, req.Localpart, req.ServerName)
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.PerformPusherDeletionRequest, res *struct{}) error {
|
||||
pushers, err := a.DB.GetPushers(ctx, req.Localpart)
|
||||
pushers, err := a.DB.GetPushers(ctx, req.Localpart, req.ServerName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for i := range pushers {
|
||||
logrus.Warnf("pusher session: %d, req session: %d", pushers[i].SessionID, req.SessionID)
|
||||
if pushers[i].SessionID != req.SessionID {
|
||||
err := a.DB.RemovePusher(ctx, pushers[i].AppID, pushers[i].PushKey, req.Localpart)
|
||||
err := a.DB.RemovePusher(ctx, pushers[i].AppID, pushers[i].PushKey, req.Localpart, req.ServerName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -838,7 +850,7 @@ func (a *UserInternalAPI) PerformPusherDeletion(ctx context.Context, req *api.Pe
|
|||
|
||||
func (a *UserInternalAPI) QueryPushers(ctx context.Context, req *api.QueryPushersRequest, res *api.QueryPushersResponse) error {
|
||||
var err error
|
||||
res.Pushers, err = a.DB.GetPushers(ctx, req.Localpart)
|
||||
res.Pushers, err = a.DB.GetPushers(ctx, req.Localpart, req.ServerName)
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -864,11 +876,11 @@ func (a *UserInternalAPI) PerformPushRulesPut(
|
|||
}
|
||||
|
||||
func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPushRulesRequest, res *api.QueryPushRulesResponse) error {
|
||||
localpart, _, err := gomatrixserverlib.SplitID('@', req.UserID)
|
||||
localpart, domain, err := gomatrixserverlib.SplitID('@', req.UserID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to split user ID %q for push rules", req.UserID)
|
||||
}
|
||||
pushRules, err := a.DB.QueryPushRules(ctx, localpart)
|
||||
pushRules, err := a.DB.QueryPushRules(ctx, localpart, domain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to query push rules: %w", err)
|
||||
}
|
||||
|
@ -877,14 +889,14 @@ func (a *UserInternalAPI) QueryPushRules(ctx context.Context, req *api.QueryPush
|
|||
}
|
||||
|
||||
func (a *UserInternalAPI) SetAvatarURL(ctx context.Context, req *api.PerformSetAvatarURLRequest, res *api.PerformSetAvatarURLResponse) error {
|
||||
profile, changed, err := a.DB.SetAvatarURL(ctx, req.Localpart, req.AvatarURL)
|
||||
profile, changed, err := a.DB.SetAvatarURL(ctx, req.Localpart, req.ServerName, req.AvatarURL)
|
||||
res.Profile = profile
|
||||
res.Changed = changed
|
||||
return err
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.QueryNumericLocalpartResponse) error {
|
||||
id, err := a.DB.GetNewNumericLocalpart(ctx)
|
||||
func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, req *api.QueryNumericLocalpartRequest, res *api.QueryNumericLocalpartResponse) error {
|
||||
id, err := a.DB.GetNewNumericLocalpart(ctx, req.ServerName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -894,12 +906,12 @@ func (a *UserInternalAPI) QueryNumericLocalpart(ctx context.Context, res *api.Qu
|
|||
|
||||
func (a *UserInternalAPI) QueryAccountAvailability(ctx context.Context, req *api.QueryAccountAvailabilityRequest, res *api.QueryAccountAvailabilityResponse) error {
|
||||
var err error
|
||||
res.Available, err = a.DB.CheckAccountAvailability(ctx, req.Localpart)
|
||||
res.Available, err = a.DB.CheckAccountAvailability(ctx, req.Localpart, req.ServerName)
|
||||
return err
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.QueryAccountByPasswordRequest, res *api.QueryAccountByPasswordResponse) error {
|
||||
acc, err := a.DB.GetAccountByPassword(ctx, req.Localpart, req.PlaintextPassword)
|
||||
acc, err := a.DB.GetAccountByPassword(ctx, req.Localpart, req.ServerName, req.PlaintextPassword)
|
||||
switch err {
|
||||
case sql.ErrNoRows: // user does not exist
|
||||
return nil
|
||||
|
@ -915,23 +927,24 @@ func (a *UserInternalAPI) QueryAccountByPassword(ctx context.Context, req *api.Q
|
|||
}
|
||||
|
||||
func (a *UserInternalAPI) SetDisplayName(ctx context.Context, req *api.PerformUpdateDisplayNameRequest, res *api.PerformUpdateDisplayNameResponse) error {
|
||||
profile, changed, err := a.DB.SetDisplayName(ctx, req.Localpart, req.DisplayName)
|
||||
profile, changed, err := a.DB.SetDisplayName(ctx, req.Localpart, req.ServerName, req.DisplayName)
|
||||
res.Profile = profile
|
||||
res.Changed = changed
|
||||
return err
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) QueryLocalpartForThreePID(ctx context.Context, req *api.QueryLocalpartForThreePIDRequest, res *api.QueryLocalpartForThreePIDResponse) error {
|
||||
localpart, err := a.DB.GetLocalpartForThreePID(ctx, req.ThreePID, req.Medium)
|
||||
localpart, domain, err := a.DB.GetLocalpartForThreePID(ctx, req.ThreePID, req.Medium)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
res.Localpart = localpart
|
||||
res.ServerName = domain
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *UserInternalAPI) QueryThreePIDsForLocalpart(ctx context.Context, req *api.QueryThreePIDsForLocalpartRequest, res *api.QueryThreePIDsForLocalpartResponse) error {
|
||||
r, err := a.DB.GetThreePIDsForLocalpart(ctx, req.Localpart)
|
||||
r, err := a.DB.GetThreePIDsForLocalpart(ctx, req.Localpart, req.ServerName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -944,7 +957,7 @@ func (a *UserInternalAPI) PerformForgetThreePID(ctx context.Context, req *api.Pe
|
|||
}
|
||||
|
||||
func (a *UserInternalAPI) PerformSaveThreePIDAssociation(ctx context.Context, req *api.PerformSaveThreePIDAssociationRequest, res *struct{}) error {
|
||||
return a.DB.SaveThreePIDAssociation(ctx, req.ThreePID, req.Localpart, req.Medium)
|
||||
return a.DB.SaveThreePIDAssociation(ctx, req.ThreePID, req.Localpart, req.ServerName, req.Medium)
|
||||
}
|
||||
|
||||
const pushRulesAccountDataType = "m.push_rules"
|
||||
|
|
|
@ -66,7 +66,7 @@ func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLog
|
|||
if !a.Config.Matrix.IsLocalServerName(domain) {
|
||||
return fmt.Errorf("cannot return a login token for a remote user (server name %s)", domain)
|
||||
}
|
||||
if _, err := a.DB.GetAccountByLocalpart(ctx, localpart); err != nil {
|
||||
if _, err := a.DB.GetAccountByLocalpart(ctx, localpart, domain); err != nil {
|
||||
res.Data = nil
|
||||
if err == sql.ErrNoRows {
|
||||
return nil
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue