Add UserAPI storage tests (#2384)

* Add tests for parts of the userapi storage

* Add tests for keybackup

* Add LoginToken tests

* Add OpenID tests

* Add profile tests

* Add pusher tests

* Add ThreePID tests

* Add notification tests

* Add more device tests, fix numeric localpart query

* Fix failing CI

* Fix numeric local part query
This commit is contained in:
Till 2022-04-27 15:05:49 +02:00 committed by GitHub
parent d7cc187ec0
commit f023cdf8c4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 640 additions and 77 deletions

1
go.mod
View file

@ -47,6 +47,7 @@ require (
github.com/pressly/goose v2.7.0+incompatible github.com/pressly/goose v2.7.0+incompatible
github.com/prometheus/client_golang v1.12.1 github.com/prometheus/client_golang v1.12.1
github.com/sirupsen/logrus v1.8.1 github.com/sirupsen/logrus v1.8.1
github.com/stretchr/testify v1.7.0
github.com/tidwall/gjson v1.14.0 github.com/tidwall/gjson v1.14.0
github.com/tidwall/sjson v1.2.4 github.com/tidwall/sjson v1.2.4
github.com/uber/jaeger-client-go v2.30.0+incompatible github.com/uber/jaeger-client-go v2.30.0+incompatible

View file

@ -21,6 +21,7 @@ import (
"io" "io"
"net" "net"
"net/http" "net/http"
_ "net/http/pprof"
"os" "os"
"os/signal" "os/signal"
"syscall" "syscall"
@ -56,8 +57,6 @@ import (
userapi "github.com/matrix-org/dendrite/userapi/api" userapi "github.com/matrix-org/dendrite/userapi/api"
userapiinthttp "github.com/matrix-org/dendrite/userapi/inthttp" userapiinthttp "github.com/matrix-org/dendrite/userapi/inthttp"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
_ "net/http/pprof"
) )
// BaseDendrite is a base for creating new instances of dendrite. It parses // BaseDendrite is a base for creating new instances of dendrite. It parses
@ -273,7 +272,7 @@ func (b *BaseDendrite) PushGatewayHTTPClient() pushgateway.Client {
// CreateAccountsDB creates a new instance of the accounts database. Should only // CreateAccountsDB creates a new instance of the accounts database. Should only
// be called once per component. // be called once per component.
func (b *BaseDendrite) CreateAccountsDB() userdb.Database { func (b *BaseDendrite) CreateAccountsDB() userdb.Database {
db, err := userdb.NewDatabase( db, err := userdb.NewUserAPIDatabase(
&b.Cfg.UserAPI.AccountDatabase, &b.Cfg.UserAPI.AccountDatabase,
b.Cfg.Global.ServerName, b.Cfg.Global.ServerName,
b.Cfg.UserAPI.BCryptCost, b.Cfg.UserAPI.BCryptCost,

View file

@ -27,18 +27,24 @@ import (
type Profile interface { type Profile interface {
GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error) GetProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error) SearchProfiles(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
SetPassword(ctx context.Context, localpart string, plaintextPassword string) error
SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error SetAvatarURL(ctx context.Context, localpart string, avatarURL string) error
SetDisplayName(ctx context.Context, localpart string, displayName string) error SetDisplayName(ctx context.Context, localpart string, displayName string) error
} }
type Database interface { type Account interface {
Profile
GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error)
// CreateAccount makes a new account with the given login name and password, and creates an empty profile // CreateAccount makes a new account with the given login name and password, and creates an empty profile
// for this account. If no password is supplied, the account will be a passwordless account. If the // for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, ErrUserExists. // account already exists, it will return nil, ErrUserExists.
CreateAccount(ctx context.Context, localpart string, plaintextPassword string, appserviceID string, accountType api.AccountType) (*api.Account, error) CreateAccount(ctx context.Context, localpart string, plaintextPassword string, appserviceID string, accountType api.AccountType) (*api.Account, error)
GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error)
GetNewNumericLocalpart(ctx context.Context) (int64, error)
CheckAccountAvailability(ctx context.Context, localpart string) (bool, error)
GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
DeactivateAccount(ctx context.Context, localpart string) (err error)
SetPassword(ctx context.Context, localpart string, plaintextPassword string) error
}
type AccountData interface {
SaveAccountData(ctx context.Context, localpart, roomID, dataType string, content json.RawMessage) error SaveAccountData(ctx context.Context, localpart, roomID, dataType string, content json.RawMessage) error
GetAccountData(ctx context.Context, localpart string) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error) GetAccountData(ctx context.Context, localpart string) (global map[string]json.RawMessage, rooms map[string]map[string]json.RawMessage, err error)
// GetAccountDataByType returns account data matching a given // GetAccountDataByType returns account data matching a given
@ -46,26 +52,9 @@ type Database interface {
// If no account data could be found, returns nil // If no account data could be found, returns nil
// Returns an error if there was an issue with the retrieval // Returns an error if there was an issue with the retrieval
GetAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data json.RawMessage, err error) GetAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data json.RawMessage, err error)
GetNewNumericLocalpart(ctx context.Context) (int64, error) }
SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error)
RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error)
GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, err error)
GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error)
CheckAccountAvailability(ctx context.Context, localpart string) (bool, error)
GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
DeactivateAccount(ctx context.Context, localpart string) (err error)
CreateOpenIDToken(ctx context.Context, token, localpart string) (exp int64, err error)
GetOpenIDTokenAttributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error)
// Key backups
CreateKeyBackup(ctx context.Context, userID, algorithm string, authData json.RawMessage) (version string, err error)
UpdateKeyBackupAuthData(ctx context.Context, userID, version string, authData json.RawMessage) (err error)
DeleteKeyBackup(ctx context.Context, userID, version string) (exists bool, err error)
GetKeyBackup(ctx context.Context, userID, version string) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error)
UpsertBackupKeys(ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession) (count int64, etag string, err error)
GetBackupKeys(ctx context.Context, version, userID, filterRoomID, filterSessionID string) (result map[string]map[string]api.KeyBackupSession, err error)
CountBackupKeys(ctx context.Context, version, userID string) (count int64, err error)
type Device interface {
GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error) GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error)
GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error) GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error)
GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error) GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error)
@ -79,11 +68,22 @@ type Database interface {
CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error) CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string) (dev *api.Device, returnErr error)
UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error
UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error
RemoveDevice(ctx context.Context, deviceID, localpart string) error
RemoveDevices(ctx context.Context, localpart string, devices []string) error RemoveDevices(ctx context.Context, localpart string, devices []string) error
// RemoveAllDevices deleted all devices for this user. Returns the devices deleted. // RemoveAllDevices deleted all devices for this user. Returns the devices deleted.
RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error) RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error)
}
type KeyBackup interface {
CreateKeyBackup(ctx context.Context, userID, algorithm string, authData json.RawMessage) (version string, err error)
UpdateKeyBackupAuthData(ctx context.Context, userID, version string, authData json.RawMessage) (err error)
DeleteKeyBackup(ctx context.Context, userID, version string) (exists bool, err error)
GetKeyBackup(ctx context.Context, userID, version string) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error)
UpsertBackupKeys(ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession) (count int64, etag string, err error)
GetBackupKeys(ctx context.Context, version, userID, filterRoomID, filterSessionID string) (result map[string]map[string]api.KeyBackupSession, err error)
CountBackupKeys(ctx context.Context, version, userID string) (count int64, err error)
}
type LoginToken interface {
// CreateLoginToken generates a token, stores and returns it. The lifetime is // CreateLoginToken generates a token, stores and returns it. The lifetime is
// determined by the loginTokenLifetime given to the Database constructor. // determined by the loginTokenLifetime given to the Database constructor.
CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error)
@ -94,21 +94,50 @@ type Database interface {
// GetLoginTokenDataByToken returns the data associated with the given token. // GetLoginTokenDataByToken returns the data associated with the given token.
// May return sql.ErrNoRows. // May return sql.ErrNoRows.
GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error)
}
InsertNotification(ctx context.Context, localpart, eventID string, pos int64, tweaks map[string]interface{}, n *api.Notification) error type OpenID interface {
DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos int64) (affected bool, err error) CreateOpenIDToken(ctx context.Context, token, userID string) (exp int64, err error)
SetNotificationsRead(ctx context.Context, localpart, roomID string, pos int64, b bool) (affected bool, err error) GetOpenIDTokenAttributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error)
GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) }
GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error)
GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error)
DeleteOldNotifications(ctx context.Context) error
type Pusher interface {
UpsertPusher(ctx context.Context, p api.Pusher, localpart string) error UpsertPusher(ctx context.Context, p api.Pusher, localpart string) error
GetPushers(ctx context.Context, localpart string) ([]api.Pusher, error) GetPushers(ctx context.Context, localpart string) ([]api.Pusher, error)
RemovePusher(ctx context.Context, appid, pushkey, localpart string) error RemovePusher(ctx context.Context, appid, pushkey, localpart string) error
RemovePushers(ctx context.Context, appid, pushkey string) error RemovePushers(ctx context.Context, appid, pushkey string) error
} }
type ThreePID interface {
SaveThreePIDAssociation(ctx context.Context, threepid, localpart, medium string) (err error)
RemoveThreePIDAssociation(ctx context.Context, threepid string, medium string) (err error)
GetLocalpartForThreePID(ctx context.Context, threepid string, medium string) (localpart string, err error)
GetThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error)
}
type Notification interface {
InsertNotification(ctx context.Context, localpart, eventID string, pos int64, tweaks map[string]interface{}, n *api.Notification) error
DeleteNotificationsUpTo(ctx context.Context, localpart, roomID string, pos int64) (affected bool, err error)
SetNotificationsRead(ctx context.Context, localpart, roomID string, pos int64, read bool) (affected bool, err error)
GetNotifications(ctx context.Context, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error)
GetNotificationCount(ctx context.Context, localpart string, filter tables.NotificationFilter) (int64, error)
GetRoomNotificationCounts(ctx context.Context, localpart, roomID string) (total int64, highlight int64, _ error)
DeleteOldNotifications(ctx context.Context) error
}
type Database interface {
Account
AccountData
Device
KeyBackup
LoginToken
Notification
OpenID
Profile
Pusher
ThreePID
}
// Err3PIDInUse is the error returned when trying to save an association involving // Err3PIDInUse is the error returned when trying to save an association involving
// a third-party identifier which is already associated to a local user. // a third-party identifier which is already associated to a local user.
var Err3PIDInUse = errors.New("this third-party identifier is already in use") var Err3PIDInUse = errors.New("this third-party identifier is already in use")

View file

@ -47,8 +47,6 @@ CREATE TABLE IF NOT EXISTS account_accounts (
-- TODO: -- TODO:
-- upgraded_ts, devices, any email reset stuff? -- upgraded_ts, devices, any email reset stuff?
); );
-- Create sequence for autogenerated numeric usernames
CREATE SEQUENCE IF NOT EXISTS numeric_username_seq START 1;
` `
const insertAccountSQL = "" + const insertAccountSQL = "" +
@ -67,7 +65,7 @@ const selectPasswordHashSQL = "" +
"SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = FALSE" "SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = FALSE"
const selectNewNumericLocalpartSQL = "" + const selectNewNumericLocalpartSQL = "" +
"SELECT nextval('numeric_username_seq')" "SELECT COALESCE(MAX(localpart::integer), 0) FROM account_accounts WHERE localpart ~ '^[0-9]*$'"
type accountsStatements struct { type accountsStatements struct {
insertAccountStmt *sql.Stmt insertAccountStmt *sql.Stmt
@ -178,5 +176,5 @@ func (s *accountsStatements) SelectNewNumericLocalpart(
stmt = sqlutil.TxStmt(txn, stmt) stmt = sqlutil.TxStmt(txn, stmt)
} }
err = stmt.QueryRowContext(ctx).Scan(&id) err = stmt.QueryRowContext(ctx).Scan(&id)
return return id + 1, err
} }

View file

@ -78,7 +78,7 @@ const selectDeviceByIDSQL = "" +
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2" "SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
const selectDevicesByLocalpartSQL = "" + const selectDevicesByLocalpartSQL = "" +
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2" "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC"
const updateDeviceNameSQL = "" + const updateDeviceNameSQL = "" +
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" "UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
@ -93,7 +93,7 @@ const deleteDevicesSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)" "DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)"
const selectDevicesByIDSQL = "" + const selectDevicesByIDSQL = "" +
"SELECT device_id, localpart, display_name FROM device_devices WHERE device_id = ANY($1)" "SELECT device_id, localpart, display_name, last_seen_ts FROM device_devices WHERE device_id = ANY($1) ORDER BY last_seen_ts DESC"
const updateDeviceLastSeen = "" + const updateDeviceLastSeen = "" +
"UPDATE device_devices SET last_seen_ts = $1, ip = $2 WHERE localpart = $3 AND device_id = $4" "UPDATE device_devices SET last_seen_ts = $1, ip = $2 WHERE localpart = $3 AND device_id = $4"
@ -235,16 +235,20 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByID: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByID: rows.close() failed")
var devices []api.Device var devices []api.Device
var dev api.Device
var localpart string
var lastseents sql.NullInt64
var displayName sql.NullString
for rows.Next() { for rows.Next() {
var dev api.Device if err := rows.Scan(&dev.ID, &localpart, &displayName, &lastseents); err != nil {
var localpart string
var displayName sql.NullString
if err := rows.Scan(&dev.ID, &localpart, &displayName); err != nil {
return nil, err return nil, err
} }
if displayName.Valid { if displayName.Valid {
dev.DisplayName = displayName.String dev.DisplayName = displayName.String
} }
if lastseents.Valid {
dev.LastSeenTS = lastseents.Int64
}
dev.UserID = userutil.MakeUserID(localpart, s.serverName) dev.UserID = userutil.MakeUserID(localpart, s.serverName)
devices = append(devices, dev) devices = append(devices, dev)
} }
@ -262,10 +266,10 @@ func (s *devicesStatements) SelectDevicesByLocalpart(
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByLocalpart: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByLocalpart: rows.close() failed")
var dev api.Device
var lastseents sql.NullInt64
var id, displayname, ip, useragent sql.NullString
for rows.Next() { for rows.Next() {
var dev api.Device
var lastseents sql.NullInt64
var id, displayname, ip, useragent sql.NullString
err = rows.Scan(&id, &displayname, &lastseents, &ip, &useragent) err = rows.Scan(&id, &displayname, &lastseents, &ip, &useragent)
if err != nil { if err != nil {
return devices, err return devices, err

View file

@ -577,21 +577,6 @@ func (d *Database) UpdateDevice(
}) })
} }
// RemoveDevice revokes a device by deleting the entry in the database
// matching with the given device ID and user ID localpart.
// If the device doesn't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevice(
ctx context.Context, deviceID, localpart string,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if err := d.Devices.DeleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveDevices revokes one or more devices by deleting the entry in the database // RemoveDevices revokes one or more devices by deleting the entry in the database
// matching with the given device IDs and user ID localpart. // matching with the given device IDs and user ID localpart.
// If the devices don't exist, it will not return an error // If the devices don't exist, it will not return an error

View file

@ -65,7 +65,7 @@ const selectPasswordHashSQL = "" +
"SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0" "SELECT password_hash FROM account_accounts WHERE localpart = $1 AND is_deactivated = 0"
const selectNewNumericLocalpartSQL = "" + const selectNewNumericLocalpartSQL = "" +
"SELECT COUNT(localpart) FROM account_accounts" "SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM account_accounts WHERE CAST(localpart AS INT) <> 0"
type accountsStatements struct { type accountsStatements struct {
db *sql.DB db *sql.DB
@ -121,6 +121,7 @@ func (s *accountsStatements) InsertAccount(
UserID: userutil.MakeUserID(localpart, s.serverName), UserID: userutil.MakeUserID(localpart, s.serverName),
ServerName: s.serverName, ServerName: s.serverName,
AppServiceID: appserviceID, AppServiceID: appserviceID,
AccountType: accountType,
}, nil }, nil
} }
@ -177,5 +178,8 @@ func (s *accountsStatements) SelectNewNumericLocalpart(
stmt = sqlutil.TxStmt(txn, stmt) stmt = sqlutil.TxStmt(txn, stmt)
} }
err = stmt.QueryRowContext(ctx).Scan(&id) err = stmt.QueryRowContext(ctx).Scan(&id)
return if err == sql.ErrNoRows {
return 1, nil
}
return id + 1, err
} }

View file

@ -63,7 +63,7 @@ const selectDeviceByIDSQL = "" +
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2" "SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
const selectDevicesByLocalpartSQL = "" + const selectDevicesByLocalpartSQL = "" +
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2" "SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM device_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC"
const updateDeviceNameSQL = "" + const updateDeviceNameSQL = "" +
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3" "UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
@ -78,7 +78,7 @@ const deleteDevicesSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)" "DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)"
const selectDevicesByIDSQL = "" + const selectDevicesByIDSQL = "" +
"SELECT device_id, localpart, display_name FROM device_devices WHERE device_id IN ($1)" "SELECT device_id, localpart, display_name, last_seen_ts FROM device_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC"
const updateDeviceLastSeen = "" + const updateDeviceLastSeen = "" +
"UPDATE device_devices SET last_seen_ts = $1, ip = $2 WHERE localpart = $3 AND device_id = $4" "UPDATE device_devices SET last_seen_ts = $1, ip = $2 WHERE localpart = $3 AND device_id = $4"
@ -235,10 +235,10 @@ func (s *devicesStatements) SelectDevicesByLocalpart(
return devices, err return devices, err
} }
var dev api.Device
var lastseents sql.NullInt64
var id, displayname, ip, useragent sql.NullString
for rows.Next() { for rows.Next() {
var dev api.Device
var lastseents sql.NullInt64
var id, displayname, ip, useragent sql.NullString
err = rows.Scan(&id, &displayname, &lastseents, &ip, &useragent) err = rows.Scan(&id, &displayname, &lastseents, &ip, &useragent)
if err != nil { if err != nil {
return devices, err return devices, err
@ -279,16 +279,20 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s
} }
defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByID: rows.close() failed") defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByID: rows.close() failed")
var devices []api.Device var devices []api.Device
var dev api.Device
var localpart string
var displayName sql.NullString
var lastseents sql.NullInt64
for rows.Next() { for rows.Next() {
var dev api.Device if err := rows.Scan(&dev.ID, &localpart, &displayName, &lastseents); err != nil {
var localpart string
var displayName sql.NullString
if err := rows.Scan(&dev.ID, &localpart, &displayName); err != nil {
return nil, err return nil, err
} }
if displayName.Valid { if displayName.Valid {
dev.DisplayName = displayName.String dev.DisplayName = displayName.String
} }
if lastseents.Valid {
dev.LastSeenTS = lastseents.Int64
}
dev.UserID = userutil.MakeUserID(localpart, s.serverName) dev.UserID = userutil.MakeUserID(localpart, s.serverName)
devices = append(devices, dev) devices = append(devices, dev)
} }

View file

@ -28,9 +28,9 @@ import (
"github.com/matrix-org/dendrite/userapi/storage/sqlite3" "github.com/matrix-org/dendrite/userapi/storage/sqlite3"
) )
// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) // NewUserAPIDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
// and sets postgres connection parameters // and sets postgres connection parameters
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (Database, error) { func NewUserAPIDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration, serverNoticesLocalpart string) (Database, error) {
switch { switch {
case dbProperties.ConnectionString.IsSQLite(): case dbProperties.ConnectionString.IsSQLite():
return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart) return sqlite3.NewDatabase(dbProperties, serverName, bcryptCost, openIDTokenLifetimeMS, loginTokenLifetime, serverNoticesLocalpart)

View file

@ -0,0 +1,539 @@
package storage_test
import (
"context"
"encoding/json"
"fmt"
"testing"
"time"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal/pushrules"
"github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/test"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib"
"github.com/matrix-org/util"
"github.com/stretchr/testify/assert"
"golang.org/x/crypto/bcrypt"
)
const loginTokenLifetime = time.Minute
var (
openIDLifetimeMS = time.Minute.Milliseconds()
ctx = context.Background()
)
func mustCreateDatabase(t *testing.T, dbType test.DBType) (storage.Database, func()) {
connStr, close := test.PrepareDBConnectionString(t, dbType)
db, err := storage.NewUserAPIDatabase(&config.DatabaseOptions{
ConnectionString: config.DataSource(connStr),
}, "localhost", bcrypt.MinCost, openIDLifetimeMS, loginTokenLifetime, "_server")
if err != nil {
t.Fatalf("NewUserAPIDatabase returned %s", err)
}
return db, close
}
// Tests storing and getting account data
func Test_AccountData(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType)
defer close()
alice := test.NewUser()
localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err)
room := test.NewRoom(t, alice)
events := room.Events()
contentRoom := json.RawMessage(fmt.Sprintf(`{"event_id":"%s"}`, events[len(events)-1].EventID()))
err = db.SaveAccountData(ctx, localpart, room.ID, "m.fully_read", contentRoom)
assert.NoError(t, err, "unable to save account data")
contentGlobal := json.RawMessage(fmt.Sprintf(`{"recent_rooms":["%s"]}`, room.ID))
err = db.SaveAccountData(ctx, localpart, "", "im.vector.setting.breadcrumbs", contentGlobal)
assert.NoError(t, err, "unable to save account data")
accountData, err := db.GetAccountDataByType(ctx, localpart, room.ID, "m.fully_read")
assert.NoError(t, err, "unable to get account data by type")
assert.Equal(t, contentRoom, accountData)
globalData, roomData, err := db.GetAccountData(ctx, localpart)
assert.NoError(t, err)
assert.Equal(t, contentRoom, roomData[room.ID]["m.fully_read"])
assert.Equal(t, contentGlobal, globalData["im.vector.setting.breadcrumbs"])
})
}
// Tests the creation of accounts
func Test_Accounts(t *testing.T) {
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType)
defer close()
alice := test.NewUser()
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err)
accAlice, err := db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin)
assert.NoError(t, err, "failed to create account")
// verify the newly create account is the same as returned by CreateAccount
var accGet *api.Account
accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, "testing")
assert.NoError(t, err, "failed to get account by password")
assert.Equal(t, accAlice, accGet)
accGet, err = db.GetAccountByLocalpart(ctx, aliceLocalpart)
assert.NoError(t, err, "failed to get account by localpart")
assert.Equal(t, accAlice, accGet)
// check account availability
available, err := db.CheckAccountAvailability(ctx, aliceLocalpart)
assert.NoError(t, err, "failed to checkout account availability")
assert.Equal(t, false, available)
available, err = db.CheckAccountAvailability(ctx, "unusedname")
assert.NoError(t, err, "failed to checkout account availability")
assert.Equal(t, true, available)
// get guest account numeric aliceLocalpart
first, err := db.GetNewNumericLocalpart(ctx)
assert.NoError(t, err, "failed to get new numeric localpart")
// Create a new account to verify the numeric localpart is updated
_, err = db.CreateAccount(ctx, "", "testing", "", api.AccountTypeGuest)
assert.NoError(t, err, "failed to create account")
second, err := db.GetNewNumericLocalpart(ctx)
assert.NoError(t, err)
assert.Greater(t, second, first)
// update password for alice
err = db.SetPassword(ctx, aliceLocalpart, "newPassword")
assert.NoError(t, err, "failed to update password")
accGet, err = db.GetAccountByPassword(ctx, aliceLocalpart, "newPassword")
assert.NoError(t, err, "failed to get account by new password")
assert.Equal(t, accAlice, accGet)
// deactivate account
err = db.DeactivateAccount(ctx, aliceLocalpart)
assert.NoError(t, err, "failed to deactivate account")
// This should fail now, as the account is deactivated
_, err = db.GetAccountByPassword(ctx, aliceLocalpart, "newPassword")
assert.Error(t, err, "expected an error, got none")
_, err = db.GetAccountByLocalpart(ctx, "unusename")
assert.Error(t, err, "expected an error for non existent localpart")
})
}
func Test_Devices(t *testing.T) {
alice := test.NewUser()
localpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err)
deviceID := util.RandomString(8)
accessToken := util.RandomString(16)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType)
defer close()
deviceWithID, err := db.CreateDevice(ctx, localpart, &deviceID, accessToken, nil, "", "")
assert.NoError(t, err, "unable to create deviceWithoutID")
gotDevice, err := db.GetDeviceByID(ctx, localpart, deviceID)
assert.NoError(t, err, "unable to get device by id")
assert.Equal(t, deviceWithID.ID, gotDevice.ID) // GetDeviceByID doesn't populate all fields
gotDeviceAccessToken, err := db.GetDeviceByAccessToken(ctx, accessToken)
assert.NoError(t, err, "unable to get device by access token")
assert.Equal(t, deviceWithID.ID, gotDeviceAccessToken.ID) // GetDeviceByAccessToken doesn't populate all fields
// create a device without existing device ID
accessToken = util.RandomString(16)
deviceWithoutID, err := db.CreateDevice(ctx, localpart, nil, accessToken, nil, "", "")
assert.NoError(t, err, "unable to create deviceWithoutID")
gotDeviceWithoutID, err := db.GetDeviceByID(ctx, localpart, deviceWithoutID.ID)
assert.NoError(t, err, "unable to get device by id")
assert.Equal(t, deviceWithoutID.ID, gotDeviceWithoutID.ID) // GetDeviceByID doesn't populate all fields
// Get devices
devices, err := db.GetDevicesByLocalpart(ctx, localpart)
assert.NoError(t, err, "unable to get devices by localpart")
assert.Equal(t, 2, len(devices))
deviceIDs := make([]string, 0, len(devices))
for _, dev := range devices {
deviceIDs = append(deviceIDs, dev.ID)
}
devices2, err := db.GetDevicesByID(ctx, deviceIDs)
assert.NoError(t, err, "unable to get devices by id")
assert.Equal(t, devices, devices2)
// Update device
newName := "new display name"
err = db.UpdateDevice(ctx, localpart, deviceWithID.ID, &newName)
assert.NoError(t, err, "unable to update device displayname")
err = db.UpdateDeviceLastSeen(ctx, localpart, deviceWithID.ID, "127.0.0.1")
assert.NoError(t, err, "unable to update device last seen")
deviceWithID.DisplayName = newName
deviceWithID.LastSeenIP = "127.0.0.1"
deviceWithID.LastSeenTS = int64(gomatrixserverlib.AsTimestamp(time.Now().Truncate(time.Second)))
devices, err = db.GetDevicesByLocalpart(ctx, localpart)
assert.NoError(t, err, "unable to get device by id")
assert.Equal(t, 2, len(devices))
assert.Equal(t, deviceWithID.DisplayName, devices[0].DisplayName)
assert.Equal(t, deviceWithID.LastSeenIP, devices[0].LastSeenIP)
truncatedTime := gomatrixserverlib.Timestamp(devices[0].LastSeenTS).Time().Truncate(time.Second)
assert.Equal(t, gomatrixserverlib.Timestamp(deviceWithID.LastSeenTS), gomatrixserverlib.AsTimestamp(truncatedTime))
// create one more device and remove the devices step by step
newDeviceID := util.RandomString(16)
accessToken = util.RandomString(16)
_, err = db.CreateDevice(ctx, localpart, &newDeviceID, accessToken, nil, "", "")
assert.NoError(t, err, "unable to create new device")
devices, err = db.GetDevicesByLocalpart(ctx, localpart)
assert.NoError(t, err, "unable to get device by id")
assert.Equal(t, 3, len(devices))
err = db.RemoveDevices(ctx, localpart, deviceIDs)
assert.NoError(t, err, "unable to remove devices")
devices, err = db.GetDevicesByLocalpart(ctx, localpart)
assert.NoError(t, err, "unable to get device by id")
assert.Equal(t, 1, len(devices))
deleted, err := db.RemoveAllDevices(ctx, localpart, "")
assert.NoError(t, err, "unable to remove all devices")
assert.Equal(t, 1, len(deleted))
assert.Equal(t, newDeviceID, deleted[0].ID)
})
}
func Test_KeyBackup(t *testing.T) {
alice := test.NewUser()
room := test.NewRoom(t, alice)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType)
defer close()
wantAuthData := json.RawMessage("my auth data")
wantVersion, err := db.CreateKeyBackup(ctx, alice.ID, "dummyAlgo", wantAuthData)
assert.NoError(t, err, "unable to create key backup")
// get key backup by version
gotVersion, gotAlgo, gotAuthData, _, _, err := db.GetKeyBackup(ctx, alice.ID, wantVersion)
assert.NoError(t, err, "unable to get key backup")
assert.Equal(t, wantVersion, gotVersion, "backup version mismatch")
assert.Equal(t, "dummyAlgo", gotAlgo, "backup algorithm mismatch")
assert.Equal(t, wantAuthData, gotAuthData, "backup auth data mismatch")
// get any key backup
gotVersion, gotAlgo, gotAuthData, _, _, err = db.GetKeyBackup(ctx, alice.ID, "")
assert.NoError(t, err, "unable to get key backup")
assert.Equal(t, wantVersion, gotVersion, "backup version mismatch")
assert.Equal(t, "dummyAlgo", gotAlgo, "backup algorithm mismatch")
assert.Equal(t, wantAuthData, gotAuthData, "backup auth data mismatch")
err = db.UpdateKeyBackupAuthData(ctx, alice.ID, wantVersion, json.RawMessage("my updated auth data"))
assert.NoError(t, err, "unable to update key backup auth data")
uploads := []api.InternalKeyBackupSession{
{
KeyBackupSession: api.KeyBackupSession{
IsVerified: true,
SessionData: wantAuthData,
},
RoomID: room.ID,
SessionID: "1",
},
{
KeyBackupSession: api.KeyBackupSession{},
RoomID: room.ID,
SessionID: "2",
},
}
count, _, err := db.UpsertBackupKeys(ctx, wantVersion, alice.ID, uploads)
assert.NoError(t, err, "unable to upsert backup keys")
assert.Equal(t, int64(len(uploads)), count, "unexpected backup count")
// do it again to update a key
uploads[1].IsVerified = true
count, _, err = db.UpsertBackupKeys(ctx, wantVersion, alice.ID, uploads[1:])
assert.NoError(t, err, "unable to upsert backup keys")
assert.Equal(t, int64(len(uploads)), count, "unexpected backup count")
// get backup keys by session id
gotBackupKeys, err := db.GetBackupKeys(ctx, wantVersion, alice.ID, room.ID, "1")
assert.NoError(t, err, "unable to get backup keys")
assert.Equal(t, uploads[0].KeyBackupSession, gotBackupKeys[room.ID]["1"])
// get backup keys by room id
gotBackupKeys, err = db.GetBackupKeys(ctx, wantVersion, alice.ID, room.ID, "")
assert.NoError(t, err, "unable to get backup keys")
assert.Equal(t, uploads[0].KeyBackupSession, gotBackupKeys[room.ID]["1"])
gotCount, err := db.CountBackupKeys(ctx, wantVersion, alice.ID)
assert.NoError(t, err, "unable to get backup keys count")
assert.Equal(t, count, gotCount, "unexpected backup count")
// finally delete a key
exists, err := db.DeleteKeyBackup(ctx, alice.ID, wantVersion)
assert.NoError(t, err, "unable to delete key backup")
assert.True(t, exists)
// this key should not exist
exists, err = db.DeleteKeyBackup(ctx, alice.ID, "3")
assert.NoError(t, err, "unable to delete key backup")
assert.False(t, exists)
})
}
func Test_LoginToken(t *testing.T) {
alice := test.NewUser()
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType)
defer close()
// create a new token
wantLoginToken := &api.LoginTokenData{UserID: alice.ID}
gotMetadata, err := db.CreateLoginToken(ctx, wantLoginToken)
assert.NoError(t, err, "unable to create login token")
assert.NotNil(t, gotMetadata)
assert.Equal(t, time.Now().Add(loginTokenLifetime).Truncate(loginTokenLifetime), gotMetadata.Expiration.Truncate(loginTokenLifetime))
// get the new token
gotLoginToken, err := db.GetLoginTokenDataByToken(ctx, gotMetadata.Token)
assert.NoError(t, err, "unable to get login token")
assert.NotNil(t, gotLoginToken)
assert.Equal(t, wantLoginToken, gotLoginToken, "unexpected login token")
// remove the login token again
err = db.RemoveLoginToken(ctx, gotMetadata.Token)
assert.NoError(t, err, "unable to remove login token")
// check if the token was actually deleted
_, err = db.GetLoginTokenDataByToken(ctx, gotMetadata.Token)
assert.Error(t, err, "expected an error, but got none")
})
}
func Test_OpenID(t *testing.T) {
alice := test.NewUser()
token := util.RandomString(24)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType)
defer close()
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + openIDLifetimeMS
expires, err := db.CreateOpenIDToken(ctx, token, alice.ID)
assert.NoError(t, err, "unable to create OpenID token")
assert.Equal(t, expiresAtMS, expires)
attributes, err := db.GetOpenIDTokenAttributes(ctx, token)
assert.NoError(t, err, "unable to get OpenID token attributes")
assert.Equal(t, alice.ID, attributes.UserID)
assert.Equal(t, expiresAtMS, attributes.ExpiresAtMS)
})
}
func Test_Profile(t *testing.T) {
alice := test.NewUser()
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType)
defer close()
// create account, which also creates a profile
_, err = db.CreateAccount(ctx, aliceLocalpart, "testing", "", api.AccountTypeAdmin)
assert.NoError(t, err, "failed to create account")
gotProfile, err := db.GetProfileByLocalpart(ctx, aliceLocalpart)
assert.NoError(t, err, "unable to get profile by localpart")
wantProfile := &authtypes.Profile{Localpart: aliceLocalpart}
assert.Equal(t, wantProfile, gotProfile)
// set avatar & displayname
wantProfile.DisplayName = "Alice"
wantProfile.AvatarURL = "mxc://aliceAvatar"
err = db.SetDisplayName(ctx, aliceLocalpart, "Alice")
assert.NoError(t, err, "unable to set displayname")
err = db.SetAvatarURL(ctx, aliceLocalpart, "mxc://aliceAvatar")
assert.NoError(t, err, "unable to set avatar url")
// verify profile
gotProfile, err = db.GetProfileByLocalpart(ctx, aliceLocalpart)
assert.NoError(t, err, "unable to get profile by localpart")
assert.Equal(t, wantProfile, gotProfile)
// search profiles
searchRes, err := db.SearchProfiles(ctx, "Alice", 2)
assert.NoError(t, err, "unable to search profiles")
assert.Equal(t, 1, len(searchRes))
assert.Equal(t, *wantProfile, searchRes[0])
})
}
func Test_Pusher(t *testing.T) {
alice := test.NewUser()
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType)
defer close()
appID := util.RandomString(8)
var pushKeys []string
var gotPushers []api.Pusher
for i := 0; i < 2; i++ {
pushKey := util.RandomString(8)
wantPusher := api.Pusher{
PushKey: pushKey,
Kind: api.HTTPKind,
AppID: appID,
AppDisplayName: util.RandomString(8),
DeviceDisplayName: util.RandomString(8),
ProfileTag: util.RandomString(8),
Language: util.RandomString(2),
}
err = db.UpsertPusher(ctx, wantPusher, aliceLocalpart)
assert.NoError(t, err, "unable to upsert pusher")
// check it was actually persisted
gotPushers, err = db.GetPushers(ctx, aliceLocalpart)
assert.NoError(t, err, "unable to get pushers")
assert.Equal(t, i+1, len(gotPushers))
assert.Equal(t, wantPusher, gotPushers[i])
pushKeys = append(pushKeys, pushKey)
}
// remove single pusher
err = db.RemovePusher(ctx, appID, pushKeys[0], aliceLocalpart)
assert.NoError(t, err, "unable to remove pusher")
gotPushers, err := db.GetPushers(ctx, aliceLocalpart)
assert.NoError(t, err, "unable to get pushers")
assert.Equal(t, 1, len(gotPushers))
// remove last pusher
err = db.RemovePushers(ctx, appID, pushKeys[1])
assert.NoError(t, err, "unable to remove pusher")
gotPushers, err = db.GetPushers(ctx, aliceLocalpart)
assert.NoError(t, err, "unable to get pushers")
assert.Equal(t, 0, len(gotPushers))
})
}
func Test_ThreePID(t *testing.T) {
alice := test.NewUser()
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType)
defer close()
threePID := util.RandomString(8)
medium := util.RandomString(8)
err = db.SaveThreePIDAssociation(ctx, threePID, aliceLocalpart, medium)
assert.NoError(t, err, "unable to save threepid association")
// get the stored threepid
gotLocalpart, err := db.GetLocalpartForThreePID(ctx, threePID, medium)
assert.NoError(t, err, "unable to get localpart for threepid")
assert.Equal(t, aliceLocalpart, gotLocalpart)
threepids, err := db.GetThreePIDsForLocalpart(ctx, aliceLocalpart)
assert.NoError(t, err, "unable to get threepids for localpart")
assert.Equal(t, 1, len(threepids))
assert.Equal(t, authtypes.ThreePID{
Address: threePID,
Medium: medium,
}, threepids[0])
// remove threepid association
err = db.RemoveThreePIDAssociation(ctx, threePID, medium)
assert.NoError(t, err, "unexpected error")
// verify it was deleted
threepids, err = db.GetThreePIDsForLocalpart(ctx, aliceLocalpart)
assert.NoError(t, err, "unable to get threepids for localpart")
assert.Equal(t, 0, len(threepids))
})
}
func Test_Notification(t *testing.T) {
alice := test.NewUser()
aliceLocalpart, _, err := gomatrixserverlib.SplitID('@', alice.ID)
assert.NoError(t, err)
room := test.NewRoom(t, alice)
room2 := test.NewRoom(t, alice)
test.WithAllDatabases(t, func(t *testing.T, dbType test.DBType) {
db, close := mustCreateDatabase(t, dbType)
defer close()
// generate some dummy notifications
for i := 0; i < 10; i++ {
eventID := util.RandomString(16)
roomID := room.ID
ts := time.Now()
if i > 5 {
roomID = room2.ID
// create some old notifications to test DeleteOldNotifications
ts = ts.AddDate(0, -2, 0)
}
notification := &api.Notification{
Actions: []*pushrules.Action{
{},
},
Event: gomatrixserverlib.ClientEvent{
Content: gomatrixserverlib.RawJSON("{}"),
},
Read: false,
RoomID: roomID,
TS: gomatrixserverlib.AsTimestamp(ts),
}
err = db.InsertNotification(ctx, aliceLocalpart, eventID, int64(i+1), nil, notification)
assert.NoError(t, err, "unable to insert notification")
}
// get notifications
count, err := db.GetNotificationCount(ctx, aliceLocalpart, tables.AllNotifications)
assert.NoError(t, err, "unable to get notification count")
assert.Equal(t, int64(10), count)
notifs, count, err := db.GetNotifications(ctx, aliceLocalpart, 0, 15, tables.AllNotifications)
assert.NoError(t, err, "unable to get notifications")
assert.Equal(t, int64(10), count)
assert.Equal(t, 10, len(notifs))
// ... for a specific room
total, _, err := db.GetRoomNotificationCounts(ctx, aliceLocalpart, room2.ID)
assert.NoError(t, err, "unable to get notifications for room")
assert.Equal(t, int64(4), total)
// mark notification as read
affected, err := db.SetNotificationsRead(ctx, aliceLocalpart, room2.ID, 7, true)
assert.NoError(t, err, "unable to set notifications read")
assert.True(t, affected)
// this should delete 2 notifications
affected, err = db.DeleteNotificationsUpTo(ctx, aliceLocalpart, room2.ID, 8)
assert.NoError(t, err, "unable to set notifications read")
assert.True(t, affected)
total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, room2.ID)
assert.NoError(t, err, "unable to get notifications for room")
assert.Equal(t, int64(2), total)
// delete old notifications
err = db.DeleteOldNotifications(ctx)
assert.NoError(t, err)
// this should now return 0 notifications
total, _, err = db.GetRoomNotificationCounts(ctx, aliceLocalpart, room2.ID)
assert.NoError(t, err, "unable to get notifications for room")
assert.Equal(t, int64(0), total)
})
}

View file

@ -23,7 +23,7 @@ import (
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
func NewDatabase( func NewUserAPIDatabase(
dbProperties *config.DatabaseOptions, dbProperties *config.DatabaseOptions,
serverName gomatrixserverlib.ServerName, serverName gomatrixserverlib.ServerName,
bcryptCost int, bcryptCost int,

View file

@ -52,7 +52,7 @@ func MustMakeInternalAPI(t *testing.T, opts apiTestOpts) (api.UserInternalAPI, s
MaxOpenConnections: 1, MaxOpenConnections: 1,
MaxIdleConnections: 1, MaxIdleConnections: 1,
} }
accountDB, err := storage.NewDatabase(dbopts, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime, "") accountDB, err := storage.NewUserAPIDatabase(dbopts, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS, opts.loginTokenLifetime, "")
if err != nil { if err != nil {
t.Fatalf("failed to create account DB: %s", err) t.Fatalf("failed to create account DB: %s", err)
} }