mirror of
https://github.com/hoernschen/dendrite.git
synced 2025-08-01 13:52:46 +00:00
Virtual hosting schema and logic changes (#2876)
Note that virtual users cannot federate correctly yet.
This commit is contained in:
parent
e177e0ae73
commit
529df30b56
62 changed files with 1250 additions and 732 deletions
|
@ -68,9 +68,10 @@ const (
|
|||
// GetAccountByPassword returns the account associated with the given localpart and password.
|
||||
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
|
||||
func (d *Database) GetAccountByPassword(
|
||||
ctx context.Context, localpart, plaintextPassword string,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
plaintextPassword string,
|
||||
) (*api.Account, error) {
|
||||
hash, err := d.Accounts.SelectPasswordHash(ctx, localpart)
|
||||
hash, err := d.Accounts.SelectPasswordHash(ctx, localpart, serverName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -80,24 +81,27 @@ func (d *Database) GetAccountByPassword(
|
|||
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return d.Accounts.SelectAccountByLocalpart(ctx, localpart)
|
||||
return d.Accounts.SelectAccountByLocalpart(ctx, localpart, serverName)
|
||||
}
|
||||
|
||||
// GetProfileByLocalpart returns the profile associated with the given localpart.
|
||||
// Returns sql.ErrNoRows if no profile exists which matches the given localpart.
|
||||
func (d *Database) GetProfileByLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (*authtypes.Profile, error) {
|
||||
return d.Profiles.SelectProfileByLocalpart(ctx, localpart)
|
||||
return d.Profiles.SelectProfileByLocalpart(ctx, localpart, serverName)
|
||||
}
|
||||
|
||||
// SetAvatarURL updates the avatar URL of the profile associated with the given
|
||||
// localpart. Returns an error if something went wrong with the SQL query
|
||||
func (d *Database) SetAvatarURL(
|
||||
ctx context.Context, localpart string, avatarURL string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
avatarURL string,
|
||||
) (profile *authtypes.Profile, changed bool, err error) {
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
profile, changed, err = d.Profiles.SetAvatarURL(ctx, txn, localpart, avatarURL)
|
||||
profile, changed, err = d.Profiles.SetAvatarURL(ctx, txn, localpart, serverName, avatarURL)
|
||||
return err
|
||||
})
|
||||
return
|
||||
|
@ -106,10 +110,12 @@ func (d *Database) SetAvatarURL(
|
|||
// SetDisplayName updates the display name of the profile associated with the given
|
||||
// localpart. Returns an error if something went wrong with the SQL query
|
||||
func (d *Database) SetDisplayName(
|
||||
ctx context.Context, localpart string, displayName string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
displayName string,
|
||||
) (profile *authtypes.Profile, changed bool, err error) {
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
profile, changed, err = d.Profiles.SetDisplayName(ctx, txn, localpart, displayName)
|
||||
profile, changed, err = d.Profiles.SetDisplayName(ctx, txn, localpart, serverName, displayName)
|
||||
return err
|
||||
})
|
||||
return
|
||||
|
@ -117,14 +123,15 @@ func (d *Database) SetDisplayName(
|
|||
|
||||
// SetPassword sets the account password to the given hash.
|
||||
func (d *Database) SetPassword(
|
||||
ctx context.Context, localpart, plaintextPassword string,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
plaintextPassword string,
|
||||
) error {
|
||||
hash, err := d.hashPassword(plaintextPassword)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return d.Writer.Do(nil, nil, func(txn *sql.Tx) error {
|
||||
return d.Accounts.UpdatePassword(ctx, localpart, hash)
|
||||
return d.Accounts.UpdatePassword(ctx, localpart, serverName, hash)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -132,21 +139,22 @@ func (d *Database) SetPassword(
|
|||
// for this account. If no password is supplied, the account will be a passwordless account. If the
|
||||
// account already exists, it will return nil, ErrUserExists.
|
||||
func (d *Database) CreateAccount(
|
||||
ctx context.Context, localpart, plaintextPassword, appserviceID string, accountType api.AccountType,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
plaintextPassword, appserviceID string, accountType api.AccountType,
|
||||
) (acc *api.Account, err error) {
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
// For guest accounts, we create a new numeric local part
|
||||
if accountType == api.AccountTypeGuest {
|
||||
var numLocalpart int64
|
||||
numLocalpart, err = d.Accounts.SelectNewNumericLocalpart(ctx, txn)
|
||||
numLocalpart, err = d.Accounts.SelectNewNumericLocalpart(ctx, txn, serverName)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("d.Accounts.SelectNewNumericLocalpart: %w", err)
|
||||
}
|
||||
localpart = strconv.FormatInt(numLocalpart, 10)
|
||||
plaintextPassword = ""
|
||||
appserviceID = ""
|
||||
}
|
||||
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, accountType)
|
||||
acc, err = d.createAccount(ctx, txn, localpart, serverName, plaintextPassword, appserviceID, accountType)
|
||||
return err
|
||||
})
|
||||
return
|
||||
|
@ -155,7 +163,9 @@ func (d *Database) CreateAccount(
|
|||
// WARNING! This function assumes that the relevant mutexes have already
|
||||
// been taken out by the caller (e.g. CreateAccount or CreateGuestAccount).
|
||||
func (d *Database) createAccount(
|
||||
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, accountType api.AccountType,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
plaintextPassword, appserviceID string, accountType api.AccountType,
|
||||
) (*api.Account, error) {
|
||||
var err error
|
||||
var account *api.Account
|
||||
|
@ -167,28 +177,28 @@ func (d *Database) createAccount(
|
|||
return nil, err
|
||||
}
|
||||
}
|
||||
if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, hash, appserviceID, accountType); err != nil {
|
||||
if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, serverName, hash, appserviceID, accountType); err != nil {
|
||||
return nil, sqlutil.ErrUserExists
|
||||
}
|
||||
if err = d.Profiles.InsertProfile(ctx, txn, localpart); err != nil {
|
||||
return nil, err
|
||||
if err = d.Profiles.InsertProfile(ctx, txn, localpart, serverName); err != nil {
|
||||
return nil, fmt.Errorf("d.Profiles.InsertProfile: %w", err)
|
||||
}
|
||||
pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, d.ServerName)
|
||||
pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, serverName)
|
||||
prbs, err := json.Marshal(pushRuleSets)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("json.Marshal: %w", err)
|
||||
}
|
||||
if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(prbs)); err != nil {
|
||||
return nil, err
|
||||
if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, serverName, "", "m.push_rules", json.RawMessage(prbs)); err != nil {
|
||||
return nil, fmt.Errorf("d.AccountDatas.InsertAccountData: %w", err)
|
||||
}
|
||||
return account, nil
|
||||
}
|
||||
|
||||
func (d *Database) QueryPushRules(
|
||||
ctx context.Context,
|
||||
localpart string,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (*pushrules.AccountRuleSets, error) {
|
||||
data, err := d.AccountDatas.SelectAccountDataByType(ctx, localpart, "", "m.push_rules")
|
||||
data, err := d.AccountDatas.SelectAccountDataByType(ctx, localpart, serverName, "", "m.push_rules")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -196,13 +206,13 @@ func (d *Database) QueryPushRules(
|
|||
// If we didn't find any default push rules then we should just generate some
|
||||
// fresh ones.
|
||||
if len(data) == 0 {
|
||||
pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, d.ServerName)
|
||||
pushRuleSets := pushrules.DefaultAccountRuleSets(localpart, serverName)
|
||||
prbs, err := json.Marshal(pushRuleSets)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal default push rules: %w", err)
|
||||
}
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
if dbErr := d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", prbs); dbErr != nil {
|
||||
if dbErr := d.AccountDatas.InsertAccountData(ctx, txn, localpart, serverName, "", "m.push_rules", prbs); dbErr != nil {
|
||||
return fmt.Errorf("failed to save default push rules: %w", dbErr)
|
||||
}
|
||||
return nil
|
||||
|
@ -225,22 +235,23 @@ func (d *Database) QueryPushRules(
|
|||
// update the corresponding row with the new content
|
||||
// Returns a SQL error if there was an issue with the insertion/update
|
||||
func (d *Database) SaveAccountData(
|
||||
ctx context.Context, localpart, roomID, dataType string, content json.RawMessage,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
roomID, dataType string, content json.RawMessage,
|
||||
) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.AccountDatas.InsertAccountData(ctx, txn, localpart, roomID, dataType, content)
|
||||
return d.AccountDatas.InsertAccountData(ctx, txn, localpart, serverName, roomID, dataType, content)
|
||||
})
|
||||
}
|
||||
|
||||
// GetAccountData returns account data related to a given localpart
|
||||
// If no account data could be found, returns an empty arrays
|
||||
// Returns an error if there was an issue with the retrieval
|
||||
func (d *Database) GetAccountData(ctx context.Context, localpart string) (
|
||||
func (d *Database) GetAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (
|
||||
global map[string]json.RawMessage,
|
||||
rooms map[string]map[string]json.RawMessage,
|
||||
err error,
|
||||
) {
|
||||
return d.AccountDatas.SelectAccountData(ctx, localpart)
|
||||
return d.AccountDatas.SelectAccountData(ctx, localpart, serverName)
|
||||
}
|
||||
|
||||
// GetAccountDataByType returns account data matching a given
|
||||
|
@ -248,18 +259,19 @@ func (d *Database) GetAccountData(ctx context.Context, localpart string) (
|
|||
// If no account data could be found, returns nil
|
||||
// Returns an error if there was an issue with the retrieval
|
||||
func (d *Database) GetAccountDataByType(
|
||||
ctx context.Context, localpart, roomID, dataType string,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
roomID, dataType string,
|
||||
) (data json.RawMessage, err error) {
|
||||
return d.AccountDatas.SelectAccountDataByType(
|
||||
ctx, localpart, roomID, dataType,
|
||||
ctx, localpart, serverName, roomID, dataType,
|
||||
)
|
||||
}
|
||||
|
||||
// GetNewNumericLocalpart generates and returns a new unused numeric localpart
|
||||
func (d *Database) GetNewNumericLocalpart(
|
||||
ctx context.Context,
|
||||
ctx context.Context, serverName gomatrixserverlib.ServerName,
|
||||
) (int64, error) {
|
||||
return d.Accounts.SelectNewNumericLocalpart(ctx, nil)
|
||||
return d.Accounts.SelectNewNumericLocalpart(ctx, nil, serverName)
|
||||
}
|
||||
|
||||
func (d *Database) hashPassword(plaintext string) (hash string, err error) {
|
||||
|
@ -276,10 +288,12 @@ var Err3PIDInUse = errors.New("this third-party identifier is already in use")
|
|||
// If the third-party identifier is already part of an association, returns Err3PIDInUse.
|
||||
// Returns an error if there was a problem talking to the database.
|
||||
func (d *Database) SaveThreePIDAssociation(
|
||||
ctx context.Context, threepid, localpart, medium string,
|
||||
ctx context.Context, threepid string,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
medium string,
|
||||
) (err error) {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
user, err := d.ThreePIDs.SelectLocalpartForThreePID(
|
||||
user, _, err := d.ThreePIDs.SelectLocalpartForThreePID(
|
||||
ctx, txn, threepid, medium,
|
||||
)
|
||||
if err != nil {
|
||||
|
@ -290,7 +304,7 @@ func (d *Database) SaveThreePIDAssociation(
|
|||
return Err3PIDInUse
|
||||
}
|
||||
|
||||
return d.ThreePIDs.InsertThreePID(ctx, txn, threepid, medium, localpart)
|
||||
return d.ThreePIDs.InsertThreePID(ctx, txn, threepid, medium, localpart, serverName)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -313,7 +327,7 @@ func (d *Database) RemoveThreePIDAssociation(
|
|||
// Returns an error if there was a problem talking to the database.
|
||||
func (d *Database) GetLocalpartForThreePID(
|
||||
ctx context.Context, threepid string, medium string,
|
||||
) (localpart string, err error) {
|
||||
) (localpart string, serverName gomatrixserverlib.ServerName, err error) {
|
||||
return d.ThreePIDs.SelectLocalpartForThreePID(ctx, nil, threepid, medium)
|
||||
}
|
||||
|
||||
|
@ -322,16 +336,17 @@ func (d *Database) GetLocalpartForThreePID(
|
|||
// If no association is known for this user, returns an empty slice.
|
||||
// Returns an error if there was an issue talking to the database.
|
||||
func (d *Database) GetThreePIDsForLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (threepids []authtypes.ThreePID, err error) {
|
||||
return d.ThreePIDs.SelectThreePIDsForLocalpart(ctx, localpart)
|
||||
return d.ThreePIDs.SelectThreePIDsForLocalpart(ctx, localpart, serverName)
|
||||
}
|
||||
|
||||
// CheckAccountAvailability checks if the username/localpart is already present
|
||||
// in the database.
|
||||
// If the DB returns sql.ErrNoRows the Localpart isn't taken.
|
||||
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) {
|
||||
_, err := d.Accounts.SelectAccountByLocalpart(ctx, localpart)
|
||||
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (bool, error) {
|
||||
_, err := d.Accounts.SelectAccountByLocalpart(ctx, localpart, serverName)
|
||||
if err == sql.ErrNoRows {
|
||||
return true, nil
|
||||
}
|
||||
|
@ -341,12 +356,12 @@ func (d *Database) CheckAccountAvailability(ctx context.Context, localpart strin
|
|||
// GetAccountByLocalpart returns the account associated with the given localpart.
|
||||
// This function assumes the request is authenticated or the account data is used only internally.
|
||||
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
|
||||
func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string,
|
||||
func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (*api.Account, error) {
|
||||
// try to get the account with lowercase localpart (majority)
|
||||
acc, err := d.Accounts.SelectAccountByLocalpart(ctx, strings.ToLower(localpart))
|
||||
acc, err := d.Accounts.SelectAccountByLocalpart(ctx, strings.ToLower(localpart), serverName)
|
||||
if err == sql.ErrNoRows {
|
||||
acc, err = d.Accounts.SelectAccountByLocalpart(ctx, localpart) // try with localpart as passed by the request
|
||||
acc, err = d.Accounts.SelectAccountByLocalpart(ctx, localpart, serverName) // try with localpart as passed by the request
|
||||
}
|
||||
return acc, err
|
||||
}
|
||||
|
@ -359,20 +374,24 @@ func (d *Database) SearchProfiles(ctx context.Context, searchString string, limi
|
|||
}
|
||||
|
||||
// DeactivateAccount deactivates the user's account, removing all ability for the user to login again.
|
||||
func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) {
|
||||
func (d *Database) DeactivateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (err error) {
|
||||
return d.Writer.Do(nil, nil, func(txn *sql.Tx) error {
|
||||
return d.Accounts.DeactivateAccount(ctx, localpart)
|
||||
return d.Accounts.DeactivateAccount(ctx, localpart, serverName)
|
||||
})
|
||||
}
|
||||
|
||||
// CreateOpenIDToken persists a new token that was issued for OpenID Connect
|
||||
func (d *Database) CreateOpenIDToken(
|
||||
ctx context.Context,
|
||||
token, localpart string,
|
||||
token, userID string,
|
||||
) (int64, error) {
|
||||
localpart, domain, err := gomatrixserverlib.SplitID('@', userID)
|
||||
if err != nil {
|
||||
return 0, nil
|
||||
}
|
||||
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.OpenIDTokenLifetimeMS
|
||||
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.OpenIDTokens.InsertOpenIDToken(ctx, txn, token, localpart, expiresAtMS)
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.OpenIDTokens.InsertOpenIDToken(ctx, txn, token, localpart, domain, expiresAtMS)
|
||||
})
|
||||
return expiresAtMS, err
|
||||
}
|
||||
|
@ -539,16 +558,19 @@ func (d *Database) GetDeviceByAccessToken(
|
|||
// GetDeviceByID returns the device matching the given ID.
|
||||
// Returns sql.ErrNoRows if no matching device was found.
|
||||
func (d *Database) GetDeviceByID(
|
||||
ctx context.Context, localpart, deviceID string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
deviceID string,
|
||||
) (*api.Device, error) {
|
||||
return d.Devices.SelectDeviceByID(ctx, localpart, deviceID)
|
||||
return d.Devices.SelectDeviceByID(ctx, localpart, serverName, deviceID)
|
||||
}
|
||||
|
||||
// GetDevicesByLocalpart returns the devices matching the given localpart.
|
||||
func (d *Database) GetDevicesByLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) ([]api.Device, error) {
|
||||
return d.Devices.SelectDevicesByLocalpart(ctx, nil, localpart, "")
|
||||
return d.Devices.SelectDevicesByLocalpart(ctx, nil, localpart, serverName, "")
|
||||
}
|
||||
|
||||
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
|
||||
|
@ -562,18 +584,18 @@ func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]ap
|
|||
// If no device ID is given one is generated.
|
||||
// Returns the device on success.
|
||||
func (d *Database) CreateDevice(
|
||||
ctx context.Context, localpart string, deviceID *string, accessToken string,
|
||||
displayName *string, ipAddr, userAgent string,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string,
|
||||
) (dev *api.Device, returnErr error) {
|
||||
if deviceID != nil {
|
||||
returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
var err error
|
||||
// Revoke existing tokens for this device
|
||||
if err = d.Devices.DeleteDevice(ctx, txn, *deviceID, localpart); err != nil {
|
||||
if err = d.Devices.DeleteDevice(ctx, txn, *deviceID, localpart, serverName); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
|
||||
dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, serverName, accessToken, displayName, ipAddr, userAgent)
|
||||
return err
|
||||
})
|
||||
} else {
|
||||
|
@ -588,7 +610,7 @@ func (d *Database) CreateDevice(
|
|||
|
||||
returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
var err error
|
||||
dev, err = d.Devices.InsertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
|
||||
dev, err = d.Devices.InsertDevice(ctx, txn, newDeviceID, localpart, serverName, accessToken, displayName, ipAddr, userAgent)
|
||||
return err
|
||||
})
|
||||
if returnErr == nil {
|
||||
|
@ -614,10 +636,12 @@ func generateDeviceID() (string, error) {
|
|||
// UpdateDevice updates the given device with the display name.
|
||||
// Returns SQL error if there are problems and nil on success.
|
||||
func (d *Database) UpdateDevice(
|
||||
ctx context.Context, localpart, deviceID string, displayName *string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
deviceID string, displayName *string,
|
||||
) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.Devices.UpdateDeviceName(ctx, txn, localpart, deviceID, displayName)
|
||||
return d.Devices.UpdateDeviceName(ctx, txn, localpart, serverName, deviceID, displayName)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -626,10 +650,12 @@ func (d *Database) UpdateDevice(
|
|||
// If the devices don't exist, it will not return an error
|
||||
// If something went wrong during the deletion, it will return the SQL error.
|
||||
func (d *Database) RemoveDevices(
|
||||
ctx context.Context, localpart string, devices []string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
devices []string,
|
||||
) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
if err := d.Devices.DeleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
|
||||
if err := d.Devices.DeleteDevices(ctx, txn, localpart, serverName, devices); err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
|
@ -640,14 +666,16 @@ func (d *Database) RemoveDevices(
|
|||
// database matching the given user ID localpart.
|
||||
// If something went wrong during the deletion, it will return the SQL error.
|
||||
func (d *Database) RemoveAllDevices(
|
||||
ctx context.Context, localpart, exceptDeviceID string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
exceptDeviceID string,
|
||||
) (devices []api.Device, err error) {
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
devices, err = d.Devices.SelectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
|
||||
devices, err = d.Devices.SelectDevicesByLocalpart(ctx, txn, localpart, serverName, exceptDeviceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := d.Devices.DeleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
|
||||
if err := d.Devices.DeleteDevicesByLocalpart(ctx, txn, localpart, serverName, exceptDeviceID); err != sql.ErrNoRows {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
|
@ -656,9 +684,9 @@ func (d *Database) RemoveAllDevices(
|
|||
}
|
||||
|
||||
// UpdateDeviceLastSeen updates a last seen timestamp and the ip address.
|
||||
func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr, userAgent string) error {
|
||||
func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.Devices.UpdateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr, userAgent)
|
||||
return d.Devices.UpdateDeviceLastSeen(ctx, txn, localpart, serverName, deviceID, ipAddr, userAgent)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -706,38 +734,38 @@ func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (
|
|||
return d.LoginTokens.SelectLoginToken(ctx, token)
|
||||
}
|
||||
|
||||
func (d *Database) InsertNotification(ctx context.Context, localpart, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error {
|
||||
func (d *Database) InsertNotification(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
return d.Notifications.Insert(ctx, txn, localpart, eventID, pos, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n)
|
||||
return d.Notifications.Insert(ctx, txn, localpart, serverName, eventID, pos, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n)
|
||||
})
|
||||
}
|
||||
|
||||
func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos uint64) (affected bool, err error) {
|
||||
func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, err error) {
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
affected, err = d.Notifications.DeleteUpTo(ctx, txn, localpart, roomID, pos)
|
||||
affected, err = d.Notifications.DeleteUpTo(ctx, txn, localpart, serverName, roomID, pos)
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (d *Database) SetNotificationsRead(ctx context.Context, localpart, roomID string, pos uint64, b bool) (affected bool, err error) {
|
||||
func (d *Database) SetNotificationsRead(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, b bool) (affected bool, err error) {
|
||||
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
affected, err = d.Notifications.UpdateRead(ctx, txn, localpart, roomID, pos, b)
|
||||
affected, err = d.Notifications.UpdateRead(ctx, txn, localpart, serverName, roomID, pos, b)
|
||||
return err
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (d *Database) GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
|
||||
return d.Notifications.Select(ctx, nil, localpart, fromID, limit, filter)
|
||||
func (d *Database) GetNotifications(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
|
||||
return d.Notifications.Select(ctx, nil, localpart, serverName, fromID, limit, filter)
|
||||
}
|
||||
|
||||
func (d *Database) GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error) {
|
||||
return d.Notifications.SelectCount(ctx, nil, localpart, filter)
|
||||
func (d *Database) GetNotificationCount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, filter tables.NotificationFilter) (int64, error) {
|
||||
return d.Notifications.SelectCount(ctx, nil, localpart, serverName, filter)
|
||||
}
|
||||
|
||||
func (d *Database) GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error) {
|
||||
return d.Notifications.SelectRoomCounts(ctx, nil, localpart, roomID)
|
||||
func (d *Database) GetRoomNotificationCounts(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, _ error) {
|
||||
return d.Notifications.SelectRoomCounts(ctx, nil, localpart, serverName, roomID)
|
||||
}
|
||||
|
||||
func (d *Database) DeleteOldNotifications(ctx context.Context) error {
|
||||
|
@ -747,7 +775,8 @@ func (d *Database) DeleteOldNotifications(ctx context.Context) error {
|
|||
}
|
||||
|
||||
func (d *Database) UpsertPusher(
|
||||
ctx context.Context, p api.Pusher, localpart string,
|
||||
ctx context.Context, p api.Pusher,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
data, err := json.Marshal(p.Data)
|
||||
if err != nil {
|
||||
|
@ -766,25 +795,26 @@ func (d *Database) UpsertPusher(
|
|||
p.ProfileTag,
|
||||
p.Language,
|
||||
string(data),
|
||||
localpart)
|
||||
localpart,
|
||||
serverName)
|
||||
})
|
||||
}
|
||||
|
||||
// GetPushers returns the pushers matching the given localpart.
|
||||
func (d *Database) GetPushers(
|
||||
ctx context.Context, localpart string,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) ([]api.Pusher, error) {
|
||||
return d.Pushers.SelectPushers(ctx, nil, localpart)
|
||||
return d.Pushers.SelectPushers(ctx, nil, localpart, serverName)
|
||||
}
|
||||
|
||||
// RemovePusher deletes one pusher
|
||||
// Invoked when `append` is true and `kind` is null in
|
||||
// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-pushers-set
|
||||
func (d *Database) RemovePusher(
|
||||
ctx context.Context, appid, pushkey, localpart string,
|
||||
ctx context.Context, appid, pushkey, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
|
||||
err := d.Pushers.DeletePusher(ctx, txn, appid, pushkey, localpart)
|
||||
err := d.Pushers.DeletePusher(ctx, txn, appid, pushkey, localpart, serverName)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue