refactor: update GMSL (#3058)

Sister PR to https://github.com/matrix-org/gomatrixserverlib/pull/364

Read this commit by commit to avoid going insane.
This commit is contained in:
kegsay 2023-04-19 15:50:33 +01:00 committed by GitHub
parent 9fa39263c0
commit 72285b2659
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
306 changed files with 2117 additions and 1934 deletions

View file

@ -28,6 +28,7 @@ import (
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/gomatrixserverlib/fclient"
"github.com/matrix-org/gomatrixserverlib/spec"
"golang.org/x/crypto/bcrypt"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
@ -55,7 +56,7 @@ type Database struct {
Pushers tables.PusherTable
Stats tables.StatsTable
LoginTokenLifetime time.Duration
ServerName gomatrixserverlib.ServerName
ServerName spec.ServerName
BcryptCost int
OpenIDTokenLifetimeMS int64
}
@ -80,7 +81,7 @@ const (
// GetAccountByPassword returns the account associated with the given localpart and password.
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByPassword(
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
ctx context.Context, localpart string, serverName spec.ServerName,
plaintextPassword string,
) (*api.Account, error) {
hash, err := d.Accounts.SelectPasswordHash(ctx, localpart, serverName)
@ -100,7 +101,7 @@ func (d *Database) GetAccountByPassword(
// Returns sql.ErrNoRows if no profile exists which matches the given localpart.
func (d *Database) GetProfileByLocalpart(
ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
localpart string, serverName spec.ServerName,
) (*authtypes.Profile, error) {
return d.Profiles.SelectProfileByLocalpart(ctx, localpart, serverName)
}
@ -109,7 +110,7 @@ func (d *Database) GetProfileByLocalpart(
// localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetAvatarURL(
ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
localpart string, serverName spec.ServerName,
avatarURL string,
) (profile *authtypes.Profile, changed bool, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
@ -123,7 +124,7 @@ func (d *Database) SetAvatarURL(
// localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetDisplayName(
ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
localpart string, serverName spec.ServerName,
displayName string,
) (profile *authtypes.Profile, changed bool, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
@ -135,7 +136,7 @@ func (d *Database) SetDisplayName(
// SetPassword sets the account password to the given hash.
func (d *Database) SetPassword(
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
ctx context.Context, localpart string, serverName spec.ServerName,
plaintextPassword string,
) error {
hash, err := d.hashPassword(plaintextPassword)
@ -151,7 +152,7 @@ func (d *Database) SetPassword(
// for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, ErrUserExists.
func (d *Database) CreateAccount(
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
ctx context.Context, localpart string, serverName spec.ServerName,
plaintextPassword, appserviceID string, accountType api.AccountType,
) (acc *api.Account, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
@ -176,7 +177,7 @@ func (d *Database) CreateAccount(
// been taken out by the caller (e.g. CreateAccount or CreateGuestAccount).
func (d *Database) createAccount(
ctx context.Context, txn *sql.Tx,
localpart string, serverName gomatrixserverlib.ServerName,
localpart string, serverName spec.ServerName,
plaintextPassword, appserviceID string, accountType api.AccountType,
) (*api.Account, error) {
var err error
@ -208,7 +209,7 @@ func (d *Database) createAccount(
func (d *Database) QueryPushRules(
ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
localpart string, serverName spec.ServerName,
) (*pushrules.AccountRuleSets, error) {
data, err := d.AccountDatas.SelectAccountDataByType(ctx, localpart, serverName, "", "m.push_rules")
if err != nil {
@ -247,7 +248,7 @@ func (d *Database) QueryPushRules(
// update the corresponding row with the new content
// Returns a SQL error if there was an issue with the insertion/update
func (d *Database) SaveAccountData(
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
ctx context.Context, localpart string, serverName spec.ServerName,
roomID, dataType string, content json.RawMessage,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
@ -258,7 +259,7 @@ func (d *Database) SaveAccountData(
// GetAccountData returns account data related to a given localpart
// If no account data could be found, returns an empty arrays
// Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountData(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (
func (d *Database) GetAccountData(ctx context.Context, localpart string, serverName spec.ServerName) (
global map[string]json.RawMessage,
rooms map[string]map[string]json.RawMessage,
err error,
@ -271,7 +272,7 @@ func (d *Database) GetAccountData(ctx context.Context, localpart string, serverN
// If no account data could be found, returns nil
// Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountDataByType(
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
ctx context.Context, localpart string, serverName spec.ServerName,
roomID, dataType string,
) (data json.RawMessage, err error) {
return d.AccountDatas.SelectAccountDataByType(
@ -281,7 +282,7 @@ func (d *Database) GetAccountDataByType(
// GetNewNumericLocalpart generates and returns a new unused numeric localpart
func (d *Database) GetNewNumericLocalpart(
ctx context.Context, serverName gomatrixserverlib.ServerName,
ctx context.Context, serverName spec.ServerName,
) (int64, error) {
return d.Accounts.SelectNewNumericLocalpart(ctx, nil, serverName)
}
@ -301,7 +302,7 @@ var Err3PIDInUse = errors.New("this third-party identifier is already in use")
// Returns an error if there was a problem talking to the database.
func (d *Database) SaveThreePIDAssociation(
ctx context.Context, threepid string,
localpart string, serverName gomatrixserverlib.ServerName,
localpart string, serverName spec.ServerName,
medium string,
) (err error) {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
@ -339,7 +340,7 @@ func (d *Database) RemoveThreePIDAssociation(
// Returns an error if there was a problem talking to the database.
func (d *Database) GetLocalpartForThreePID(
ctx context.Context, threepid string, medium string,
) (localpart string, serverName gomatrixserverlib.ServerName, err error) {
) (localpart string, serverName spec.ServerName, err error) {
return d.ThreePIDs.SelectLocalpartForThreePID(ctx, nil, threepid, medium)
}
@ -349,7 +350,7 @@ func (d *Database) GetLocalpartForThreePID(
// Returns an error if there was an issue talking to the database.
func (d *Database) GetThreePIDsForLocalpart(
ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
localpart string, serverName spec.ServerName,
) (threepids []authtypes.ThreePID, err error) {
return d.ThreePIDs.SelectThreePIDsForLocalpart(ctx, localpart, serverName)
}
@ -357,7 +358,7 @@ func (d *Database) GetThreePIDsForLocalpart(
// CheckAccountAvailability checks if the username/localpart is already present
// in the database.
// If the DB returns sql.ErrNoRows the Localpart isn't taken.
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (bool, error) {
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string, serverName spec.ServerName) (bool, error) {
_, err := d.Accounts.SelectAccountByLocalpart(ctx, localpart, serverName)
if err == sql.ErrNoRows {
return true, nil
@ -368,7 +369,7 @@ func (d *Database) CheckAccountAvailability(ctx context.Context, localpart strin
// GetAccountByLocalpart returns the account associated with the given localpart.
// This function assumes the request is authenticated or the account data is used only internally.
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string, serverName spec.ServerName,
) (*api.Account, error) {
// try to get the account with lowercase localpart (majority)
acc, err := d.Accounts.SelectAccountByLocalpart(ctx, strings.ToLower(localpart), serverName)
@ -386,7 +387,7 @@ func (d *Database) SearchProfiles(ctx context.Context, searchString string, limi
}
// DeactivateAccount deactivates the user's account, removing all ability for the user to login again.
func (d *Database) DeactivateAccount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName) (err error) {
func (d *Database) DeactivateAccount(ctx context.Context, localpart string, serverName spec.ServerName) (err error) {
return d.Writer.Do(nil, nil, func(txn *sql.Tx) error {
return d.Accounts.DeactivateAccount(ctx, localpart, serverName)
})
@ -571,7 +572,7 @@ func (d *Database) GetDeviceByAccessToken(
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByID(
ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
localpart string, serverName spec.ServerName,
deviceID string,
) (*api.Device, error) {
return d.Devices.SelectDeviceByID(ctx, localpart, serverName, deviceID)
@ -580,7 +581,7 @@ func (d *Database) GetDeviceByID(
// GetDevicesByLocalpart returns the devices matching the given localpart.
func (d *Database) GetDevicesByLocalpart(
ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
localpart string, serverName spec.ServerName,
) ([]api.Device, error) {
return d.Devices.SelectDevicesByLocalpart(ctx, nil, localpart, serverName, "")
}
@ -596,7 +597,7 @@ func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]ap
// If no device ID is given one is generated.
// Returns the device on success.
func (d *Database) CreateDevice(
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
ctx context.Context, localpart string, serverName spec.ServerName,
deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string,
) (dev *api.Device, returnErr error) {
if deviceID != nil {
@ -675,7 +676,7 @@ func generateDeviceID() (string, error) {
// Returns SQL error if there are problems and nil on success.
func (d *Database) UpdateDevice(
ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
localpart string, serverName spec.ServerName,
deviceID string, displayName *string,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
@ -689,7 +690,7 @@ func (d *Database) UpdateDevice(
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevices(
ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
localpart string, serverName spec.ServerName,
devices []string,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
@ -705,7 +706,7 @@ func (d *Database) RemoveDevices(
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveAllDevices(
ctx context.Context,
localpart string, serverName gomatrixserverlib.ServerName,
localpart string, serverName spec.ServerName,
exceptDeviceID string,
) (devices []api.Device, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
@ -722,7 +723,7 @@ func (d *Database) RemoveAllDevices(
}
// UpdateDeviceLastSeen updates a last seen timestamp and the ip address.
func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error {
func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart string, serverName spec.ServerName, deviceID, ipAddr, userAgent string) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Devices.UpdateDeviceLastSeen(ctx, txn, localpart, serverName, deviceID, ipAddr, userAgent)
})
@ -772,13 +773,13 @@ func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (
return d.LoginTokens.SelectLoginToken(ctx, token)
}
func (d *Database) InsertNotification(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error {
func (d *Database) InsertNotification(ctx context.Context, localpart string, serverName spec.ServerName, eventID string, pos uint64, tweaks map[string]interface{}, n *api.Notification) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Notifications.Insert(ctx, txn, localpart, serverName, eventID, pos, pushrules.BoolTweakOr(tweaks, pushrules.HighlightTweak, false), n)
})
}
func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, err error) {
func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart string, serverName spec.ServerName, roomID string, pos uint64) (affected bool, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
affected, err = d.Notifications.DeleteUpTo(ctx, txn, localpart, serverName, roomID, pos)
return err
@ -786,7 +787,7 @@ func (d *Database) DeleteNotificationsUpTo(ctx context.Context, localpart string
return
}
func (d *Database) SetNotificationsRead(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, b bool) (affected bool, err error) {
func (d *Database) SetNotificationsRead(ctx context.Context, localpart string, serverName spec.ServerName, roomID string, pos uint64, b bool) (affected bool, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
affected, err = d.Notifications.UpdateRead(ctx, txn, localpart, serverName, roomID, pos, b)
return err
@ -794,15 +795,15 @@ func (d *Database) SetNotificationsRead(ctx context.Context, localpart string, s
return
}
func (d *Database) GetNotifications(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
func (d *Database) GetNotifications(ctx context.Context, localpart string, serverName spec.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
return d.Notifications.Select(ctx, nil, localpart, serverName, fromID, limit, filter)
}
func (d *Database) GetNotificationCount(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, filter tables.NotificationFilter) (int64, error) {
func (d *Database) GetNotificationCount(ctx context.Context, localpart string, serverName spec.ServerName, filter tables.NotificationFilter) (int64, error) {
return d.Notifications.SelectCount(ctx, nil, localpart, serverName, filter)
}
func (d *Database) GetRoomNotificationCounts(ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, _ error) {
func (d *Database) GetRoomNotificationCounts(ctx context.Context, localpart string, serverName spec.ServerName, roomID string) (total int64, highlight int64, _ error) {
return d.Notifications.SelectRoomCounts(ctx, nil, localpart, serverName, roomID)
}
@ -814,7 +815,7 @@ func (d *Database) DeleteOldNotifications(ctx context.Context) error {
func (d *Database) UpsertPusher(
ctx context.Context, p api.Pusher,
localpart string, serverName gomatrixserverlib.ServerName,
localpart string, serverName spec.ServerName,
) error {
data, err := json.Marshal(p.Data)
if err != nil {
@ -840,7 +841,7 @@ func (d *Database) UpsertPusher(
// GetPushers returns the pushers matching the given localpart.
func (d *Database) GetPushers(
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
ctx context.Context, localpart string, serverName spec.ServerName,
) ([]api.Pusher, error) {
return d.Pushers.SelectPushers(ctx, nil, localpart, serverName)
}
@ -849,7 +850,7 @@ func (d *Database) GetPushers(
// Invoked when `append` is true and `kind` is null in
// https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-pushers-set
func (d *Database) RemovePusher(
ctx context.Context, appid, pushkey, localpart string, serverName gomatrixserverlib.ServerName,
ctx context.Context, appid, pushkey, localpart string, serverName spec.ServerName,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
err := d.Pushers.DeletePusher(ctx, txn, appid, pushkey, localpart, serverName)
@ -876,14 +877,14 @@ func (d *Database) UserStatistics(ctx context.Context) (*types.UserStatistics, *
return d.Stats.UserStatistics(ctx, nil)
}
func (d *Database) UpsertDailyRoomsMessages(ctx context.Context, serverName gomatrixserverlib.ServerName, stats types.MessageStats, activeRooms, activeE2EERooms int64) error {
func (d *Database) UpsertDailyRoomsMessages(ctx context.Context, serverName spec.ServerName, stats types.MessageStats, activeRooms, activeE2EERooms int64) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Stats.UpsertDailyStats(ctx, txn, serverName, stats, activeRooms, activeE2EERooms)
})
}
func (d *Database) DailyRoomsMessages(
ctx context.Context, serverName gomatrixserverlib.ServerName,
ctx context.Context, serverName spec.ServerName,
) (stats types.MessageStats, activeRooms, activeE2EERooms int64, err error) {
return d.Stats.DailyRoomsMessages(ctx, nil, serverName)
}
@ -996,7 +997,7 @@ func (d *KeyDatabase) KeyChanges(ctx context.Context, fromOffset, toOffset int64
// StaleDeviceLists returns a list of user IDs ending with the domains provided who have stale device lists.
// If no domains are given, all user IDs with stale device lists are returned.
func (d *KeyDatabase) StaleDeviceLists(ctx context.Context, domains []gomatrixserverlib.ServerName) ([]string, error) {
func (d *KeyDatabase) StaleDeviceLists(ctx context.Context, domains []spec.ServerName) ([]string, error) {
return d.StaleDeviceListsTable.SelectUserIDsWithStaleDeviceLists(ctx, domains)
}
@ -1038,7 +1039,7 @@ func (d *KeyDatabase) CrossSigningKeysForUser(ctx context.Context, userID string
result := fclient.CrossSigningKey{
UserID: userID,
Usage: []fclient.CrossSigningKeyPurpose{purpose},
Keys: map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{
Keys: map[gomatrixserverlib.KeyID]spec.Base64Bytes{
keyID: key,
},
}
@ -1051,10 +1052,10 @@ func (d *KeyDatabase) CrossSigningKeysForUser(ctx context.Context, userID string
continue
}
if result.Signatures == nil {
result.Signatures = map[string]map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
result.Signatures = map[string]map[gomatrixserverlib.KeyID]spec.Base64Bytes{}
}
if _, ok := result.Signatures[sigUserID]; !ok {
result.Signatures[sigUserID] = map[gomatrixserverlib.KeyID]gomatrixserverlib.Base64Bytes{}
result.Signatures[sigUserID] = map[gomatrixserverlib.KeyID]spec.Base64Bytes{}
}
for sigKeyID, sigBytes := range forSigUserID {
result.Signatures[sigUserID][sigKeyID] = sigBytes
@ -1092,7 +1093,7 @@ func (d *KeyDatabase) StoreCrossSigningSigsForTarget(
ctx context.Context,
originUserID string, originKeyID gomatrixserverlib.KeyID,
targetUserID string, targetKeyID gomatrixserverlib.KeyID,
signature gomatrixserverlib.Base64Bytes,
signature spec.Base64Bytes,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if err := d.CrossSigningSigsTable.UpsertCrossSigningSigsForTarget(ctx, nil, originUserID, originKeyID, targetUserID, targetKeyID, signature); err != nil {