Refactor user API storage (#2202)

* Refactor User API database

* Fix migration bugs
This commit is contained in:
Neil Alexander 2022-02-18 13:51:59 +00:00 committed by GitHub
parent 9bd5e414c9
commit 9f4a39e8e0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 1165 additions and 1671 deletions

View file

@ -21,6 +21,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
) )
const accountDataSchema = ` const accountDataSchema = `
@ -56,19 +57,20 @@ type accountDataStatements struct {
selectAccountDataByTypeStmt *sql.Stmt selectAccountDataByTypeStmt *sql.Stmt
} }
func (s *accountDataStatements) prepare(db *sql.DB) (err error) { func NewPostgresAccountDataTable(db *sql.DB) (tables.AccountDataTable, error) {
_, err = db.Exec(accountDataSchema) s := &accountDataStatements{}
_, err := db.Exec(accountDataSchema)
if err != nil { if err != nil {
return return nil, err
} }
return sqlutil.StatementList{ return s, sqlutil.StatementList{
{&s.insertAccountDataStmt, insertAccountDataSQL}, {&s.insertAccountDataStmt, insertAccountDataSQL},
{&s.selectAccountDataStmt, selectAccountDataSQL}, {&s.selectAccountDataStmt, selectAccountDataSQL},
{&s.selectAccountDataByTypeStmt, selectAccountDataByTypeSQL}, {&s.selectAccountDataByTypeStmt, selectAccountDataByTypeSQL},
}.Prepare(db) }.Prepare(db)
} }
func (s *accountDataStatements) insertAccountData( func (s *accountDataStatements) InsertAccountData(
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage, ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
) (err error) { ) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertAccountDataStmt) stmt := sqlutil.TxStmt(txn, s.insertAccountDataStmt)
@ -76,7 +78,7 @@ func (s *accountDataStatements) insertAccountData(
return return
} }
func (s *accountDataStatements) selectAccountData( func (s *accountDataStatements) SelectAccountData(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) ( ) (
/* global */ map[string]json.RawMessage, /* global */ map[string]json.RawMessage,
@ -114,7 +116,7 @@ func (s *accountDataStatements) selectAccountData(
return global, rooms, rows.Err() return global, rooms, rows.Err()
} }
func (s *accountDataStatements) selectAccountDataByType( func (s *accountDataStatements) SelectAccountDataByType(
ctx context.Context, localpart, roomID, dataType string, ctx context.Context, localpart, roomID, dataType string,
) (data json.RawMessage, err error) { ) (data json.RawMessage, err error) {
var bytes []byte var bytes []byte

View file

@ -24,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@ -78,14 +79,15 @@ type accountsStatements struct {
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
} }
func (s *accountsStatements) execSchema(db *sql.DB) error { func NewPostgresAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.AccountsTable, error) {
_, err := db.Exec(accountsSchema) s := &accountsStatements{
return err serverName: serverName,
} }
_, err := db.Exec(accountsSchema)
func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { if err != nil {
s.serverName = server return nil, err
return sqlutil.StatementList{ }
return s, sqlutil.StatementList{
{&s.insertAccountStmt, insertAccountSQL}, {&s.insertAccountStmt, insertAccountSQL},
{&s.updatePasswordStmt, updatePasswordSQL}, {&s.updatePasswordStmt, updatePasswordSQL},
{&s.deactivateAccountStmt, deactivateAccountSQL}, {&s.deactivateAccountStmt, deactivateAccountSQL},
@ -98,7 +100,7 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
// insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing, // insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing,
// this account will be passwordless. Returns an error if this account already exists. Returns the account // this account will be passwordless. Returns an error if this account already exists. Returns the account
// on success. // on success.
func (s *accountsStatements) insertAccount( func (s *accountsStatements) InsertAccount(
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType, ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType,
) (*api.Account, error) { ) (*api.Account, error) {
createdTimeMS := time.Now().UnixNano() / 1000000 createdTimeMS := time.Now().UnixNano() / 1000000
@ -123,28 +125,28 @@ func (s *accountsStatements) insertAccount(
}, nil }, nil
} }
func (s *accountsStatements) updatePassword( func (s *accountsStatements) UpdatePassword(
ctx context.Context, localpart, passwordHash string, ctx context.Context, localpart, passwordHash string,
) (err error) { ) (err error) {
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart) _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart)
return return
} }
func (s *accountsStatements) deactivateAccount( func (s *accountsStatements) DeactivateAccount(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) (err error) { ) (err error) {
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart) _, err = s.deactivateAccountStmt.ExecContext(ctx, localpart)
return return
} }
func (s *accountsStatements) selectPasswordHash( func (s *accountsStatements) SelectPasswordHash(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) (hash string, err error) { ) (hash string, err error) {
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash) err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash)
return return
} }
func (s *accountsStatements) selectAccountByLocalpart( func (s *accountsStatements) SelectAccountByLocalpart(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) (*api.Account, error) { ) (*api.Account, error) {
var appserviceIDPtr sql.NullString var appserviceIDPtr sql.NullString
@ -168,7 +170,7 @@ func (s *accountsStatements) selectAccountByLocalpart(
return &acc, nil return &acc, nil
} }
func (s *accountsStatements) selectNewNumericLocalpart( func (s *accountsStatements) SelectNewNumericLocalpart(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
) (id int64, err error) { ) (id int64, err error) {
stmt := s.selectNewNumericLocalpartStmt stmt := s.selectNewNumericLocalpartStmt

View file

@ -24,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
) )
@ -111,53 +112,32 @@ type devicesStatements struct {
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
} }
func (s *devicesStatements) execSchema(db *sql.DB) error { func NewPostgresDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.DevicesTable, error) {
s := &devicesStatements{
serverName: serverName,
}
_, err := db.Exec(devicesSchema) _, err := db.Exec(devicesSchema)
return err if err != nil {
return nil, err
} }
return s, sqlutil.StatementList{
func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { {&s.insertDeviceStmt, insertDeviceSQL},
if err = s.execSchema(db); err != nil { {&s.selectDeviceByTokenStmt, selectDeviceByTokenSQL},
return {&s.selectDeviceByIDStmt, selectDeviceByIDSQL},
} {&s.selectDevicesByLocalpartStmt, selectDevicesByLocalpartSQL},
if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil { {&s.updateDeviceNameStmt, updateDeviceNameSQL},
return {&s.deleteDeviceStmt, deleteDeviceSQL},
} {&s.deleteDevicesByLocalpartStmt, deleteDevicesByLocalpartSQL},
if s.selectDeviceByTokenStmt, err = db.Prepare(selectDeviceByTokenSQL); err != nil { {&s.deleteDevicesStmt, deleteDevicesSQL},
return {&s.selectDevicesByIDStmt, selectDevicesByIDSQL},
} {&s.updateDeviceLastSeenStmt, updateDeviceLastSeen},
if s.selectDeviceByIDStmt, err = db.Prepare(selectDeviceByIDSQL); err != nil { }.Prepare(db)
return
}
if s.selectDevicesByLocalpartStmt, err = db.Prepare(selectDevicesByLocalpartSQL); err != nil {
return
}
if s.updateDeviceNameStmt, err = db.Prepare(updateDeviceNameSQL); err != nil {
return
}
if s.deleteDeviceStmt, err = db.Prepare(deleteDeviceSQL); err != nil {
return
}
if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil {
return
}
if s.deleteDevicesStmt, err = db.Prepare(deleteDevicesSQL); err != nil {
return
}
if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil {
return
}
if s.updateDeviceLastSeenStmt, err = db.Prepare(updateDeviceLastSeen); err != nil {
return
}
s.serverName = server
return
} }
// insertDevice creates a new device. Returns an error if any device with the same access token already exists. // insertDevice creates a new device. Returns an error if any device with the same access token already exists.
// Returns an error if the user already has a device with the given device ID. // Returns an error if the user already has a device with the given device ID.
// Returns the device on success. // Returns the device on success.
func (s *devicesStatements) insertDevice( func (s *devicesStatements) InsertDevice(
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
displayName *string, ipAddr, userAgent string, displayName *string, ipAddr, userAgent string,
) (*api.Device, error) { ) (*api.Device, error) {
@ -179,7 +159,7 @@ func (s *devicesStatements) insertDevice(
} }
// deleteDevice removes a single device by id and user localpart. // deleteDevice removes a single device by id and user localpart.
func (s *devicesStatements) deleteDevice( func (s *devicesStatements) DeleteDevice(
ctx context.Context, txn *sql.Tx, id, localpart string, ctx context.Context, txn *sql.Tx, id, localpart string,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt) stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
@ -189,7 +169,7 @@ func (s *devicesStatements) deleteDevice(
// deleteDevices removes a single or multiple devices by ids and user localpart. // deleteDevices removes a single or multiple devices by ids and user localpart.
// Returns an error if the execution failed. // Returns an error if the execution failed.
func (s *devicesStatements) deleteDevices( func (s *devicesStatements) DeleteDevices(
ctx context.Context, txn *sql.Tx, localpart string, devices []string, ctx context.Context, txn *sql.Tx, localpart string, devices []string,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesStmt) stmt := sqlutil.TxStmt(txn, s.deleteDevicesStmt)
@ -199,7 +179,7 @@ func (s *devicesStatements) deleteDevices(
// deleteDevicesByLocalpart removes all devices for the // deleteDevicesByLocalpart removes all devices for the
// given user localpart. // given user localpart.
func (s *devicesStatements) deleteDevicesByLocalpart( func (s *devicesStatements) DeleteDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
@ -207,7 +187,7 @@ func (s *devicesStatements) deleteDevicesByLocalpart(
return err return err
} }
func (s *devicesStatements) updateDeviceName( func (s *devicesStatements) UpdateDeviceName(
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string, ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt) stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
@ -215,7 +195,7 @@ func (s *devicesStatements) updateDeviceName(
return err return err
} }
func (s *devicesStatements) selectDeviceByToken( func (s *devicesStatements) SelectDeviceByToken(
ctx context.Context, accessToken string, ctx context.Context, accessToken string,
) (*api.Device, error) { ) (*api.Device, error) {
var dev api.Device var dev api.Device
@ -231,7 +211,7 @@ func (s *devicesStatements) selectDeviceByToken(
// selectDeviceByID retrieves a device from the database with the given user // selectDeviceByID retrieves a device from the database with the given user
// localpart and deviceID // localpart and deviceID
func (s *devicesStatements) selectDeviceByID( func (s *devicesStatements) SelectDeviceByID(
ctx context.Context, localpart, deviceID string, ctx context.Context, localpart, deviceID string,
) (*api.Device, error) { ) (*api.Device, error) {
var dev api.Device var dev api.Device
@ -248,7 +228,7 @@ func (s *devicesStatements) selectDeviceByID(
return &dev, err return &dev, err
} }
func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
rows, err := s.selectDevicesByIDStmt.QueryContext(ctx, pq.StringArray(deviceIDs)) rows, err := s.selectDevicesByIDStmt.QueryContext(ctx, pq.StringArray(deviceIDs))
if err != nil { if err != nil {
return nil, err return nil, err
@ -271,7 +251,7 @@ func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []s
return devices, rows.Err() return devices, rows.Err()
} }
func (s *devicesStatements) selectDevicesByLocalpart( func (s *devicesStatements) SelectDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
) ([]api.Device, error) { ) ([]api.Device, error) {
devices := []api.Device{} devices := []api.Device{}
@ -313,7 +293,7 @@ func (s *devicesStatements) selectDevicesByLocalpart(
return devices, rows.Err() return devices, rows.Err()
} }
func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error { func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error {
lastSeenTs := time.Now().UnixNano() / 1000000 lastSeenTs := time.Now().UnixNano() / 1000000
stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt) stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt)
_, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, localpart, deviceID) _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, localpart, deviceID)

View file

@ -22,6 +22,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
) )
const keyBackupTableSchema = ` const keyBackupTableSchema = `
@ -72,12 +73,13 @@ type keyBackupStatements struct {
selectKeysByRoomIDAndSessionIDStmt *sql.Stmt selectKeysByRoomIDAndSessionIDStmt *sql.Stmt
} }
func (s *keyBackupStatements) prepare(db *sql.DB) (err error) { func NewPostgresKeyBackupTable(db *sql.DB) (tables.KeyBackupTable, error) {
_, err = db.Exec(keyBackupTableSchema) s := &keyBackupStatements{}
_, err := db.Exec(keyBackupTableSchema)
if err != nil { if err != nil {
return return nil, err
} }
return sqlutil.StatementList{ return s, sqlutil.StatementList{
{&s.insertBackupKeyStmt, insertBackupKeySQL}, {&s.insertBackupKeyStmt, insertBackupKeySQL},
{&s.updateBackupKeyStmt, updateBackupKeySQL}, {&s.updateBackupKeyStmt, updateBackupKeySQL},
{&s.countKeysStmt, countKeysSQL}, {&s.countKeysStmt, countKeysSQL},
@ -87,14 +89,14 @@ func (s *keyBackupStatements) prepare(db *sql.DB) (err error) {
}.Prepare(db) }.Prepare(db)
} }
func (s keyBackupStatements) countKeys( func (s keyBackupStatements) CountKeys(
ctx context.Context, txn *sql.Tx, userID, version string, ctx context.Context, txn *sql.Tx, userID, version string,
) (count int64, err error) { ) (count int64, err error) {
err = txn.Stmt(s.countKeysStmt).QueryRowContext(ctx, userID, version).Scan(&count) err = txn.Stmt(s.countKeysStmt).QueryRowContext(ctx, userID, version).Scan(&count)
return return
} }
func (s *keyBackupStatements) insertBackupKey( func (s *keyBackupStatements) InsertBackupKey(
ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession, ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession,
) (err error) { ) (err error) {
_, err = txn.Stmt(s.insertBackupKeyStmt).ExecContext( _, err = txn.Stmt(s.insertBackupKeyStmt).ExecContext(
@ -103,7 +105,7 @@ func (s *keyBackupStatements) insertBackupKey(
return return
} }
func (s *keyBackupStatements) updateBackupKey( func (s *keyBackupStatements) UpdateBackupKey(
ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession, ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession,
) (err error) { ) (err error) {
_, err = txn.Stmt(s.updateBackupKeyStmt).ExecContext( _, err = txn.Stmt(s.updateBackupKeyStmt).ExecContext(
@ -112,7 +114,7 @@ func (s *keyBackupStatements) updateBackupKey(
return return
} }
func (s *keyBackupStatements) selectKeys( func (s *keyBackupStatements) SelectKeys(
ctx context.Context, txn *sql.Tx, userID, version string, ctx context.Context, txn *sql.Tx, userID, version string,
) (map[string]map[string]api.KeyBackupSession, error) { ) (map[string]map[string]api.KeyBackupSession, error) {
rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version) rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version)
@ -122,7 +124,7 @@ func (s *keyBackupStatements) selectKeys(
return unpackKeys(ctx, rows) return unpackKeys(ctx, rows)
} }
func (s *keyBackupStatements) selectKeysByRoomID( func (s *keyBackupStatements) SelectKeysByRoomID(
ctx context.Context, txn *sql.Tx, userID, version, roomID string, ctx context.Context, txn *sql.Tx, userID, version, roomID string,
) (map[string]map[string]api.KeyBackupSession, error) { ) (map[string]map[string]api.KeyBackupSession, error) {
rows, err := txn.Stmt(s.selectKeysByRoomIDStmt).QueryContext(ctx, userID, version, roomID) rows, err := txn.Stmt(s.selectKeysByRoomIDStmt).QueryContext(ctx, userID, version, roomID)
@ -132,7 +134,7 @@ func (s *keyBackupStatements) selectKeysByRoomID(
return unpackKeys(ctx, rows) return unpackKeys(ctx, rows)
} }
func (s *keyBackupStatements) selectKeysByRoomIDAndSessionID( func (s *keyBackupStatements) SelectKeysByRoomIDAndSessionID(
ctx context.Context, txn *sql.Tx, userID, version, roomID, sessionID string, ctx context.Context, txn *sql.Tx, userID, version, roomID, sessionID string,
) (map[string]map[string]api.KeyBackupSession, error) { ) (map[string]map[string]api.KeyBackupSession, error) {
rows, err := txn.Stmt(s.selectKeysByRoomIDAndSessionIDStmt).QueryContext(ctx, userID, version, roomID, sessionID) rows, err := txn.Stmt(s.selectKeysByRoomIDAndSessionIDStmt).QueryContext(ctx, userID, version, roomID, sessionID)

View file

@ -22,6 +22,7 @@ import (
"strconv" "strconv"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
) )
const keyBackupVersionTableSchema = ` const keyBackupVersionTableSchema = `
@ -69,12 +70,13 @@ type keyBackupVersionStatements struct {
updateKeyBackupETagStmt *sql.Stmt updateKeyBackupETagStmt *sql.Stmt
} }
func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) { func NewPostgresKeyBackupVersionTable(db *sql.DB) (tables.KeyBackupVersionTable, error) {
_, err = db.Exec(keyBackupVersionTableSchema) s := &keyBackupVersionStatements{}
_, err := db.Exec(keyBackupVersionTableSchema)
if err != nil { if err != nil {
return return nil, err
} }
return sqlutil.StatementList{ return s, sqlutil.StatementList{
{&s.insertKeyBackupStmt, insertKeyBackupSQL}, {&s.insertKeyBackupStmt, insertKeyBackupSQL},
{&s.updateKeyBackupAuthDataStmt, updateKeyBackupAuthDataSQL}, {&s.updateKeyBackupAuthDataStmt, updateKeyBackupAuthDataSQL},
{&s.deleteKeyBackupStmt, deleteKeyBackupSQL}, {&s.deleteKeyBackupStmt, deleteKeyBackupSQL},
@ -84,7 +86,7 @@ func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) {
}.Prepare(db) }.Prepare(db)
} }
func (s *keyBackupVersionStatements) insertKeyBackup( func (s *keyBackupVersionStatements) InsertKeyBackup(
ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage, etag string, ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage, etag string,
) (version string, err error) { ) (version string, err error) {
var versionInt int64 var versionInt int64
@ -92,7 +94,7 @@ func (s *keyBackupVersionStatements) insertKeyBackup(
return strconv.FormatInt(versionInt, 10), err return strconv.FormatInt(versionInt, 10), err
} }
func (s *keyBackupVersionStatements) updateKeyBackupAuthData( func (s *keyBackupVersionStatements) UpdateKeyBackupAuthData(
ctx context.Context, txn *sql.Tx, userID, version string, authData json.RawMessage, ctx context.Context, txn *sql.Tx, userID, version string, authData json.RawMessage,
) error { ) error {
versionInt, err := strconv.ParseInt(version, 10, 64) versionInt, err := strconv.ParseInt(version, 10, 64)
@ -103,7 +105,7 @@ func (s *keyBackupVersionStatements) updateKeyBackupAuthData(
return err return err
} }
func (s *keyBackupVersionStatements) updateKeyBackupETag( func (s *keyBackupVersionStatements) UpdateKeyBackupETag(
ctx context.Context, txn *sql.Tx, userID, version, etag string, ctx context.Context, txn *sql.Tx, userID, version, etag string,
) error { ) error {
versionInt, err := strconv.ParseInt(version, 10, 64) versionInt, err := strconv.ParseInt(version, 10, 64)
@ -114,7 +116,7 @@ func (s *keyBackupVersionStatements) updateKeyBackupETag(
return err return err
} }
func (s *keyBackupVersionStatements) deleteKeyBackup( func (s *keyBackupVersionStatements) DeleteKeyBackup(
ctx context.Context, txn *sql.Tx, userID, version string, ctx context.Context, txn *sql.Tx, userID, version string,
) (bool, error) { ) (bool, error) {
versionInt, err := strconv.ParseInt(version, 10, 64) versionInt, err := strconv.ParseInt(version, 10, 64)
@ -132,7 +134,7 @@ func (s *keyBackupVersionStatements) deleteKeyBackup(
return ra == 1, nil return ra == 1, nil
} }
func (s *keyBackupVersionStatements) selectKeyBackup( func (s *keyBackupVersionStatements) SelectKeyBackup(
ctx context.Context, txn *sql.Tx, userID, version string, ctx context.Context, txn *sql.Tx, userID, version string,
) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) { ) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) {
var versionInt int64 var versionInt int64

View file

@ -21,18 +21,11 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
type loginTokenStatements struct { const loginTokenSchema = `
insertStmt *sql.Stmt
deleteStmt *sql.Stmt
selectByTokenStmt *sql.Stmt
}
// execSchema ensures tables and indices exist.
func (s *loginTokenStatements) execSchema(db *sql.DB) error {
_, err := db.Exec(`
CREATE TABLE IF NOT EXISTS login_tokens ( CREATE TABLE IF NOT EXISTS login_tokens (
-- The random value of the token issued to a user -- The random value of the token issued to a user
token TEXT NOT NULL PRIMARY KEY, token TEXT NOT NULL PRIMARY KEY,
@ -45,24 +38,38 @@ CREATE TABLE IF NOT EXISTS login_tokens (
-- This index allows efficient garbage collection of expired tokens. -- This index allows efficient garbage collection of expired tokens.
CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_expires_at); CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_expires_at);
`) `
return err
const insertLoginTokenSQL = "" +
"INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"
const deleteLoginTokenSQL = "" +
"DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"
const selectLoginTokenSQL = "" +
"SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2"
type loginTokenStatements struct {
insertStmt *sql.Stmt
deleteStmt *sql.Stmt
selectStmt *sql.Stmt
} }
// prepare runs statement preparation. func NewPostgresLoginTokenTable(db *sql.DB) (tables.LoginTokenTable, error) {
func (s *loginTokenStatements) prepare(db *sql.DB) error { s := &loginTokenStatements{}
if err := s.execSchema(db); err != nil { _, err := db.Exec(loginTokenSchema)
return err if err != nil {
return nil, err
} }
return sqlutil.StatementList{ return s, sqlutil.StatementList{
{&s.insertStmt, "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"}, {&s.insertStmt, insertLoginTokenSQL},
{&s.deleteStmt, "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"}, {&s.deleteStmt, deleteLoginTokenSQL},
{&s.selectByTokenStmt, "SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2"}, {&s.selectStmt, selectLoginTokenSQL},
}.Prepare(db) }.Prepare(db)
} }
// insert adds an already generated token to the database. // insert adds an already generated token to the database.
func (s *loginTokenStatements) insert(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error { func (s *loginTokenStatements) InsertLoginToken(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error {
stmt := sqlutil.TxStmt(txn, s.insertStmt) stmt := sqlutil.TxStmt(txn, s.insertStmt)
_, err := stmt.ExecContext(ctx, metadata.Token, metadata.Expiration.UTC(), data.UserID) _, err := stmt.ExecContext(ctx, metadata.Token, metadata.Expiration.UTC(), data.UserID)
return err return err
@ -72,7 +79,7 @@ func (s *loginTokenStatements) insert(ctx context.Context, txn *sql.Tx, metadata
// //
// As a simple way to garbage-collect stale tokens, we also remove all expired tokens. // As a simple way to garbage-collect stale tokens, we also remove all expired tokens.
// The login_tokens_expiration_idx index should make that efficient. // The login_tokens_expiration_idx index should make that efficient.
func (s *loginTokenStatements) deleteByToken(ctx context.Context, txn *sql.Tx, token string) error { func (s *loginTokenStatements) DeleteLoginToken(ctx context.Context, txn *sql.Tx, token string) error {
stmt := sqlutil.TxStmt(txn, s.deleteStmt) stmt := sqlutil.TxStmt(txn, s.deleteStmt)
res, err := stmt.ExecContext(ctx, token, time.Now().UTC()) res, err := stmt.ExecContext(ctx, token, time.Now().UTC())
if err != nil { if err != nil {
@ -85,9 +92,9 @@ func (s *loginTokenStatements) deleteByToken(ctx context.Context, txn *sql.Tx, t
} }
// selectByToken returns the data associated with the given token. May return sql.ErrNoRows. // selectByToken returns the data associated with the given token. May return sql.ErrNoRows.
func (s *loginTokenStatements) selectByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { func (s *loginTokenStatements) SelectLoginToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
var data api.LoginTokenData var data api.LoginTokenData
err := s.selectByTokenStmt.QueryRowContext(ctx, token, time.Now().UTC()).Scan(&data.UserID) err := s.selectStmt.QueryRowContext(ctx, token, time.Now().UTC()).Scan(&data.UserID)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -6,6 +6,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@ -22,33 +23,35 @@ CREATE TABLE IF NOT EXISTS open_id_tokens (
); );
` `
const insertTokenSQL = "" + const insertOpenIDTokenSQL = "" +
"INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)" "INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
const selectTokenSQL = "" + const selectOpenIDTokenSQL = "" +
"SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1" "SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1"
type tokenStatements struct { type openIDTokenStatements struct {
insertTokenStmt *sql.Stmt insertTokenStmt *sql.Stmt
selectTokenStmt *sql.Stmt selectTokenStmt *sql.Stmt
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
} }
func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { func NewPostgresOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.OpenIDTable, error) {
_, err = db.Exec(openIDTokenSchema) s := &openIDTokenStatements{
if err != nil { serverName: serverName,
return
} }
s.serverName = server _, err := db.Exec(openIDTokenSchema)
return sqlutil.StatementList{ if err != nil {
{&s.insertTokenStmt, insertTokenSQL}, return nil, err
{&s.selectTokenStmt, selectTokenSQL}, }
return s, sqlutil.StatementList{
{&s.insertTokenStmt, insertOpenIDTokenSQL},
{&s.selectTokenStmt, selectOpenIDTokenSQL},
}.Prepare(db) }.Prepare(db)
} }
// insertToken inserts a new OpenID Connect token to the DB. // insertToken inserts a new OpenID Connect token to the DB.
// Returns new token, otherwise returns error if the token already exists. // Returns new token, otherwise returns error if the token already exists.
func (s *tokenStatements) insertToken( func (s *openIDTokenStatements) InsertOpenIDToken(
ctx context.Context, ctx context.Context,
txn *sql.Tx, txn *sql.Tx,
token, localpart string, token, localpart string,
@ -61,7 +64,7 @@ func (s *tokenStatements) insertToken(
// selectOpenIDTokenAtrributes gets the attributes associated with an OpenID token from the DB // selectOpenIDTokenAtrributes gets the attributes associated with an OpenID token from the DB
// Returns the existing token's attributes, or err if no token is found // Returns the existing token's attributes, or err if no token is found
func (s *tokenStatements) selectOpenIDTokenAtrributes( func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes(
ctx context.Context, ctx context.Context,
token string, token string,
) (*api.OpenIDTokenAttributes, error) { ) (*api.OpenIDTokenAttributes, error) {

View file

@ -22,6 +22,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
) )
const profilesSchema = ` const profilesSchema = `
@ -59,12 +60,13 @@ type profilesStatements struct {
selectProfilesBySearchStmt *sql.Stmt selectProfilesBySearchStmt *sql.Stmt
} }
func (s *profilesStatements) prepare(db *sql.DB) (err error) { func NewPostgresProfilesTable(db *sql.DB) (tables.ProfileTable, error) {
_, err = db.Exec(profilesSchema) s := &profilesStatements{}
_, err := db.Exec(profilesSchema)
if err != nil { if err != nil {
return return nil, err
} }
return sqlutil.StatementList{ return s, sqlutil.StatementList{
{&s.insertProfileStmt, insertProfileSQL}, {&s.insertProfileStmt, insertProfileSQL},
{&s.selectProfileByLocalpartStmt, selectProfileByLocalpartSQL}, {&s.selectProfileByLocalpartStmt, selectProfileByLocalpartSQL},
{&s.setAvatarURLStmt, setAvatarURLSQL}, {&s.setAvatarURLStmt, setAvatarURLSQL},
@ -73,14 +75,14 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) {
}.Prepare(db) }.Prepare(db)
} }
func (s *profilesStatements) insertProfile( func (s *profilesStatements) InsertProfile(
ctx context.Context, txn *sql.Tx, localpart string, ctx context.Context, txn *sql.Tx, localpart string,
) (err error) { ) (err error) {
_, err = sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "") _, err = sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
return return
} }
func (s *profilesStatements) selectProfileByLocalpart( func (s *profilesStatements) SelectProfileByLocalpart(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) (*authtypes.Profile, error) { ) (*authtypes.Profile, error) {
var profile authtypes.Profile var profile authtypes.Profile
@ -93,21 +95,21 @@ func (s *profilesStatements) selectProfileByLocalpart(
return &profile, nil return &profile, nil
} }
func (s *profilesStatements) setAvatarURL( func (s *profilesStatements) SetAvatarURL(
ctx context.Context, localpart string, avatarURL string, ctx context.Context, txn *sql.Tx, localpart string, avatarURL string,
) (err error) { ) (err error) {
_, err = s.setAvatarURLStmt.ExecContext(ctx, avatarURL, localpart) _, err = s.setAvatarURLStmt.ExecContext(ctx, avatarURL, localpart)
return return
} }
func (s *profilesStatements) setDisplayName( func (s *profilesStatements) SetDisplayName(
ctx context.Context, localpart string, displayName string, ctx context.Context, txn *sql.Tx, localpart string, displayName string,
) (err error) { ) (err error) {
_, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart) _, err = s.setDisplayNameStmt.ExecContext(ctx, displayName, localpart)
return return
} }
func (s *profilesStatements) selectProfilesBySearch( func (s *profilesStatements) SelectProfilesBySearch(
ctx context.Context, searchString string, limit int, ctx context.Context, searchString string, limit int,
) ([]authtypes.Profile, error) { ) ([]authtypes.Profile, error) {
var profiles []authtypes.Profile var profiles []authtypes.Profile

View file

@ -15,76 +15,33 @@
package postgres package postgres
import ( import (
"context"
"crypto/rand"
"database/sql"
"encoding/base64"
"encoding/json"
"errors"
"fmt" "fmt"
"strconv"
"time" "time"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"golang.org/x/crypto/bcrypt"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/postgres/deltas" "github.com/matrix-org/dendrite/userapi/storage/postgres/deltas"
"github.com/matrix-org/dendrite/userapi/storage/shared"
// Import the postgres database driver. // Import the postgres database driver.
_ "github.com/lib/pq" _ "github.com/lib/pq"
) )
// Database represents an account database
type Database struct {
db *sql.DB
writer sqlutil.Writer
sqlutil.PartitionOffsetStatements
accounts accountsStatements
profiles profilesStatements
accountDatas accountDataStatements
threepids threepidStatements
openIDTokens tokenStatements
keyBackupVersions keyBackupVersionStatements
devices devicesStatements
loginTokens loginTokenStatements
loginTokenLifetime time.Duration
keyBackups keyBackupStatements
serverName gomatrixserverlib.ServerName
bcryptCost int
openIDTokenLifetimeMS int64
}
const (
// The length of generated device IDs
deviceIDByteLength = 6
loginTokenByteLength = 32
)
// NewDatabase creates a new accounts and profiles database // NewDatabase creates a new accounts and profiles database
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration) (*Database, error) { func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration) (*shared.Database, error) {
db, err := sqlutil.Open(dbProperties) db, err := sqlutil.Open(dbProperties)
if err != nil { if err != nil {
return nil, err return nil, err
} }
d := &Database{
serverName: serverName,
db: db,
writer: sqlutil.NewDummyWriter(),
loginTokenLifetime: loginTokenLifetime,
bcryptCost: bcryptCost,
openIDTokenLifetimeMS: openIDTokenLifetimeMS,
}
// Create tables before executing migrations so we don't fail if the table is missing, m := sqlutil.NewMigrations()
// and THEN prepare statements so we don't fail due to referencing new columns if _, err = db.Exec(accountsSchema); err != nil {
if err = d.accounts.execSchema(db); err != nil { // do this so that the migration can and we don't fail on
// preparing statements for columns that don't exist yet
return nil, err return nil, err
} }
m := sqlutil.NewMigrations()
deltas.LoadIsActive(m) deltas.LoadIsActive(m)
//deltas.LoadLastSeenTSIP(m) //deltas.LoadLastSeenTSIP(m)
deltas.LoadAddAccountType(m) deltas.LoadAddAccountType(m)
@ -92,638 +49,57 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
return nil, err return nil, err
} }
if err = d.PartitionOffsetStatements.Prepare(db, d.writer, "account"); err != nil { accountDataTable, err := NewPostgresAccountDataTable(db)
return nil, err
}
if err = d.accounts.prepare(db, serverName); err != nil {
return nil, err
}
if err = d.profiles.prepare(db); err != nil {
return nil, err
}
if err = d.accountDatas.prepare(db); err != nil {
return nil, err
}
if err = d.threepids.prepare(db); err != nil {
return nil, err
}
if err = d.openIDTokens.prepare(db, serverName); err != nil {
return nil, err
}
if err = d.keyBackupVersions.prepare(db); err != nil {
return nil, err
}
if err = d.keyBackups.prepare(db); err != nil {
return nil, err
}
if err = d.devices.prepare(db, serverName); err != nil {
return nil, err
}
if err = d.loginTokens.prepare(db); err != nil {
return nil, err
}
return d, nil
}
// GetAccountByPassword returns the account associated with the given localpart and password.
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByPassword(
ctx context.Context, localpart, plaintextPassword string,
) (*api.Account, error) {
hash, err := d.accounts.selectPasswordHash(ctx, localpart)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("NewPostgresAccountDataTable: %w", err)
} }
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil { accountsTable, err := NewPostgresAccountsTable(db, serverName)
return nil, err
}
return d.accounts.selectAccountByLocalpart(ctx, localpart)
}
// GetProfileByLocalpart returns the profile associated with the given localpart.
// Returns sql.ErrNoRows if no profile exists which matches the given localpart.
func (d *Database) GetProfileByLocalpart(
ctx context.Context, localpart string,
) (*authtypes.Profile, error) {
return d.profiles.selectProfileByLocalpart(ctx, localpart)
}
// SetAvatarURL updates the avatar URL of the profile associated with the given
// localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetAvatarURL(
ctx context.Context, localpart string, avatarURL string,
) error {
return d.profiles.setAvatarURL(ctx, localpart, avatarURL)
}
// SetDisplayName updates the display name of the profile associated with the given
// localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetDisplayName(
ctx context.Context, localpart string, displayName string,
) error {
return d.profiles.setDisplayName(ctx, localpart, displayName)
}
// SetPassword sets the account password to the given hash.
func (d *Database) SetPassword(
ctx context.Context, localpart, plaintextPassword string,
) error {
hash, err := d.hashPassword(plaintextPassword)
if err != nil { if err != nil {
return err return nil, fmt.Errorf("NewPostgresAccountsTable: %w", err)
} }
return d.accounts.updatePassword(ctx, localpart, hash) devicesTable, err := NewPostgresDevicesTable(db, serverName)
}
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
// for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, sqlutil.ErrUserExists.
func (d *Database) CreateAccount(
ctx context.Context, localpart, plaintextPassword, appserviceID string, accountType api.AccountType,
) (acc *api.Account, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
// For guest accounts, we create a new numeric local part
if accountType == api.AccountTypeGuest {
var numLocalpart int64
numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn)
if err != nil { if err != nil {
return err return nil, fmt.Errorf("NewPostgresDevicesTable: %w", err)
} }
localpart = strconv.FormatInt(numLocalpart, 10) keyBackupTable, err := NewPostgresKeyBackupTable(db)
plaintextPassword = ""
appserviceID = ""
}
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, accountType)
return err
})
return
}
func (d *Database) createAccount(
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, accountType api.AccountType,
) (*api.Account, error) {
var account *api.Account
var err error
// Generate a password hash if this is not a password-less user
hash := ""
if plaintextPassword != "" {
hash, err = d.hashPassword(plaintextPassword)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("NewPostgresKeyBackupTable: %w", err)
} }
} keyBackupVersionTable, err := NewPostgresKeyBackupVersionTable(db)
if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID, accountType); err != nil {
if sqlutil.IsUniqueConstraintViolationErr(err) {
return nil, sqlutil.ErrUserExists
}
return nil, err
}
if err = d.profiles.insertProfile(ctx, txn, localpart); err != nil {
return nil, err
}
if err = d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{
"global": {
"content": [],
"override": [],
"room": [],
"sender": [],
"underride": []
}
}`)); err != nil {
return nil, err
}
return account, nil
}
// SaveAccountData saves new account data for a given user and a given room.
// If the account data is not specific to a room, the room ID should be an empty string
// If an account data already exists for a given set (user, room, data type), it will
// update the corresponding row with the new content
// Returns a SQL error if there was an issue with the insertion/update
func (d *Database) SaveAccountData(
ctx context.Context, localpart, roomID, dataType string, content json.RawMessage,
) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content)
})
}
// GetAccountData returns account data related to a given localpart
// If no account data could be found, returns an empty arrays
// Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountData(ctx context.Context, localpart string) (
global map[string]json.RawMessage,
rooms map[string]map[string]json.RawMessage,
err error,
) {
return d.accountDatas.selectAccountData(ctx, localpart)
}
// GetAccountDataByType returns account data matching a given
// localpart, room ID and type.
// If no account data could be found, returns nil
// Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountDataByType(
ctx context.Context, localpart, roomID, dataType string,
) (data json.RawMessage, err error) {
return d.accountDatas.selectAccountDataByType(
ctx, localpart, roomID, dataType,
)
}
// GetNewNumericLocalpart generates and returns a new unused numeric localpart
func (d *Database) GetNewNumericLocalpart(
ctx context.Context,
) (int64, error) {
return d.accounts.selectNewNumericLocalpart(ctx, nil)
}
func (d *Database) hashPassword(plaintext string) (hash string, err error) {
hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), d.bcryptCost)
return string(hashBytes), err
}
// Err3PIDInUse is the error returned when trying to save an association involving
// a third-party identifier which is already associated to a local user.
var Err3PIDInUse = errors.New("this third-party identifier is already in use")
// SaveThreePIDAssociation saves the association between a third party identifier
// and a local Matrix user (identified by the user's ID's local part).
// If the third-party identifier is already part of an association, returns Err3PIDInUse.
// Returns an error if there was a problem talking to the database.
func (d *Database) SaveThreePIDAssociation(
ctx context.Context, threepid, localpart, medium string,
) (err error) {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
user, err := d.threepids.selectLocalpartForThreePID(
ctx, txn, threepid, medium,
)
if err != nil { if err != nil {
return err return nil, fmt.Errorf("NewPostgresKeyBackupVersionTable: %w", err)
} }
loginTokenTable, err := NewPostgresLoginTokenTable(db)
if len(user) > 0 {
return Err3PIDInUse
}
return d.threepids.insertThreePID(ctx, txn, threepid, medium, localpart)
})
}
// RemoveThreePIDAssociation removes the association involving a given third-party
// identifier.
// If no association exists involving this third-party identifier, returns nothing.
// If there was a problem talking to the database, returns an error.
func (d *Database) RemoveThreePIDAssociation(
ctx context.Context, threepid string, medium string,
) (err error) {
return d.threepids.deleteThreePID(ctx, threepid, medium)
}
// GetLocalpartForThreePID looks up the localpart associated with a given third-party
// identifier.
// If no association involves the given third-party idenfitier, returns an empty
// string.
// Returns an error if there was a problem talking to the database.
func (d *Database) GetLocalpartForThreePID(
ctx context.Context, threepid string, medium string,
) (localpart string, err error) {
return d.threepids.selectLocalpartForThreePID(ctx, nil, threepid, medium)
}
// GetThreePIDsForLocalpart looks up the third-party identifiers associated with
// a given local user.
// If no association is known for this user, returns an empty slice.
// Returns an error if there was an issue talking to the database.
func (d *Database) GetThreePIDsForLocalpart(
ctx context.Context, localpart string,
) (threepids []authtypes.ThreePID, err error) {
return d.threepids.selectThreePIDsForLocalpart(ctx, localpart)
}
// CheckAccountAvailability checks if the username/localpart is already present
// in the database.
// If the DB returns sql.ErrNoRows the Localpart isn't taken.
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) {
_, err := d.accounts.selectAccountByLocalpart(ctx, localpart)
if err == sql.ErrNoRows {
return true, nil
}
return false, err
}
// GetAccountByLocalpart returns the account associated with the given localpart.
// This function assumes the request is authenticated or the account data is used only internally.
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string,
) (*api.Account, error) {
return d.accounts.selectAccountByLocalpart(ctx, localpart)
}
// SearchProfiles returns all profiles where the provided localpart or display name
// match any part of the profiles in the database.
func (d *Database) SearchProfiles(ctx context.Context, searchString string, limit int,
) ([]authtypes.Profile, error) {
return d.profiles.selectProfilesBySearch(ctx, searchString, limit)
}
// DeactivateAccount deactivates the user's account, removing all ability for the user to login again.
func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) {
return d.accounts.deactivateAccount(ctx, localpart)
}
// CreateOpenIDToken persists a new token that was issued through OpenID Connect
func (d *Database) CreateOpenIDToken(
ctx context.Context,
token, localpart string,
) (int64, error) {
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.openIDTokenLifetimeMS
err := sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.openIDTokens.insertToken(ctx, txn, token, localpart, expiresAtMS)
})
return expiresAtMS, err
}
// GetOpenIDTokenAttributes gets the attributes of issued an OIDC auth token
func (d *Database) GetOpenIDTokenAttributes(
ctx context.Context,
token string,
) (*api.OpenIDTokenAttributes, error) {
return d.openIDTokens.selectOpenIDTokenAtrributes(ctx, token)
}
func (d *Database) CreateKeyBackup(
ctx context.Context, userID, algorithm string, authData json.RawMessage,
) (version string, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
version, err = d.keyBackupVersions.insertKeyBackup(ctx, txn, userID, algorithm, authData, "")
return err
})
return
}
func (d *Database) UpdateKeyBackupAuthData(
ctx context.Context, userID, version string, authData json.RawMessage,
) (err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.keyBackupVersions.updateKeyBackupAuthData(ctx, txn, userID, version, authData)
})
return
}
func (d *Database) DeleteKeyBackup(
ctx context.Context, userID, version string,
) (exists bool, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
exists, err = d.keyBackupVersions.deleteKeyBackup(ctx, txn, userID, version)
return err
})
return
}
func (d *Database) GetKeyBackup(
ctx context.Context, userID, version string,
) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
versionResult, algorithm, authData, etag, deleted, err = d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version)
return err
})
return
}
func (d *Database) GetBackupKeys(
ctx context.Context, version, userID, filterRoomID, filterSessionID string,
) (result map[string]map[string]api.KeyBackupSession, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
if filterSessionID != "" {
result, err = d.keyBackups.selectKeysByRoomIDAndSessionID(ctx, txn, userID, version, filterRoomID, filterSessionID)
return err
}
if filterRoomID != "" {
result, err = d.keyBackups.selectKeysByRoomID(ctx, txn, userID, version, filterRoomID)
return err
}
result, err = d.keyBackups.selectKeys(ctx, txn, userID, version)
return err
})
return
}
func (d *Database) CountBackupKeys(
ctx context.Context, version, userID string,
) (count int64, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
count, err = d.keyBackups.countKeys(ctx, txn, userID, version)
if err != nil { if err != nil {
return err return nil, fmt.Errorf("NewPostgresLoginTokenTable: %w", err)
} }
return nil openIDTable, err := NewPostgresOpenIDTable(db, serverName)
})
return
}
// nolint:nakedret
func (d *Database) UpsertBackupKeys(
ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession,
) (count int64, etag string, err error) {
// wrap the following logic in a txn to ensure we atomically upload keys
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
_, _, _, oldETag, deleted, err := d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version)
if err != nil { if err != nil {
return err return nil, fmt.Errorf("NewPostgresOpenIDTable: %w", err)
} }
if deleted { profilesTable, err := NewPostgresProfilesTable(db)
return fmt.Errorf("backup was deleted")
}
// pull out all keys for this (user_id, version)
existingKeys, err := d.keyBackups.selectKeys(ctx, txn, userID, version)
if err != nil { if err != nil {
return err return nil, fmt.Errorf("NewPostgresProfilesTable: %w", err)
} }
threePIDTable, err := NewPostgresThreePIDTable(db)
changed := false
// loop over all the new keys (which should be smaller than the set of backed up keys)
for _, newKey := range uploads {
// if we have a matching (room_id, session_id), we may need to update the key if it meets some rules, check them.
existingRoom := existingKeys[newKey.RoomID]
if existingRoom != nil {
existingSession, ok := existingRoom[newKey.SessionID]
if ok {
if existingSession.ShouldReplaceRoomKey(&newKey.KeyBackupSession) {
err = d.keyBackups.updateBackupKey(ctx, txn, userID, version, newKey)
changed = true
if err != nil { if err != nil {
return fmt.Errorf("d.keyBackups.updateBackupKey: %w", err) return nil, fmt.Errorf("NewPostgresThreePIDTable: %w", err)
} }
} return &shared.Database{
// if we shouldn't replace the key we do nothing with it AccountDatas: accountDataTable,
continue Accounts: accountsTable,
} Devices: devicesTable,
} KeyBackups: keyBackupTable,
// if we're here, either the room or session are new, either way, we insert KeyBackupVersions: keyBackupVersionTable,
err = d.keyBackups.insertBackupKey(ctx, txn, userID, version, newKey) LoginTokens: loginTokenTable,
changed = true OpenIDTokens: openIDTable,
if err != nil { Profiles: profilesTable,
return fmt.Errorf("d.keyBackups.insertBackupKey: %w", err) ThreePIDs: threePIDTable,
} ServerName: serverName,
} DB: db,
Writer: sqlutil.NewDummyWriter(),
count, err = d.keyBackups.countKeys(ctx, txn, userID, version) LoginTokenLifetime: loginTokenLifetime,
if err != nil { BcryptCost: bcryptCost,
return err OpenIDTokenLifetimeMS: openIDTokenLifetimeMS,
} }, nil
if changed {
// update the etag
var newETag string
if oldETag == "" {
newETag = "1"
} else {
oldETagInt, err := strconv.ParseInt(oldETag, 10, 64)
if err != nil {
return fmt.Errorf("failed to parse old etag: %s", err)
}
newETag = strconv.FormatInt(oldETagInt+1, 10)
}
etag = newETag
return d.keyBackupVersions.updateKeyBackupETag(ctx, txn, userID, version, newETag)
} else {
etag = oldETag
}
return nil
})
return
}
// GetDeviceByAccessToken returns the device matching the given access token.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByAccessToken(
ctx context.Context, token string,
) (*api.Device, error) {
return d.devices.selectDeviceByToken(ctx, token)
}
// GetDeviceByID returns the device matching the given ID.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*api.Device, error) {
return d.devices.selectDeviceByID(ctx, localpart, deviceID)
}
// GetDevicesByLocalpart returns the devices matching the given localpart.
func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string,
) ([]api.Device, error) {
return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "")
}
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
return d.devices.selectDevicesByID(ctx, deviceIDs)
}
// CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
// an error will be returned.
// If no device ID is given one is generated.
// Returns the device on success.
func (d *Database) CreateDevice(
ctx context.Context, localpart string, deviceID *string, accessToken string,
displayName *string, ipAddr, userAgent string,
) (dev *api.Device, returnErr error) {
if deviceID != nil {
returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error
// Revoke existing tokens for this device
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
return err
}
dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
return err
})
} else {
// We generate device IDs in a loop in case its already taken.
// We cap this at going round 5 times to ensure we don't spin forever
var newDeviceID string
for i := 1; i <= 5; i++ {
newDeviceID, returnErr = generateDeviceID()
if returnErr != nil {
return
}
returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
return err
})
if returnErr == nil {
return
}
}
}
return
}
// generateDeviceID creates a new device id. Returns an error if failed to generate
// random bytes.
func generateDeviceID() (string, error) {
b := make([]byte, deviceIDByteLength)
_, err := rand.Read(b)
if err != nil {
return "", err
}
// url-safe no padding
return base64.RawURLEncoding.EncodeToString(b), nil
}
// UpdateDevice updates the given device with the display name.
// Returns SQL error if there are problems and nil on success.
func (d *Database) UpdateDevice(
ctx context.Context, localpart, deviceID string, displayName *string,
) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
})
}
// RemoveDevice revokes a device by deleting the entry in the database
// matching with the given device ID and user ID localpart.
// If the device doesn't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevice(
ctx context.Context, deviceID, localpart string,
) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveDevices revokes one or more devices by deleting the entry in the database
// matching with the given device IDs and user ID localpart.
// If the devices don't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevices(
ctx context.Context, localpart string, devices []string,
) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveAllDevices revokes devices by deleting the entry in the
// database matching the given user ID localpart.
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveAllDevices(
ctx context.Context, localpart, exceptDeviceID string,
) (devices []api.Device, err error) {
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
if err != nil {
return err
}
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
return err
}
return nil
})
return
}
// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address
func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr)
})
}
// CreateLoginToken generates a token, stores and returns it. The lifetime is
// determined by the loginTokenLifetime given to the Database constructor.
func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) {
tok, err := generateLoginToken()
if err != nil {
return nil, err
}
meta := &api.LoginTokenMetadata{
Token: tok,
Expiration: time.Now().Add(d.loginTokenLifetime),
}
err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.loginTokens.insert(ctx, txn, meta, data)
})
if err != nil {
return nil, err
}
return meta, nil
}
func generateLoginToken() (string, error) {
b := make([]byte, loginTokenByteLength)
_, err := rand.Read(b)
if err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
// RemoveLoginToken removes the named token (and may clean up other expired tokens).
func (d *Database) RemoveLoginToken(ctx context.Context, token string) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.loginTokens.deleteByToken(ctx, txn, token)
})
}
// GetLoginTokenDataByToken returns the data associated with the given token.
// May return sql.ErrNoRows.
func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
return d.loginTokens.selectByToken(ctx, token)
} }

View file

@ -19,6 +19,7 @@ import (
"database/sql" "database/sql"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
) )
@ -58,12 +59,13 @@ type threepidStatements struct {
deleteThreePIDStmt *sql.Stmt deleteThreePIDStmt *sql.Stmt
} }
func (s *threepidStatements) prepare(db *sql.DB) (err error) { func NewPostgresThreePIDTable(db *sql.DB) (tables.ThreePIDTable, error) {
_, err = db.Exec(threepidSchema) s := &threepidStatements{}
_, err := db.Exec(threepidSchema)
if err != nil { if err != nil {
return return nil, err
} }
return sqlutil.StatementList{ return s, sqlutil.StatementList{
{&s.selectLocalpartForThreePIDStmt, selectLocalpartForThreePIDSQL}, {&s.selectLocalpartForThreePIDStmt, selectLocalpartForThreePIDSQL},
{&s.selectThreePIDsForLocalpartStmt, selectThreePIDsForLocalpartSQL}, {&s.selectThreePIDsForLocalpartStmt, selectThreePIDsForLocalpartSQL},
{&s.insertThreePIDStmt, insertThreePIDSQL}, {&s.insertThreePIDStmt, insertThreePIDSQL},
@ -71,7 +73,7 @@ func (s *threepidStatements) prepare(db *sql.DB) (err error) {
}.Prepare(db) }.Prepare(db)
} }
func (s *threepidStatements) selectLocalpartForThreePID( func (s *threepidStatements) SelectLocalpartForThreePID(
ctx context.Context, txn *sql.Tx, threepid string, medium string, ctx context.Context, txn *sql.Tx, threepid string, medium string,
) (localpart string, err error) { ) (localpart string, err error) {
stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt) stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt)
@ -82,7 +84,7 @@ func (s *threepidStatements) selectLocalpartForThreePID(
return return
} }
func (s *threepidStatements) selectThreePIDsForLocalpart( func (s *threepidStatements) SelectThreePIDsForLocalpart(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) (threepids []authtypes.ThreePID, err error) { ) (threepids []authtypes.ThreePID, err error) {
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart) rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart)
@ -106,7 +108,7 @@ func (s *threepidStatements) selectThreePIDsForLocalpart(
return return
} }
func (s *threepidStatements) insertThreePID( func (s *threepidStatements) InsertThreePID(
ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, ctx context.Context, txn *sql.Tx, threepid, medium, localpart string,
) (err error) { ) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt) stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt)
@ -114,8 +116,9 @@ func (s *threepidStatements) insertThreePID(
return return
} }
func (s *threepidStatements) deleteThreePID( func (s *threepidStatements) DeleteThreePID(
ctx context.Context, threepid string, medium string) (err error) { ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error) {
_, err = s.deleteThreePIDStmt.ExecContext(ctx, threepid, medium) stmt := sqlutil.TxStmt(txn, s.deleteThreePIDStmt)
_, err = stmt.ExecContext(ctx, threepid, medium)
return return
} }

View file

@ -0,0 +1,672 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package shared
import (
"context"
"crypto/rand"
"database/sql"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"strconv"
"time"
"github.com/matrix-org/gomatrixserverlib"
"golang.org/x/crypto/bcrypt"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
)
// Database represents an account database
type Database struct {
DB *sql.DB
Writer sqlutil.Writer
Accounts tables.AccountsTable
Profiles tables.ProfileTable
AccountDatas tables.AccountDataTable
ThreePIDs tables.ThreePIDTable
OpenIDTokens tables.OpenIDTable
KeyBackups tables.KeyBackupTable
KeyBackupVersions tables.KeyBackupVersionTable
Devices tables.DevicesTable
LoginTokens tables.LoginTokenTable
LoginTokenLifetime time.Duration
ServerName gomatrixserverlib.ServerName
BcryptCost int
OpenIDTokenLifetimeMS int64
}
const (
// The length of generated device IDs
deviceIDByteLength = 6
loginTokenByteLength = 32
)
// GetAccountByPassword returns the account associated with the given localpart and password.
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByPassword(
ctx context.Context, localpart, plaintextPassword string,
) (*api.Account, error) {
hash, err := d.Accounts.SelectPasswordHash(ctx, localpart)
if err != nil {
return nil, err
}
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil {
return nil, err
}
return d.Accounts.SelectAccountByLocalpart(ctx, localpart)
}
// GetProfileByLocalpart returns the profile associated with the given localpart.
// Returns sql.ErrNoRows if no profile exists which matches the given localpart.
func (d *Database) GetProfileByLocalpart(
ctx context.Context, localpart string,
) (*authtypes.Profile, error) {
return d.Profiles.SelectProfileByLocalpart(ctx, localpart)
}
// SetAvatarURL updates the avatar URL of the profile associated with the given
// localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetAvatarURL(
ctx context.Context, localpart string, avatarURL string,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Profiles.SetAvatarURL(ctx, txn, localpart, avatarURL)
})
}
// SetDisplayName updates the display name of the profile associated with the given
// localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetDisplayName(
ctx context.Context, localpart string, displayName string,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Profiles.SetDisplayName(ctx, txn, localpart, displayName)
})
}
// SetPassword sets the account password to the given hash.
func (d *Database) SetPassword(
ctx context.Context, localpart, plaintextPassword string,
) error {
hash, err := d.hashPassword(plaintextPassword)
if err != nil {
return err
}
return d.Writer.Do(nil, nil, func(txn *sql.Tx) error {
return d.Accounts.UpdatePassword(ctx, localpart, hash)
})
}
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
// for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, ErrUserExists.
func (d *Database) CreateAccount(
ctx context.Context, localpart, plaintextPassword, appserviceID string, accountType api.AccountType,
) (acc *api.Account, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
// For guest accounts, we create a new numeric local part
if accountType == api.AccountTypeGuest {
var numLocalpart int64
numLocalpart, err = d.Accounts.SelectNewNumericLocalpart(ctx, txn)
if err != nil {
return err
}
localpart = strconv.FormatInt(numLocalpart, 10)
plaintextPassword = ""
appserviceID = ""
}
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, accountType)
return err
})
return
}
// WARNING! This function assumes that the relevant mutexes have already
// been taken out by the caller (e.g. CreateAccount or CreateGuestAccount).
func (d *Database) createAccount(
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, accountType api.AccountType,
) (*api.Account, error) {
var err error
var account *api.Account
// Generate a password hash if this is not a password-less user
hash := ""
if plaintextPassword != "" {
hash, err = d.hashPassword(plaintextPassword)
if err != nil {
return nil, err
}
}
if account, err = d.Accounts.InsertAccount(ctx, txn, localpart, hash, appserviceID, accountType); err != nil {
return nil, sqlutil.ErrUserExists
}
if err = d.Profiles.InsertProfile(ctx, txn, localpart); err != nil {
return nil, err
}
if err = d.AccountDatas.InsertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{
"global": {
"content": [],
"override": [],
"room": [],
"sender": [],
"underride": []
}
}`)); err != nil {
return nil, err
}
return account, nil
}
// SaveAccountData saves new account data for a given user and a given room.
// If the account data is not specific to a room, the room ID should be an empty string
// If an account data already exists for a given set (user, room, data type), it will
// update the corresponding row with the new content
// Returns a SQL error if there was an issue with the insertion/update
func (d *Database) SaveAccountData(
ctx context.Context, localpart, roomID, dataType string, content json.RawMessage,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.AccountDatas.InsertAccountData(ctx, txn, localpart, roomID, dataType, content)
})
}
// GetAccountData returns account data related to a given localpart
// If no account data could be found, returns an empty arrays
// Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountData(ctx context.Context, localpart string) (
global map[string]json.RawMessage,
rooms map[string]map[string]json.RawMessage,
err error,
) {
return d.AccountDatas.SelectAccountData(ctx, localpart)
}
// GetAccountDataByType returns account data matching a given
// localpart, room ID and type.
// If no account data could be found, returns nil
// Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountDataByType(
ctx context.Context, localpart, roomID, dataType string,
) (data json.RawMessage, err error) {
return d.AccountDatas.SelectAccountDataByType(
ctx, localpart, roomID, dataType,
)
}
// GetNewNumericLocalpart generates and returns a new unused numeric localpart
func (d *Database) GetNewNumericLocalpart(
ctx context.Context,
) (int64, error) {
return d.Accounts.SelectNewNumericLocalpart(ctx, nil)
}
func (d *Database) hashPassword(plaintext string) (hash string, err error) {
hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), d.BcryptCost)
return string(hashBytes), err
}
// Err3PIDInUse is the error returned when trying to save an association involving
// a third-party identifier which is already associated to a local user.
var Err3PIDInUse = errors.New("this third-party identifier is already in use")
// SaveThreePIDAssociation saves the association between a third party identifier
// and a local Matrix user (identified by the user's ID's local part).
// If the third-party identifier is already part of an association, returns Err3PIDInUse.
// Returns an error if there was a problem talking to the database.
func (d *Database) SaveThreePIDAssociation(
ctx context.Context, threepid, localpart, medium string,
) (err error) {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
user, err := d.ThreePIDs.SelectLocalpartForThreePID(
ctx, txn, threepid, medium,
)
if err != nil {
return err
}
if len(user) > 0 {
return Err3PIDInUse
}
return d.ThreePIDs.InsertThreePID(ctx, txn, threepid, medium, localpart)
})
}
// RemoveThreePIDAssociation removes the association involving a given third-party
// identifier.
// If no association exists involving this third-party identifier, returns nothing.
// If there was a problem talking to the database, returns an error.
func (d *Database) RemoveThreePIDAssociation(
ctx context.Context, threepid string, medium string,
) (err error) {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.ThreePIDs.DeleteThreePID(ctx, txn, threepid, medium)
})
}
// GetLocalpartForThreePID looks up the localpart associated with a given third-party
// identifier.
// If no association involves the given third-party idenfitier, returns an empty
// string.
// Returns an error if there was a problem talking to the database.
func (d *Database) GetLocalpartForThreePID(
ctx context.Context, threepid string, medium string,
) (localpart string, err error) {
return d.ThreePIDs.SelectLocalpartForThreePID(ctx, nil, threepid, medium)
}
// GetThreePIDsForLocalpart looks up the third-party identifiers associated with
// a given local user.
// If no association is known for this user, returns an empty slice.
// Returns an error if there was an issue talking to the database.
func (d *Database) GetThreePIDsForLocalpart(
ctx context.Context, localpart string,
) (threepids []authtypes.ThreePID, err error) {
return d.ThreePIDs.SelectThreePIDsForLocalpart(ctx, localpart)
}
// CheckAccountAvailability checks if the username/localpart is already present
// in the database.
// If the DB returns sql.ErrNoRows the Localpart isn't taken.
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) {
_, err := d.Accounts.SelectAccountByLocalpart(ctx, localpart)
if err == sql.ErrNoRows {
return true, nil
}
return false, err
}
// GetAccountByLocalpart returns the account associated with the given localpart.
// This function assumes the request is authenticated or the account data is used only internally.
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string,
) (*api.Account, error) {
return d.Accounts.SelectAccountByLocalpart(ctx, localpart)
}
// SearchProfiles returns all profiles where the provided localpart or display name
// match any part of the profiles in the database.
func (d *Database) SearchProfiles(ctx context.Context, searchString string, limit int,
) ([]authtypes.Profile, error) {
return d.Profiles.SelectProfilesBySearch(ctx, searchString, limit)
}
// DeactivateAccount deactivates the user's account, removing all ability for the user to login again.
func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) {
return d.Writer.Do(nil, nil, func(txn *sql.Tx) error {
return d.Accounts.DeactivateAccount(ctx, localpart)
})
}
// CreateOpenIDToken persists a new token that was issued for OpenID Connect
func (d *Database) CreateOpenIDToken(
ctx context.Context,
token, localpart string,
) (int64, error) {
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.OpenIDTokenLifetimeMS
err := d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.OpenIDTokens.InsertOpenIDToken(ctx, txn, token, localpart, expiresAtMS)
})
return expiresAtMS, err
}
// GetOpenIDTokenAttributes gets the attributes of issued an OIDC auth token
func (d *Database) GetOpenIDTokenAttributes(
ctx context.Context,
token string,
) (*api.OpenIDTokenAttributes, error) {
return d.OpenIDTokens.SelectOpenIDTokenAtrributes(ctx, token)
}
func (d *Database) CreateKeyBackup(
ctx context.Context, userID, algorithm string, authData json.RawMessage,
) (version string, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
version, err = d.KeyBackupVersions.InsertKeyBackup(ctx, txn, userID, algorithm, authData, "")
return err
})
return
}
func (d *Database) UpdateKeyBackupAuthData(
ctx context.Context, userID, version string, authData json.RawMessage,
) (err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.KeyBackupVersions.UpdateKeyBackupAuthData(ctx, txn, userID, version, authData)
})
return
}
func (d *Database) DeleteKeyBackup(
ctx context.Context, userID, version string,
) (exists bool, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
exists, err = d.KeyBackupVersions.DeleteKeyBackup(ctx, txn, userID, version)
return err
})
return
}
func (d *Database) GetKeyBackup(
ctx context.Context, userID, version string,
) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
versionResult, algorithm, authData, etag, deleted, err = d.KeyBackupVersions.SelectKeyBackup(ctx, txn, userID, version)
return err
})
return
}
func (d *Database) GetBackupKeys(
ctx context.Context, version, userID, filterRoomID, filterSessionID string,
) (result map[string]map[string]api.KeyBackupSession, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if filterSessionID != "" {
result, err = d.KeyBackups.SelectKeysByRoomIDAndSessionID(ctx, txn, userID, version, filterRoomID, filterSessionID)
return err
}
if filterRoomID != "" {
result, err = d.KeyBackups.SelectKeysByRoomID(ctx, txn, userID, version, filterRoomID)
return err
}
result, err = d.KeyBackups.SelectKeys(ctx, txn, userID, version)
return err
})
return
}
func (d *Database) CountBackupKeys(
ctx context.Context, version, userID string,
) (count int64, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
count, err = d.KeyBackups.CountKeys(ctx, txn, userID, version)
if err != nil {
return err
}
return nil
})
return
}
// nolint:nakedret
func (d *Database) UpsertBackupKeys(
ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession,
) (count int64, etag string, err error) {
// wrap the following logic in a txn to ensure we atomically upload keys
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
_, _, _, oldETag, deleted, err := d.KeyBackupVersions.SelectKeyBackup(ctx, txn, userID, version)
if err != nil {
return err
}
if deleted {
return fmt.Errorf("backup was deleted")
}
// pull out all keys for this (user_id, version)
existingKeys, err := d.KeyBackups.SelectKeys(ctx, txn, userID, version)
if err != nil {
return err
}
changed := false
// loop over all the new keys (which should be smaller than the set of backed up keys)
for _, newKey := range uploads {
// if we have a matching (room_id, session_id), we may need to update the key if it meets some rules, check them.
existingRoom := existingKeys[newKey.RoomID]
if existingRoom != nil {
existingSession, ok := existingRoom[newKey.SessionID]
if ok {
if existingSession.ShouldReplaceRoomKey(&newKey.KeyBackupSession) {
err = d.KeyBackups.UpdateBackupKey(ctx, txn, userID, version, newKey)
changed = true
if err != nil {
return fmt.Errorf("d.KeyBackups.UpdateBackupKey: %w", err)
}
}
// if we shouldn't replace the key we do nothing with it
continue
}
}
// if we're here, either the room or session are new, either way, we insert
err = d.KeyBackups.InsertBackupKey(ctx, txn, userID, version, newKey)
changed = true
if err != nil {
return fmt.Errorf("d.KeyBackups.InsertBackupKey: %w", err)
}
}
count, err = d.KeyBackups.CountKeys(ctx, txn, userID, version)
if err != nil {
return err
}
if changed {
// update the etag
var newETag string
if oldETag == "" {
newETag = "1"
} else {
oldETagInt, err := strconv.ParseInt(oldETag, 10, 64)
if err != nil {
return fmt.Errorf("failed to parse old etag: %s", err)
}
newETag = strconv.FormatInt(oldETagInt+1, 10)
}
etag = newETag
return d.KeyBackupVersions.UpdateKeyBackupETag(ctx, txn, userID, version, newETag)
} else {
etag = oldETag
}
return nil
})
return
}
// GetDeviceByAccessToken returns the device matching the given access token.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByAccessToken(
ctx context.Context, token string,
) (*api.Device, error) {
return d.Devices.SelectDeviceByToken(ctx, token)
}
// GetDeviceByID returns the device matching the given ID.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*api.Device, error) {
return d.Devices.SelectDeviceByID(ctx, localpart, deviceID)
}
// GetDevicesByLocalpart returns the devices matching the given localpart.
func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string,
) ([]api.Device, error) {
return d.Devices.SelectDevicesByLocalpart(ctx, nil, localpart, "")
}
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
return d.Devices.SelectDevicesByID(ctx, deviceIDs)
}
// CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
// an error will be returned.
// If no device ID is given one is generated.
// Returns the device on success.
func (d *Database) CreateDevice(
ctx context.Context, localpart string, deviceID *string, accessToken string,
displayName *string, ipAddr, userAgent string,
) (dev *api.Device, returnErr error) {
if deviceID != nil {
returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
var err error
// Revoke existing tokens for this device
if err = d.Devices.DeleteDevice(ctx, txn, *deviceID, localpart); err != nil {
return err
}
dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
return err
})
} else {
// We generate device IDs in a loop in case its already taken.
// We cap this at going round 5 times to ensure we don't spin forever
var newDeviceID string
for i := 1; i <= 5; i++ {
newDeviceID, returnErr = generateDeviceID()
if returnErr != nil {
return
}
returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
var err error
dev, err = d.Devices.InsertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
return err
})
if returnErr == nil {
return
}
}
}
return
}
// generateDeviceID creates a new device id. Returns an error if failed to generate
// random bytes.
func generateDeviceID() (string, error) {
b := make([]byte, deviceIDByteLength)
_, err := rand.Read(b)
if err != nil {
return "", err
}
// url-safe no padding
return base64.RawURLEncoding.EncodeToString(b), nil
}
// UpdateDevice updates the given device with the display name.
// Returns SQL error if there are problems and nil on success.
func (d *Database) UpdateDevice(
ctx context.Context, localpart, deviceID string, displayName *string,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Devices.UpdateDeviceName(ctx, txn, localpart, deviceID, displayName)
})
}
// RemoveDevice revokes a device by deleting the entry in the database
// matching with the given device ID and user ID localpart.
// If the device doesn't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevice(
ctx context.Context, deviceID, localpart string,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if err := d.Devices.DeleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveDevices revokes one or more devices by deleting the entry in the database
// matching with the given device IDs and user ID localpart.
// If the devices don't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevices(
ctx context.Context, localpart string, devices []string,
) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
if err := d.Devices.DeleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveAllDevices revokes devices by deleting the entry in the
// database matching the given user ID localpart.
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveAllDevices(
ctx context.Context, localpart, exceptDeviceID string,
) (devices []api.Device, err error) {
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
devices, err = d.Devices.SelectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
if err != nil {
return err
}
if err := d.Devices.DeleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
return err
}
return nil
})
return
}
// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address
func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.Devices.UpdateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr)
})
}
// CreateLoginToken generates a token, stores and returns it. The lifetime is
// determined by the loginTokenLifetime given to the Database constructor.
func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) {
tok, err := generateLoginToken()
if err != nil {
return nil, err
}
meta := &api.LoginTokenMetadata{
Token: tok,
Expiration: time.Now().Add(d.LoginTokenLifetime),
}
err = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.LoginTokens.InsertLoginToken(ctx, txn, meta, data)
})
if err != nil {
return nil, err
}
return meta, nil
}
func generateLoginToken() (string, error) {
b := make([]byte, loginTokenByteLength)
_, err := rand.Read(b)
if err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
// RemoveLoginToken removes the named token (and may clean up other expired tokens).
func (d *Database) RemoveLoginToken(ctx context.Context, token string) error {
return d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
return d.LoginTokens.DeleteLoginToken(ctx, txn, token)
})
}
// GetLoginTokenDataByToken returns the data associated with the given token.
// May return sql.ErrNoRows.
func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
return d.LoginTokens.SelectLoginToken(ctx, token)
}

View file

@ -20,6 +20,7 @@ import (
"encoding/json" "encoding/json"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
) )
const accountDataSchema = ` const accountDataSchema = `
@ -56,27 +57,29 @@ type accountDataStatements struct {
selectAccountDataByTypeStmt *sql.Stmt selectAccountDataByTypeStmt *sql.Stmt
} }
func (s *accountDataStatements) prepare(db *sql.DB) (err error) { func NewSQLiteAccountDataTable(db *sql.DB) (tables.AccountDataTable, error) {
s.db = db s := &accountDataStatements{
_, err = db.Exec(accountDataSchema) db: db,
if err != nil {
return
} }
return sqlutil.StatementList{ _, err := db.Exec(accountDataSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.insertAccountDataStmt, insertAccountDataSQL}, {&s.insertAccountDataStmt, insertAccountDataSQL},
{&s.selectAccountDataStmt, selectAccountDataSQL}, {&s.selectAccountDataStmt, selectAccountDataSQL},
{&s.selectAccountDataByTypeStmt, selectAccountDataByTypeSQL}, {&s.selectAccountDataByTypeStmt, selectAccountDataByTypeSQL},
}.Prepare(db) }.Prepare(db)
} }
func (s *accountDataStatements) insertAccountData( func (s *accountDataStatements) InsertAccountData(
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage, ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
) error { ) error {
_, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content) _, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content)
return err return err
} }
func (s *accountDataStatements) selectAccountData( func (s *accountDataStatements) SelectAccountData(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) ( ) (
/* global */ map[string]json.RawMessage, /* global */ map[string]json.RawMessage,
@ -113,7 +116,7 @@ func (s *accountDataStatements) selectAccountData(
return global, rooms, nil return global, rooms, nil
} }
func (s *accountDataStatements) selectAccountDataByType( func (s *accountDataStatements) SelectAccountDataByType(
ctx context.Context, localpart, roomID, dataType string, ctx context.Context, localpart, roomID, dataType string,
) (data json.RawMessage, err error) { ) (data json.RawMessage, err error) {
var bytes []byte var bytes []byte

View file

@ -24,6 +24,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@ -77,15 +78,16 @@ type accountsStatements struct {
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
} }
func (s *accountsStatements) execSchema(db *sql.DB) error { func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.AccountsTable, error) {
_, err := db.Exec(accountsSchema) s := &accountsStatements{
return err db: db,
serverName: serverName,
} }
_, err := db.Exec(accountsSchema)
func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { if err != nil {
s.db = db return nil, err
s.serverName = server }
return sqlutil.StatementList{ return s, sqlutil.StatementList{
{&s.insertAccountStmt, insertAccountSQL}, {&s.insertAccountStmt, insertAccountSQL},
{&s.updatePasswordStmt, updatePasswordSQL}, {&s.updatePasswordStmt, updatePasswordSQL},
{&s.deactivateAccountStmt, deactivateAccountSQL}, {&s.deactivateAccountStmt, deactivateAccountSQL},
@ -98,7 +100,7 @@ func (s *accountsStatements) prepare(db *sql.DB, server gomatrixserverlib.Server
// insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing, // insertAccount creates a new account. 'hash' should be the password hash for this account. If it is missing,
// this account will be passwordless. Returns an error if this account already exists. Returns the account // this account will be passwordless. Returns an error if this account already exists. Returns the account
// on success. // on success.
func (s *accountsStatements) insertAccount( func (s *accountsStatements) InsertAccount(
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType, ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType,
) (*api.Account, error) { ) (*api.Account, error) {
createdTimeMS := time.Now().UnixNano() / 1000000 createdTimeMS := time.Now().UnixNano() / 1000000
@ -122,28 +124,28 @@ func (s *accountsStatements) insertAccount(
}, nil }, nil
} }
func (s *accountsStatements) updatePassword( func (s *accountsStatements) UpdatePassword(
ctx context.Context, localpart, passwordHash string, ctx context.Context, localpart, passwordHash string,
) (err error) { ) (err error) {
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart) _, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart)
return return
} }
func (s *accountsStatements) deactivateAccount( func (s *accountsStatements) DeactivateAccount(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) (err error) { ) (err error) {
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart) _, err = s.deactivateAccountStmt.ExecContext(ctx, localpart)
return return
} }
func (s *accountsStatements) selectPasswordHash( func (s *accountsStatements) SelectPasswordHash(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) (hash string, err error) { ) (hash string, err error) {
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash) err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash)
return return
} }
func (s *accountsStatements) selectAccountByLocalpart( func (s *accountsStatements) SelectAccountByLocalpart(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) (*api.Account, error) { ) (*api.Account, error) {
var appserviceIDPtr sql.NullString var appserviceIDPtr sql.NullString
@ -167,7 +169,7 @@ func (s *accountsStatements) selectAccountByLocalpart(
return &acc, nil return &acc, nil
} }
func (s *accountsStatements) selectNewNumericLocalpart( func (s *accountsStatements) SelectNewNumericLocalpart(
ctx context.Context, txn *sql.Tx, ctx context.Context, txn *sql.Tx,
) (id int64, err error) { ) (id int64, err error) {
stmt := s.selectNewNumericLocalpartStmt stmt := s.selectNewNumericLocalpartStmt

View file

@ -23,6 +23,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
@ -84,7 +85,6 @@ const updateDeviceLastSeen = "" +
type devicesStatements struct { type devicesStatements struct {
db *sql.DB db *sql.DB
writer sqlutil.Writer
insertDeviceStmt *sql.Stmt insertDeviceStmt *sql.Stmt
selectDevicesCountStmt *sql.Stmt selectDevicesCountStmt *sql.Stmt
selectDeviceByTokenStmt *sql.Stmt selectDeviceByTokenStmt *sql.Stmt
@ -98,55 +98,33 @@ type devicesStatements struct {
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
} }
func (s *devicesStatements) execSchema(db *sql.DB) error { func NewSQLiteDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.DevicesTable, error) {
s := &devicesStatements{
db: db,
serverName: serverName,
}
_, err := db.Exec(devicesSchema) _, err := db.Exec(devicesSchema)
return err if err != nil {
return nil, err
} }
return s, sqlutil.StatementList{
func (s *devicesStatements) prepare(db *sql.DB, writer sqlutil.Writer, server gomatrixserverlib.ServerName) (err error) { {&s.insertDeviceStmt, insertDeviceSQL},
s.db = db {&s.selectDevicesCountStmt, selectDevicesCountSQL},
s.writer = writer {&s.selectDeviceByTokenStmt, selectDeviceByTokenSQL},
if err = s.execSchema(db); err != nil { {&s.selectDeviceByIDStmt, selectDeviceByIDSQL},
return {&s.selectDevicesByLocalpartStmt, selectDevicesByLocalpartSQL},
} {&s.updateDeviceNameStmt, updateDeviceNameSQL},
if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil { {&s.deleteDeviceStmt, deleteDeviceSQL},
return {&s.deleteDevicesByLocalpartStmt, deleteDevicesByLocalpartSQL},
} {&s.selectDevicesByIDStmt, selectDevicesByIDSQL},
if s.selectDevicesCountStmt, err = db.Prepare(selectDevicesCountSQL); err != nil { {&s.updateDeviceLastSeenStmt, updateDeviceLastSeen},
return }.Prepare(db)
}
if s.selectDeviceByTokenStmt, err = db.Prepare(selectDeviceByTokenSQL); err != nil {
return
}
if s.selectDeviceByIDStmt, err = db.Prepare(selectDeviceByIDSQL); err != nil {
return
}
if s.selectDevicesByLocalpartStmt, err = db.Prepare(selectDevicesByLocalpartSQL); err != nil {
return
}
if s.updateDeviceNameStmt, err = db.Prepare(updateDeviceNameSQL); err != nil {
return
}
if s.deleteDeviceStmt, err = db.Prepare(deleteDeviceSQL); err != nil {
return
}
if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil {
return
}
if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil {
return
}
if s.updateDeviceLastSeenStmt, err = db.Prepare(updateDeviceLastSeen); err != nil {
return
}
s.serverName = server
return
} }
// insertDevice creates a new device. Returns an error if any device with the same access token already exists. // insertDevice creates a new device. Returns an error if any device with the same access token already exists.
// Returns an error if the user already has a device with the given device ID. // Returns an error if the user already has a device with the given device ID.
// Returns the device on success. // Returns the device on success.
func (s *devicesStatements) insertDevice( func (s *devicesStatements) InsertDevice(
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
displayName *string, ipAddr, userAgent string, displayName *string, ipAddr, userAgent string,
) (*api.Device, error) { ) (*api.Device, error) {
@ -172,7 +150,7 @@ func (s *devicesStatements) insertDevice(
}, nil }, nil
} }
func (s *devicesStatements) deleteDevice( func (s *devicesStatements) DeleteDevice(
ctx context.Context, txn *sql.Tx, id, localpart string, ctx context.Context, txn *sql.Tx, id, localpart string,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt) stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
@ -180,7 +158,7 @@ func (s *devicesStatements) deleteDevice(
return err return err
} }
func (s *devicesStatements) deleteDevices( func (s *devicesStatements) DeleteDevices(
ctx context.Context, txn *sql.Tx, localpart string, devices []string, ctx context.Context, txn *sql.Tx, localpart string, devices []string,
) error { ) error {
orig := strings.Replace(deleteDevicesSQL, "($2)", sqlutil.QueryVariadicOffset(len(devices), 1), 1) orig := strings.Replace(deleteDevicesSQL, "($2)", sqlutil.QueryVariadicOffset(len(devices), 1), 1)
@ -198,7 +176,7 @@ func (s *devicesStatements) deleteDevices(
return err return err
} }
func (s *devicesStatements) deleteDevicesByLocalpart( func (s *devicesStatements) DeleteDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt) stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
@ -206,7 +184,7 @@ func (s *devicesStatements) deleteDevicesByLocalpart(
return err return err
} }
func (s *devicesStatements) updateDeviceName( func (s *devicesStatements) UpdateDeviceName(
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string, ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
) error { ) error {
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt) stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
@ -214,7 +192,7 @@ func (s *devicesStatements) updateDeviceName(
return err return err
} }
func (s *devicesStatements) selectDeviceByToken( func (s *devicesStatements) SelectDeviceByToken(
ctx context.Context, accessToken string, ctx context.Context, accessToken string,
) (*api.Device, error) { ) (*api.Device, error) {
var dev api.Device var dev api.Device
@ -230,7 +208,7 @@ func (s *devicesStatements) selectDeviceByToken(
// selectDeviceByID retrieves a device from the database with the given user // selectDeviceByID retrieves a device from the database with the given user
// localpart and deviceID // localpart and deviceID
func (s *devicesStatements) selectDeviceByID( func (s *devicesStatements) SelectDeviceByID(
ctx context.Context, localpart, deviceID string, ctx context.Context, localpart, deviceID string,
) (*api.Device, error) { ) (*api.Device, error) {
var dev api.Device var dev api.Device
@ -247,7 +225,7 @@ func (s *devicesStatements) selectDeviceByID(
return &dev, err return &dev, err
} }
func (s *devicesStatements) selectDevicesByLocalpart( func (s *devicesStatements) SelectDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string, ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
) ([]api.Device, error) { ) ([]api.Device, error) {
devices := []api.Device{} devices := []api.Device{}
@ -288,7 +266,7 @@ func (s *devicesStatements) selectDevicesByLocalpart(
return devices, nil return devices, nil
} }
func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) { func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
sqlQuery := strings.Replace(selectDevicesByIDSQL, "($1)", sqlutil.QueryVariadic(len(deviceIDs)), 1) sqlQuery := strings.Replace(selectDevicesByIDSQL, "($1)", sqlutil.QueryVariadic(len(deviceIDs)), 1)
iDeviceIDs := make([]interface{}, len(deviceIDs)) iDeviceIDs := make([]interface{}, len(deviceIDs))
for i := range deviceIDs { for i := range deviceIDs {
@ -317,7 +295,7 @@ func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []s
return devices, rows.Err() return devices, rows.Err()
} }
func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error { func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error {
lastSeenTs := time.Now().UnixNano() / 1000000 lastSeenTs := time.Now().UnixNano() / 1000000
stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt) stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt)
_, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, localpart, deviceID) _, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, localpart, deviceID)

View file

@ -22,6 +22,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
) )
const keyBackupTableSchema = ` const keyBackupTableSchema = `
@ -72,12 +73,13 @@ type keyBackupStatements struct {
selectKeysByRoomIDAndSessionIDStmt *sql.Stmt selectKeysByRoomIDAndSessionIDStmt *sql.Stmt
} }
func (s *keyBackupStatements) prepare(db *sql.DB) (err error) { func NewSQLiteKeyBackupTable(db *sql.DB) (tables.KeyBackupTable, error) {
_, err = db.Exec(keyBackupTableSchema) s := &keyBackupStatements{}
_, err := db.Exec(keyBackupTableSchema)
if err != nil { if err != nil {
return return nil, err
} }
return sqlutil.StatementList{ return s, sqlutil.StatementList{
{&s.insertBackupKeyStmt, insertBackupKeySQL}, {&s.insertBackupKeyStmt, insertBackupKeySQL},
{&s.updateBackupKeyStmt, updateBackupKeySQL}, {&s.updateBackupKeyStmt, updateBackupKeySQL},
{&s.countKeysStmt, countKeysSQL}, {&s.countKeysStmt, countKeysSQL},
@ -87,14 +89,14 @@ func (s *keyBackupStatements) prepare(db *sql.DB) (err error) {
}.Prepare(db) }.Prepare(db)
} }
func (s keyBackupStatements) countKeys( func (s keyBackupStatements) CountKeys(
ctx context.Context, txn *sql.Tx, userID, version string, ctx context.Context, txn *sql.Tx, userID, version string,
) (count int64, err error) { ) (count int64, err error) {
err = txn.Stmt(s.countKeysStmt).QueryRowContext(ctx, userID, version).Scan(&count) err = txn.Stmt(s.countKeysStmt).QueryRowContext(ctx, userID, version).Scan(&count)
return return
} }
func (s *keyBackupStatements) insertBackupKey( func (s *keyBackupStatements) InsertBackupKey(
ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession, ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession,
) (err error) { ) (err error) {
_, err = txn.Stmt(s.insertBackupKeyStmt).ExecContext( _, err = txn.Stmt(s.insertBackupKeyStmt).ExecContext(
@ -103,7 +105,7 @@ func (s *keyBackupStatements) insertBackupKey(
return return
} }
func (s *keyBackupStatements) updateBackupKey( func (s *keyBackupStatements) UpdateBackupKey(
ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession, ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession,
) (err error) { ) (err error) {
_, err = txn.Stmt(s.updateBackupKeyStmt).ExecContext( _, err = txn.Stmt(s.updateBackupKeyStmt).ExecContext(
@ -112,7 +114,7 @@ func (s *keyBackupStatements) updateBackupKey(
return return
} }
func (s *keyBackupStatements) selectKeys( func (s *keyBackupStatements) SelectKeys(
ctx context.Context, txn *sql.Tx, userID, version string, ctx context.Context, txn *sql.Tx, userID, version string,
) (map[string]map[string]api.KeyBackupSession, error) { ) (map[string]map[string]api.KeyBackupSession, error) {
rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version) rows, err := txn.Stmt(s.selectKeysStmt).QueryContext(ctx, userID, version)
@ -122,7 +124,7 @@ func (s *keyBackupStatements) selectKeys(
return unpackKeys(ctx, rows) return unpackKeys(ctx, rows)
} }
func (s *keyBackupStatements) selectKeysByRoomID( func (s *keyBackupStatements) SelectKeysByRoomID(
ctx context.Context, txn *sql.Tx, userID, version, roomID string, ctx context.Context, txn *sql.Tx, userID, version, roomID string,
) (map[string]map[string]api.KeyBackupSession, error) { ) (map[string]map[string]api.KeyBackupSession, error) {
rows, err := txn.Stmt(s.selectKeysByRoomIDStmt).QueryContext(ctx, userID, version, roomID) rows, err := txn.Stmt(s.selectKeysByRoomIDStmt).QueryContext(ctx, userID, version, roomID)
@ -132,7 +134,7 @@ func (s *keyBackupStatements) selectKeysByRoomID(
return unpackKeys(ctx, rows) return unpackKeys(ctx, rows)
} }
func (s *keyBackupStatements) selectKeysByRoomIDAndSessionID( func (s *keyBackupStatements) SelectKeysByRoomIDAndSessionID(
ctx context.Context, txn *sql.Tx, userID, version, roomID, sessionID string, ctx context.Context, txn *sql.Tx, userID, version, roomID, sessionID string,
) (map[string]map[string]api.KeyBackupSession, error) { ) (map[string]map[string]api.KeyBackupSession, error) {
rows, err := txn.Stmt(s.selectKeysByRoomIDAndSessionIDStmt).QueryContext(ctx, userID, version, roomID, sessionID) rows, err := txn.Stmt(s.selectKeysByRoomIDAndSessionIDStmt).QueryContext(ctx, userID, version, roomID, sessionID)

View file

@ -22,6 +22,7 @@ import (
"strconv" "strconv"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
) )
const keyBackupVersionTableSchema = ` const keyBackupVersionTableSchema = `
@ -67,12 +68,13 @@ type keyBackupVersionStatements struct {
updateKeyBackupETagStmt *sql.Stmt updateKeyBackupETagStmt *sql.Stmt
} }
func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) { func NewSQLiteKeyBackupVersionTable(db *sql.DB) (tables.KeyBackupVersionTable, error) {
_, err = db.Exec(keyBackupVersionTableSchema) s := &keyBackupVersionStatements{}
_, err := db.Exec(keyBackupVersionTableSchema)
if err != nil { if err != nil {
return return nil, err
} }
return sqlutil.StatementList{ return s, sqlutil.StatementList{
{&s.insertKeyBackupStmt, insertKeyBackupSQL}, {&s.insertKeyBackupStmt, insertKeyBackupSQL},
{&s.updateKeyBackupAuthDataStmt, updateKeyBackupAuthDataSQL}, {&s.updateKeyBackupAuthDataStmt, updateKeyBackupAuthDataSQL},
{&s.deleteKeyBackupStmt, deleteKeyBackupSQL}, {&s.deleteKeyBackupStmt, deleteKeyBackupSQL},
@ -82,7 +84,7 @@ func (s *keyBackupVersionStatements) prepare(db *sql.DB) (err error) {
}.Prepare(db) }.Prepare(db)
} }
func (s *keyBackupVersionStatements) insertKeyBackup( func (s *keyBackupVersionStatements) InsertKeyBackup(
ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage, etag string, ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage, etag string,
) (version string, err error) { ) (version string, err error) {
var versionInt int64 var versionInt int64
@ -90,7 +92,7 @@ func (s *keyBackupVersionStatements) insertKeyBackup(
return strconv.FormatInt(versionInt, 10), err return strconv.FormatInt(versionInt, 10), err
} }
func (s *keyBackupVersionStatements) updateKeyBackupAuthData( func (s *keyBackupVersionStatements) UpdateKeyBackupAuthData(
ctx context.Context, txn *sql.Tx, userID, version string, authData json.RawMessage, ctx context.Context, txn *sql.Tx, userID, version string, authData json.RawMessage,
) error { ) error {
versionInt, err := strconv.ParseInt(version, 10, 64) versionInt, err := strconv.ParseInt(version, 10, 64)
@ -101,7 +103,7 @@ func (s *keyBackupVersionStatements) updateKeyBackupAuthData(
return err return err
} }
func (s *keyBackupVersionStatements) updateKeyBackupETag( func (s *keyBackupVersionStatements) UpdateKeyBackupETag(
ctx context.Context, txn *sql.Tx, userID, version, etag string, ctx context.Context, txn *sql.Tx, userID, version, etag string,
) error { ) error {
versionInt, err := strconv.ParseInt(version, 10, 64) versionInt, err := strconv.ParseInt(version, 10, 64)
@ -112,7 +114,7 @@ func (s *keyBackupVersionStatements) updateKeyBackupETag(
return err return err
} }
func (s *keyBackupVersionStatements) deleteKeyBackup( func (s *keyBackupVersionStatements) DeleteKeyBackup(
ctx context.Context, txn *sql.Tx, userID, version string, ctx context.Context, txn *sql.Tx, userID, version string,
) (bool, error) { ) (bool, error) {
versionInt, err := strconv.ParseInt(version, 10, 64) versionInt, err := strconv.ParseInt(version, 10, 64)
@ -130,7 +132,7 @@ func (s *keyBackupVersionStatements) deleteKeyBackup(
return ra == 1, nil return ra == 1, nil
} }
func (s *keyBackupVersionStatements) selectKeyBackup( func (s *keyBackupVersionStatements) SelectKeyBackup(
ctx context.Context, txn *sql.Tx, userID, version string, ctx context.Context, txn *sql.Tx, userID, version string,
) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) { ) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) {
var versionInt int64 var versionInt int64

View file

@ -21,18 +21,17 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/util" "github.com/matrix-org/util"
) )
type loginTokenStatements struct { type loginTokenStatements struct {
insertStmt *sql.Stmt insertStmt *sql.Stmt
deleteStmt *sql.Stmt deleteStmt *sql.Stmt
selectByTokenStmt *sql.Stmt selectStmt *sql.Stmt
} }
// execSchema ensures tables and indices exist. const loginTokenSchema = `
func (s *loginTokenStatements) execSchema(db *sql.DB) error {
_, err := db.Exec(`
CREATE TABLE IF NOT EXISTS login_tokens ( CREATE TABLE IF NOT EXISTS login_tokens (
-- The random value of the token issued to a user -- The random value of the token issued to a user
token TEXT NOT NULL PRIMARY KEY, token TEXT NOT NULL PRIMARY KEY,
@ -45,24 +44,32 @@ CREATE TABLE IF NOT EXISTS login_tokens (
-- This index allows efficient garbage collection of expired tokens. -- This index allows efficient garbage collection of expired tokens.
CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_expires_at); CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_expires_at);
`) `
return err
}
// prepare runs statement preparation. const insertLoginTokenSQL = "" +
func (s *loginTokenStatements) prepare(db *sql.DB) error { "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"
if err := s.execSchema(db); err != nil {
return err const deleteLoginTokenSQL = "" +
"DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"
const selectLoginTokenSQL = "" +
"SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2"
func NewSQLiteLoginTokenTable(db *sql.DB) (tables.LoginTokenTable, error) {
s := &loginTokenStatements{}
_, err := db.Exec(loginTokenSchema)
if err != nil {
return nil, err
} }
return sqlutil.StatementList{ return s, sqlutil.StatementList{
{&s.insertStmt, "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"}, {&s.insertStmt, insertLoginTokenSQL},
{&s.deleteStmt, "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= $2"}, {&s.deleteStmt, deleteLoginTokenSQL},
{&s.selectByTokenStmt, "SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > $2"}, {&s.selectStmt, selectLoginTokenSQL},
}.Prepare(db) }.Prepare(db)
} }
// insert adds an already generated token to the database. // insert adds an already generated token to the database.
func (s *loginTokenStatements) insert(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error { func (s *loginTokenStatements) InsertLoginToken(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error {
stmt := sqlutil.TxStmt(txn, s.insertStmt) stmt := sqlutil.TxStmt(txn, s.insertStmt)
_, err := stmt.ExecContext(ctx, metadata.Token, metadata.Expiration.UTC(), data.UserID) _, err := stmt.ExecContext(ctx, metadata.Token, metadata.Expiration.UTC(), data.UserID)
return err return err
@ -72,7 +79,7 @@ func (s *loginTokenStatements) insert(ctx context.Context, txn *sql.Tx, metadata
// //
// As a simple way to garbage-collect stale tokens, we also remove all expired tokens. // As a simple way to garbage-collect stale tokens, we also remove all expired tokens.
// The login_tokens_expiration_idx index should make that efficient. // The login_tokens_expiration_idx index should make that efficient.
func (s *loginTokenStatements) deleteByToken(ctx context.Context, txn *sql.Tx, token string) error { func (s *loginTokenStatements) DeleteLoginToken(ctx context.Context, txn *sql.Tx, token string) error {
stmt := sqlutil.TxStmt(txn, s.deleteStmt) stmt := sqlutil.TxStmt(txn, s.deleteStmt)
res, err := stmt.ExecContext(ctx, token, time.Now().UTC()) res, err := stmt.ExecContext(ctx, token, time.Now().UTC())
if err != nil { if err != nil {
@ -85,9 +92,9 @@ func (s *loginTokenStatements) deleteByToken(ctx context.Context, txn *sql.Tx, t
} }
// selectByToken returns the data associated with the given token. May return sql.ErrNoRows. // selectByToken returns the data associated with the given token. May return sql.ErrNoRows.
func (s *loginTokenStatements) selectByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { func (s *loginTokenStatements) SelectLoginToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
var data api.LoginTokenData var data api.LoginTokenData
err := s.selectByTokenStmt.QueryRowContext(ctx, token, time.Now().UTC()).Scan(&data.UserID) err := s.selectStmt.QueryRowContext(ctx, token, time.Now().UTC()).Scan(&data.UserID)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -6,6 +6,7 @@ import (
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@ -22,35 +23,37 @@ CREATE TABLE IF NOT EXISTS open_id_tokens (
); );
` `
const insertTokenSQL = "" + const insertOpenIDTokenSQL = "" +
"INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)" "INSERT INTO open_id_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
const selectTokenSQL = "" + const selectOpenIDTokenSQL = "" +
"SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1" "SELECT localpart, token_expires_at_ms FROM open_id_tokens WHERE token = $1"
type tokenStatements struct { type openIDTokenStatements struct {
db *sql.DB db *sql.DB
insertTokenStmt *sql.Stmt insertTokenStmt *sql.Stmt
selectTokenStmt *sql.Stmt selectTokenStmt *sql.Stmt
serverName gomatrixserverlib.ServerName serverName gomatrixserverlib.ServerName
} }
func (s *tokenStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) { func NewSQLiteOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.OpenIDTable, error) {
s.db = db s := &openIDTokenStatements{
_, err = db.Exec(openIDTokenSchema) db: db,
if err != nil { serverName: serverName,
return err
} }
s.serverName = server _, err := db.Exec(openIDTokenSchema)
return sqlutil.StatementList{ if err != nil {
{&s.insertTokenStmt, insertTokenSQL}, return nil, err
{&s.selectTokenStmt, selectTokenSQL}, }
return s, sqlutil.StatementList{
{&s.insertTokenStmt, insertOpenIDTokenSQL},
{&s.selectTokenStmt, selectOpenIDTokenSQL},
}.Prepare(db) }.Prepare(db)
} }
// insertToken inserts a new OpenID Connect token to the DB. // insertToken inserts a new OpenID Connect token to the DB.
// Returns new token, otherwise returns error if the token already exists. // Returns new token, otherwise returns error if the token already exists.
func (s *tokenStatements) insertToken( func (s *openIDTokenStatements) InsertOpenIDToken(
ctx context.Context, ctx context.Context,
txn *sql.Tx, txn *sql.Tx,
token, localpart string, token, localpart string,
@ -63,7 +66,7 @@ func (s *tokenStatements) insertToken(
// selectOpenIDTokenAtrributes gets the attributes associated with an OpenID token from the DB // selectOpenIDTokenAtrributes gets the attributes associated with an OpenID token from the DB
// Returns the existing token's attributes, or err if no token is found // Returns the existing token's attributes, or err if no token is found
func (s *tokenStatements) selectOpenIDTokenAtrributes( func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes(
ctx context.Context, ctx context.Context,
token string, token string,
) (*api.OpenIDTokenAttributes, error) { ) (*api.OpenIDTokenAttributes, error) {

View file

@ -22,6 +22,7 @@ import (
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
) )
const profilesSchema = ` const profilesSchema = `
@ -60,13 +61,15 @@ type profilesStatements struct {
selectProfilesBySearchStmt *sql.Stmt selectProfilesBySearchStmt *sql.Stmt
} }
func (s *profilesStatements) prepare(db *sql.DB) (err error) { func NewSQLiteProfilesTable(db *sql.DB) (tables.ProfileTable, error) {
s.db = db s := &profilesStatements{
_, err = db.Exec(profilesSchema) db: db,
if err != nil {
return
} }
return sqlutil.StatementList{ _, err := db.Exec(profilesSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.insertProfileStmt, insertProfileSQL}, {&s.insertProfileStmt, insertProfileSQL},
{&s.selectProfileByLocalpartStmt, selectProfileByLocalpartSQL}, {&s.selectProfileByLocalpartStmt, selectProfileByLocalpartSQL},
{&s.setAvatarURLStmt, setAvatarURLSQL}, {&s.setAvatarURLStmt, setAvatarURLSQL},
@ -75,14 +78,14 @@ func (s *profilesStatements) prepare(db *sql.DB) (err error) {
}.Prepare(db) }.Prepare(db)
} }
func (s *profilesStatements) insertProfile( func (s *profilesStatements) InsertProfile(
ctx context.Context, txn *sql.Tx, localpart string, ctx context.Context, txn *sql.Tx, localpart string,
) error { ) error {
_, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "") _, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
return err return err
} }
func (s *profilesStatements) selectProfileByLocalpart( func (s *profilesStatements) SelectProfileByLocalpart(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) (*authtypes.Profile, error) { ) (*authtypes.Profile, error) {
var profile authtypes.Profile var profile authtypes.Profile
@ -95,7 +98,7 @@ func (s *profilesStatements) selectProfileByLocalpart(
return &profile, nil return &profile, nil
} }
func (s *profilesStatements) setAvatarURL( func (s *profilesStatements) SetAvatarURL(
ctx context.Context, txn *sql.Tx, localpart string, avatarURL string, ctx context.Context, txn *sql.Tx, localpart string, avatarURL string,
) (err error) { ) (err error) {
stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt) stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt)
@ -103,7 +106,7 @@ func (s *profilesStatements) setAvatarURL(
return return
} }
func (s *profilesStatements) setDisplayName( func (s *profilesStatements) SetDisplayName(
ctx context.Context, txn *sql.Tx, localpart string, displayName string, ctx context.Context, txn *sql.Tx, localpart string, displayName string,
) (err error) { ) (err error) {
stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt) stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt)
@ -111,7 +114,7 @@ func (s *profilesStatements) setDisplayName(
return return
} }
func (s *profilesStatements) selectProfilesBySearch( func (s *profilesStatements) SelectProfilesBySearch(
ctx context.Context, searchString string, limit int, ctx context.Context, searchString string, limit int,
) ([]authtypes.Profile, error) { ) ([]authtypes.Profile, error) {
var profiles []authtypes.Profile var profiles []authtypes.Profile

View file

@ -15,80 +15,34 @@
package sqlite3 package sqlite3
import ( import (
"context"
"crypto/rand"
"database/sql"
"encoding/base64"
"encoding/json"
"errors"
"fmt" "fmt"
"strconv"
"sync"
"time" "time"
"github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/gomatrixserverlib"
"golang.org/x/crypto/bcrypt"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/setup/config"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/shared"
"github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas" "github.com/matrix-org/dendrite/userapi/storage/sqlite3/deltas"
)
// Database represents an account database // Import the postgres database driver.
type Database struct { _ "github.com/lib/pq"
db *sql.DB
writer sqlutil.Writer
sqlutil.PartitionOffsetStatements
accounts accountsStatements
profiles profilesStatements
accountDatas accountDataStatements
threepids threepidStatements
openIDTokens tokenStatements
keyBackupVersions keyBackupVersionStatements
keyBackups keyBackupStatements
devices devicesStatements
loginTokens loginTokenStatements
loginTokenLifetime time.Duration
serverName gomatrixserverlib.ServerName
bcryptCost int
openIDTokenLifetimeMS int64
accountsMu sync.Mutex
profilesMu sync.Mutex
accountDatasMu sync.Mutex
threepidsMu sync.Mutex
}
const (
// The length of generated device IDs
deviceIDByteLength = 6
loginTokenByteLength = 32
) )
// NewDatabase creates a new accounts and profiles database // NewDatabase creates a new accounts and profiles database
func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration) (*Database, error) { func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, bcryptCost int, openIDTokenLifetimeMS int64, loginTokenLifetime time.Duration) (*shared.Database, error) {
db, err := sqlutil.Open(dbProperties) db, err := sqlutil.Open(dbProperties)
if err != nil { if err != nil {
return nil, err return nil, err
} }
d := &Database{
serverName: serverName,
db: db,
writer: sqlutil.NewExclusiveWriter(),
loginTokenLifetime: loginTokenLifetime,
bcryptCost: bcryptCost,
openIDTokenLifetimeMS: openIDTokenLifetimeMS,
}
// Create tables before executing migrations so we don't fail if the table is missing, m := sqlutil.NewMigrations()
// and THEN prepare statements so we don't fail due to referencing new columns if _, err = db.Exec(accountsSchema); err != nil {
if err = d.accounts.execSchema(db); err != nil { // do this so that the migration can and we don't fail on
// preparing statements for columns that don't exist yet
return nil, err return nil, err
} }
m := sqlutil.NewMigrations()
deltas.LoadIsActive(m) deltas.LoadIsActive(m)
//deltas.LoadLastSeenTSIP(m) //deltas.LoadLastSeenTSIP(m)
deltas.LoadAddAccountType(m) deltas.LoadAddAccountType(m)
@ -96,666 +50,57 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver
return nil, err return nil, err
} }
partitions := sqlutil.PartitionOffsetStatements{} accountDataTable, err := NewSQLiteAccountDataTable(db)
if err = partitions.Prepare(db, d.writer, "account"); err != nil {
return nil, err
}
if err = d.accounts.prepare(db, serverName); err != nil {
return nil, err
}
if err = d.profiles.prepare(db); err != nil {
return nil, err
}
if err = d.accountDatas.prepare(db); err != nil {
return nil, err
}
if err = d.threepids.prepare(db); err != nil {
return nil, err
}
if err = d.openIDTokens.prepare(db, serverName); err != nil {
return nil, err
}
if err = d.keyBackupVersions.prepare(db); err != nil {
return nil, err
}
if err = d.keyBackups.prepare(db); err != nil {
return nil, err
}
if err = d.devices.prepare(db, d.writer, serverName); err != nil {
return nil, err
}
if err = d.loginTokens.prepare(db); err != nil {
return nil, err
}
return d, nil
}
// GetAccountByPassword returns the account associated with the given localpart and password.
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByPassword(
ctx context.Context, localpart, plaintextPassword string,
) (*api.Account, error) {
hash, err := d.accounts.selectPasswordHash(ctx, localpart)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("NewSQLiteAccountDataTable: %w", err)
} }
if err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(plaintextPassword)); err != nil { accountsTable, err := NewSQLiteAccountsTable(db, serverName)
return nil, err
}
return d.accounts.selectAccountByLocalpart(ctx, localpart)
}
// GetProfileByLocalpart returns the profile associated with the given localpart.
// Returns sql.ErrNoRows if no profile exists which matches the given localpart.
func (d *Database) GetProfileByLocalpart(
ctx context.Context, localpart string,
) (*authtypes.Profile, error) {
return d.profiles.selectProfileByLocalpart(ctx, localpart)
}
// SetAvatarURL updates the avatar URL of the profile associated with the given
// localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetAvatarURL(
ctx context.Context, localpart string, avatarURL string,
) error {
d.profilesMu.Lock()
defer d.profilesMu.Unlock()
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.profiles.setAvatarURL(ctx, txn, localpart, avatarURL)
})
}
// SetDisplayName updates the display name of the profile associated with the given
// localpart. Returns an error if something went wrong with the SQL query
func (d *Database) SetDisplayName(
ctx context.Context, localpart string, displayName string,
) error {
d.profilesMu.Lock()
defer d.profilesMu.Unlock()
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.profiles.setDisplayName(ctx, txn, localpart, displayName)
})
}
// SetPassword sets the account password to the given hash.
func (d *Database) SetPassword(
ctx context.Context, localpart, plaintextPassword string,
) error {
hash, err := d.hashPassword(plaintextPassword)
if err != nil { if err != nil {
return err return nil, fmt.Errorf("NewSQLiteAccountsTable: %w", err)
} }
return d.writer.Do(nil, nil, func(txn *sql.Tx) error { devicesTable, err := NewSQLiteDevicesTable(db, serverName)
return d.accounts.updatePassword(ctx, localpart, hash)
})
}
// CreateAccount makes a new account with the given login name and password, and creates an empty profile
// for this account. If no password is supplied, the account will be a passwordless account. If the
// account already exists, it will return nil, ErrUserExists.
func (d *Database) CreateAccount(
ctx context.Context, localpart, plaintextPassword, appserviceID string, accountType api.AccountType,
) (acc *api.Account, err error) {
// Create one account at a time else we can get 'database is locked'.
d.profilesMu.Lock()
d.accountDatasMu.Lock()
d.accountsMu.Lock()
defer d.profilesMu.Unlock()
defer d.accountDatasMu.Unlock()
defer d.accountsMu.Unlock()
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
// For guest accounts, we create a new numeric local part
if accountType == api.AccountTypeGuest {
var numLocalpart int64
numLocalpart, err = d.accounts.selectNewNumericLocalpart(ctx, txn)
if err != nil { if err != nil {
return err return nil, fmt.Errorf("NewSQLiteDevicesTable: %w", err)
} }
localpart = strconv.FormatInt(numLocalpart, 10) keyBackupTable, err := NewSQLiteKeyBackupTable(db)
plaintextPassword = ""
appserviceID = ""
}
acc, err = d.createAccount(ctx, txn, localpart, plaintextPassword, appserviceID, accountType)
return err
})
return
}
// WARNING! This function assumes that the relevant mutexes have already
// been taken out by the caller (e.g. CreateAccount or CreateGuestAccount).
func (d *Database) createAccount(
ctx context.Context, txn *sql.Tx, localpart, plaintextPassword, appserviceID string, accountType api.AccountType,
) (*api.Account, error) {
var err error
var account *api.Account
// Generate a password hash if this is not a password-less user
hash := ""
if plaintextPassword != "" {
hash, err = d.hashPassword(plaintextPassword)
if err != nil { if err != nil {
return nil, err return nil, fmt.Errorf("NewSQLiteKeyBackupTable: %w", err)
} }
} keyBackupVersionTable, err := NewSQLiteKeyBackupVersionTable(db)
if account, err = d.accounts.insertAccount(ctx, txn, localpart, hash, appserviceID, accountType); err != nil {
return nil, sqlutil.ErrUserExists
}
if err = d.profiles.insertProfile(ctx, txn, localpart); err != nil {
return nil, err
}
if err = d.accountDatas.insertAccountData(ctx, txn, localpart, "", "m.push_rules", json.RawMessage(`{
"global": {
"content": [],
"override": [],
"room": [],
"sender": [],
"underride": []
}
}`)); err != nil {
return nil, err
}
return account, nil
}
// SaveAccountData saves new account data for a given user and a given room.
// If the account data is not specific to a room, the room ID should be an empty string
// If an account data already exists for a given set (user, room, data type), it will
// update the corresponding row with the new content
// Returns a SQL error if there was an issue with the insertion/update
func (d *Database) SaveAccountData(
ctx context.Context, localpart, roomID, dataType string, content json.RawMessage,
) error {
d.accountDatasMu.Lock()
defer d.accountDatasMu.Unlock()
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.accountDatas.insertAccountData(ctx, txn, localpart, roomID, dataType, content)
})
}
// GetAccountData returns account data related to a given localpart
// If no account data could be found, returns an empty arrays
// Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountData(ctx context.Context, localpart string) (
global map[string]json.RawMessage,
rooms map[string]map[string]json.RawMessage,
err error,
) {
return d.accountDatas.selectAccountData(ctx, localpart)
}
// GetAccountDataByType returns account data matching a given
// localpart, room ID and type.
// If no account data could be found, returns nil
// Returns an error if there was an issue with the retrieval
func (d *Database) GetAccountDataByType(
ctx context.Context, localpart, roomID, dataType string,
) (data json.RawMessage, err error) {
return d.accountDatas.selectAccountDataByType(
ctx, localpart, roomID, dataType,
)
}
// GetNewNumericLocalpart generates and returns a new unused numeric localpart
func (d *Database) GetNewNumericLocalpart(
ctx context.Context,
) (int64, error) {
return d.accounts.selectNewNumericLocalpart(ctx, nil)
}
func (d *Database) hashPassword(plaintext string) (hash string, err error) {
hashBytes, err := bcrypt.GenerateFromPassword([]byte(plaintext), d.bcryptCost)
return string(hashBytes), err
}
// Err3PIDInUse is the error returned when trying to save an association involving
// a third-party identifier which is already associated to a local user.
var Err3PIDInUse = errors.New("this third-party identifier is already in use")
// SaveThreePIDAssociation saves the association between a third party identifier
// and a local Matrix user (identified by the user's ID's local part).
// If the third-party identifier is already part of an association, returns Err3PIDInUse.
// Returns an error if there was a problem talking to the database.
func (d *Database) SaveThreePIDAssociation(
ctx context.Context, threepid, localpart, medium string,
) (err error) {
d.threepidsMu.Lock()
defer d.threepidsMu.Unlock()
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
user, err := d.threepids.selectLocalpartForThreePID(
ctx, txn, threepid, medium,
)
if err != nil { if err != nil {
return err return nil, fmt.Errorf("NewSQLiteKeyBackupVersionTable: %w", err)
} }
loginTokenTable, err := NewSQLiteLoginTokenTable(db)
if len(user) > 0 {
return Err3PIDInUse
}
return d.threepids.insertThreePID(ctx, txn, threepid, medium, localpart)
})
}
// RemoveThreePIDAssociation removes the association involving a given third-party
// identifier.
// If no association exists involving this third-party identifier, returns nothing.
// If there was a problem talking to the database, returns an error.
func (d *Database) RemoveThreePIDAssociation(
ctx context.Context, threepid string, medium string,
) (err error) {
d.threepidsMu.Lock()
defer d.threepidsMu.Unlock()
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.threepids.deleteThreePID(ctx, txn, threepid, medium)
})
}
// GetLocalpartForThreePID looks up the localpart associated with a given third-party
// identifier.
// If no association involves the given third-party idenfitier, returns an empty
// string.
// Returns an error if there was a problem talking to the database.
func (d *Database) GetLocalpartForThreePID(
ctx context.Context, threepid string, medium string,
) (localpart string, err error) {
return d.threepids.selectLocalpartForThreePID(ctx, nil, threepid, medium)
}
// GetThreePIDsForLocalpart looks up the third-party identifiers associated with
// a given local user.
// If no association is known for this user, returns an empty slice.
// Returns an error if there was an issue talking to the database.
func (d *Database) GetThreePIDsForLocalpart(
ctx context.Context, localpart string,
) (threepids []authtypes.ThreePID, err error) {
return d.threepids.selectThreePIDsForLocalpart(ctx, localpart)
}
// CheckAccountAvailability checks if the username/localpart is already present
// in the database.
// If the DB returns sql.ErrNoRows the Localpart isn't taken.
func (d *Database) CheckAccountAvailability(ctx context.Context, localpart string) (bool, error) {
_, err := d.accounts.selectAccountByLocalpart(ctx, localpart)
if err == sql.ErrNoRows {
return true, nil
}
return false, err
}
// GetAccountByLocalpart returns the account associated with the given localpart.
// This function assumes the request is authenticated or the account data is used only internally.
// Returns sql.ErrNoRows if no account exists which matches the given localpart.
func (d *Database) GetAccountByLocalpart(ctx context.Context, localpart string,
) (*api.Account, error) {
return d.accounts.selectAccountByLocalpart(ctx, localpart)
}
// SearchProfiles returns all profiles where the provided localpart or display name
// match any part of the profiles in the database.
func (d *Database) SearchProfiles(ctx context.Context, searchString string, limit int,
) ([]authtypes.Profile, error) {
return d.profiles.selectProfilesBySearch(ctx, searchString, limit)
}
// DeactivateAccount deactivates the user's account, removing all ability for the user to login again.
func (d *Database) DeactivateAccount(ctx context.Context, localpart string) (err error) {
return d.writer.Do(nil, nil, func(txn *sql.Tx) error {
return d.accounts.deactivateAccount(ctx, localpart)
})
}
// CreateOpenIDToken persists a new token that was issued for OpenID Connect
func (d *Database) CreateOpenIDToken(
ctx context.Context,
token, localpart string,
) (int64, error) {
expiresAtMS := time.Now().UnixNano()/int64(time.Millisecond) + d.openIDTokenLifetimeMS
err := d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.openIDTokens.insertToken(ctx, txn, token, localpart, expiresAtMS)
})
return expiresAtMS, err
}
// GetOpenIDTokenAttributes gets the attributes of issued an OIDC auth token
func (d *Database) GetOpenIDTokenAttributes(
ctx context.Context,
token string,
) (*api.OpenIDTokenAttributes, error) {
return d.openIDTokens.selectOpenIDTokenAtrributes(ctx, token)
}
func (d *Database) CreateKeyBackup(
ctx context.Context, userID, algorithm string, authData json.RawMessage,
) (version string, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
version, err = d.keyBackupVersions.insertKeyBackup(ctx, txn, userID, algorithm, authData, "")
return err
})
return
}
func (d *Database) UpdateKeyBackupAuthData(
ctx context.Context, userID, version string, authData json.RawMessage,
) (err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.keyBackupVersions.updateKeyBackupAuthData(ctx, txn, userID, version, authData)
})
return
}
func (d *Database) DeleteKeyBackup(
ctx context.Context, userID, version string,
) (exists bool, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
exists, err = d.keyBackupVersions.deleteKeyBackup(ctx, txn, userID, version)
return err
})
return
}
func (d *Database) GetKeyBackup(
ctx context.Context, userID, version string,
) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
versionResult, algorithm, authData, etag, deleted, err = d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version)
return err
})
return
}
func (d *Database) GetBackupKeys(
ctx context.Context, version, userID, filterRoomID, filterSessionID string,
) (result map[string]map[string]api.KeyBackupSession, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
if filterSessionID != "" {
result, err = d.keyBackups.selectKeysByRoomIDAndSessionID(ctx, txn, userID, version, filterRoomID, filterSessionID)
return err
}
if filterRoomID != "" {
result, err = d.keyBackups.selectKeysByRoomID(ctx, txn, userID, version, filterRoomID)
return err
}
result, err = d.keyBackups.selectKeys(ctx, txn, userID, version)
return err
})
return
}
func (d *Database) CountBackupKeys(
ctx context.Context, version, userID string,
) (count int64, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
count, err = d.keyBackups.countKeys(ctx, txn, userID, version)
if err != nil { if err != nil {
return err return nil, fmt.Errorf("NewSQLiteLoginTokenTable: %w", err)
} }
return nil openIDTable, err := NewSQLiteOpenIDTable(db, serverName)
})
return
}
// nolint:nakedret
func (d *Database) UpsertBackupKeys(
ctx context.Context, version, userID string, uploads []api.InternalKeyBackupSession,
) (count int64, etag string, err error) {
// wrap the following logic in a txn to ensure we atomically upload keys
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
_, _, _, oldETag, deleted, err := d.keyBackupVersions.selectKeyBackup(ctx, txn, userID, version)
if err != nil { if err != nil {
return err return nil, fmt.Errorf("NewSQLiteOpenIDTable: %w", err)
} }
if deleted { profilesTable, err := NewSQLiteProfilesTable(db)
return fmt.Errorf("backup was deleted")
}
// pull out all keys for this (user_id, version)
existingKeys, err := d.keyBackups.selectKeys(ctx, txn, userID, version)
if err != nil { if err != nil {
return err return nil, fmt.Errorf("NewSQLiteProfilesTable: %w", err)
} }
threePIDTable, err := NewSQLiteThreePIDTable(db)
changed := false
// loop over all the new keys (which should be smaller than the set of backed up keys)
for _, newKey := range uploads {
// if we have a matching (room_id, session_id), we may need to update the key if it meets some rules, check them.
existingRoom := existingKeys[newKey.RoomID]
if existingRoom != nil {
existingSession, ok := existingRoom[newKey.SessionID]
if ok {
if existingSession.ShouldReplaceRoomKey(&newKey.KeyBackupSession) {
err = d.keyBackups.updateBackupKey(ctx, txn, userID, version, newKey)
changed = true
if err != nil { if err != nil {
return fmt.Errorf("d.keyBackups.updateBackupKey: %w", err) return nil, fmt.Errorf("NewSQLiteThreePIDTable: %w", err)
} }
} return &shared.Database{
// if we shouldn't replace the key we do nothing with it AccountDatas: accountDataTable,
continue Accounts: accountsTable,
} Devices: devicesTable,
} KeyBackups: keyBackupTable,
// if we're here, either the room or session are new, either way, we insert KeyBackupVersions: keyBackupVersionTable,
err = d.keyBackups.insertBackupKey(ctx, txn, userID, version, newKey) LoginTokens: loginTokenTable,
changed = true OpenIDTokens: openIDTable,
if err != nil { Profiles: profilesTable,
return fmt.Errorf("d.keyBackups.insertBackupKey: %w", err) ThreePIDs: threePIDTable,
} ServerName: serverName,
} DB: db,
Writer: sqlutil.NewExclusiveWriter(),
count, err = d.keyBackups.countKeys(ctx, txn, userID, version) LoginTokenLifetime: loginTokenLifetime,
if err != nil { BcryptCost: bcryptCost,
return err OpenIDTokenLifetimeMS: openIDTokenLifetimeMS,
} }, nil
if changed {
// update the etag
var newETag string
if oldETag == "" {
newETag = "1"
} else {
oldETagInt, err := strconv.ParseInt(oldETag, 10, 64)
if err != nil {
return fmt.Errorf("failed to parse old etag: %s", err)
}
newETag = strconv.FormatInt(oldETagInt+1, 10)
}
etag = newETag
return d.keyBackupVersions.updateKeyBackupETag(ctx, txn, userID, version, newETag)
} else {
etag = oldETag
}
return nil
})
return
}
// GetDeviceByAccessToken returns the device matching the given access token.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByAccessToken(
ctx context.Context, token string,
) (*api.Device, error) {
return d.devices.selectDeviceByToken(ctx, token)
}
// GetDeviceByID returns the device matching the given ID.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*api.Device, error) {
return d.devices.selectDeviceByID(ctx, localpart, deviceID)
}
// GetDevicesByLocalpart returns the devices matching the given localpart.
func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string,
) ([]api.Device, error) {
return d.devices.selectDevicesByLocalpart(ctx, nil, localpart, "")
}
func (d *Database) GetDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
return d.devices.selectDevicesByID(ctx, deviceIDs)
}
// CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
// an error will be returned.
// If no device ID is given one is generated.
// Returns the device on success.
func (d *Database) CreateDevice(
ctx context.Context, localpart string, deviceID *string, accessToken string,
displayName *string, ipAddr, userAgent string,
) (dev *api.Device, returnErr error) {
if deviceID != nil {
returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
var err error
// Revoke existing tokens for this device
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
return err
}
dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName, ipAddr, userAgent)
return err
})
} else {
// We generate device IDs in a loop in case its already taken.
// We cap this at going round 5 times to ensure we don't spin forever
var newDeviceID string
for i := 1; i <= 5; i++ {
newDeviceID, returnErr = generateDeviceID()
if returnErr != nil {
return
}
returnErr = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
var err error
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName, ipAddr, userAgent)
return err
})
if returnErr == nil {
return
}
}
}
return
}
// generateDeviceID creates a new device id. Returns an error if failed to generate
// random bytes.
func generateDeviceID() (string, error) {
b := make([]byte, deviceIDByteLength)
_, err := rand.Read(b)
if err != nil {
return "", err
}
// url-safe no padding
return base64.RawURLEncoding.EncodeToString(b), nil
}
// UpdateDevice updates the given device with the display name.
// Returns SQL error if there are problems and nil on success.
func (d *Database) UpdateDevice(
ctx context.Context, localpart, deviceID string, displayName *string,
) error {
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
})
}
// RemoveDevice revokes a device by deleting the entry in the database
// matching with the given device ID and user ID localpart.
// If the device doesn't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevice(
ctx context.Context, deviceID, localpart string,
) error {
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveDevices revokes one or more devices by deleting the entry in the database
// matching with the given device IDs and user ID localpart.
// If the devices don't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevices(
ctx context.Context, localpart string, devices []string,
) error {
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveAllDevices revokes devices by deleting the entry in the
// database matching the given user ID localpart.
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveAllDevices(
ctx context.Context, localpart, exceptDeviceID string,
) (devices []api.Device, err error) {
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
devices, err = d.devices.selectDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID)
if err != nil {
return err
}
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart, exceptDeviceID); err != sql.ErrNoRows {
return err
}
return nil
})
return
}
// UpdateDeviceLastSeen updates a the last seen timestamp and the ip address
func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID, ipAddr string) error {
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr)
})
}
// CreateLoginToken generates a token, stores and returns it. The lifetime is
// determined by the loginTokenLifetime given to the Database constructor.
func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) {
tok, err := generateLoginToken()
if err != nil {
return nil, err
}
meta := &api.LoginTokenMetadata{
Token: tok,
Expiration: time.Now().Add(d.loginTokenLifetime),
}
err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.loginTokens.insert(ctx, txn, meta, data)
})
if err != nil {
return nil, err
}
return meta, nil
}
func generateLoginToken() (string, error) {
b := make([]byte, loginTokenByteLength)
_, err := rand.Read(b)
if err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(b), nil
}
// RemoveLoginToken removes the named token (and may clean up other expired tokens).
func (d *Database) RemoveLoginToken(ctx context.Context, token string) error {
return d.writer.Do(d.db, nil, func(txn *sql.Tx) error {
return d.loginTokens.deleteByToken(ctx, txn, token)
})
}
// GetLoginTokenDataByToken returns the data associated with the given token.
// May return sql.ErrNoRows.
func (d *Database) GetLoginTokenDataByToken(ctx context.Context, token string) (*api.LoginTokenData, error) {
return d.loginTokens.selectByToken(ctx, token)
} }

View file

@ -20,6 +20,7 @@ import (
"github.com/matrix-org/dendrite/internal" "github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/clientapi/auth/authtypes"
) )
@ -60,13 +61,15 @@ type threepidStatements struct {
deleteThreePIDStmt *sql.Stmt deleteThreePIDStmt *sql.Stmt
} }
func (s *threepidStatements) prepare(db *sql.DB) (err error) { func NewSQLiteThreePIDTable(db *sql.DB) (tables.ThreePIDTable, error) {
s.db = db s := &threepidStatements{
_, err = db.Exec(threepidSchema) db: db,
if err != nil {
return
} }
return sqlutil.StatementList{ _, err := db.Exec(threepidSchema)
if err != nil {
return nil, err
}
return s, sqlutil.StatementList{
{&s.selectLocalpartForThreePIDStmt, selectLocalpartForThreePIDSQL}, {&s.selectLocalpartForThreePIDStmt, selectLocalpartForThreePIDSQL},
{&s.selectThreePIDsForLocalpartStmt, selectThreePIDsForLocalpartSQL}, {&s.selectThreePIDsForLocalpartStmt, selectThreePIDsForLocalpartSQL},
{&s.insertThreePIDStmt, insertThreePIDSQL}, {&s.insertThreePIDStmt, insertThreePIDSQL},
@ -74,7 +77,7 @@ func (s *threepidStatements) prepare(db *sql.DB) (err error) {
}.Prepare(db) }.Prepare(db)
} }
func (s *threepidStatements) selectLocalpartForThreePID( func (s *threepidStatements) SelectLocalpartForThreePID(
ctx context.Context, txn *sql.Tx, threepid string, medium string, ctx context.Context, txn *sql.Tx, threepid string, medium string,
) (localpart string, err error) { ) (localpart string, err error) {
stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt) stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt)
@ -85,7 +88,7 @@ func (s *threepidStatements) selectLocalpartForThreePID(
return return
} }
func (s *threepidStatements) selectThreePIDsForLocalpart( func (s *threepidStatements) SelectThreePIDsForLocalpart(
ctx context.Context, localpart string, ctx context.Context, localpart string,
) (threepids []authtypes.ThreePID, err error) { ) (threepids []authtypes.ThreePID, err error) {
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart) rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart)
@ -109,7 +112,7 @@ func (s *threepidStatements) selectThreePIDsForLocalpart(
return threepids, rows.Err() return threepids, rows.Err()
} }
func (s *threepidStatements) insertThreePID( func (s *threepidStatements) InsertThreePID(
ctx context.Context, txn *sql.Tx, threepid, medium, localpart string, ctx context.Context, txn *sql.Tx, threepid, medium, localpart string,
) (err error) { ) (err error) {
stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt) stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt)
@ -117,7 +120,7 @@ func (s *threepidStatements) insertThreePID(
return err return err
} }
func (s *threepidStatements) deleteThreePID( func (s *threepidStatements) DeleteThreePID(
ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error) { ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error) {
stmt := sqlutil.TxStmt(txn, s.deleteThreePIDStmt) stmt := sqlutil.TxStmt(txn, s.deleteThreePIDStmt)
_, err = stmt.ExecContext(ctx, threepid, medium) _, err = stmt.ExecContext(ctx, threepid, medium)

View file

@ -0,0 +1,95 @@
// Copyright 2022 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package tables
import (
"context"
"database/sql"
"encoding/json"
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
"github.com/matrix-org/dendrite/userapi/api"
)
type AccountDataTable interface {
InsertAccountData(ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage) error
SelectAccountData(ctx context.Context, localpart string) (map[string]json.RawMessage, map[string]map[string]json.RawMessage, error)
SelectAccountDataByType(ctx context.Context, localpart, roomID, dataType string) (data json.RawMessage, err error)
}
type AccountsTable interface {
InsertAccount(ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType) (*api.Account, error)
UpdatePassword(ctx context.Context, localpart, passwordHash string) (err error)
DeactivateAccount(ctx context.Context, localpart string) (err error)
SelectPasswordHash(ctx context.Context, localpart string) (hash string, err error)
SelectAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error)
SelectNewNumericLocalpart(ctx context.Context, txn *sql.Tx) (id int64, err error)
}
type DevicesTable interface {
InsertDevice(ctx context.Context, txn *sql.Tx, id, localpart, accessToken string, displayName *string, ipAddr, userAgent string) (*api.Device, error)
DeleteDevice(ctx context.Context, txn *sql.Tx, id, localpart string) error
DeleteDevices(ctx context.Context, txn *sql.Tx, localpart string, devices []string) error
DeleteDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string) error
UpdateDeviceName(ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string) error
SelectDeviceByToken(ctx context.Context, accessToken string) (*api.Device, error)
SelectDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error)
SelectDevicesByLocalpart(ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string) ([]api.Device, error)
SelectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error)
UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error
}
type KeyBackupTable interface {
CountKeys(ctx context.Context, txn *sql.Tx, userID, version string) (count int64, err error)
InsertBackupKey(ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession) (err error)
UpdateBackupKey(ctx context.Context, txn *sql.Tx, userID, version string, key api.InternalKeyBackupSession) (err error)
SelectKeys(ctx context.Context, txn *sql.Tx, userID, version string) (map[string]map[string]api.KeyBackupSession, error)
SelectKeysByRoomID(ctx context.Context, txn *sql.Tx, userID, version, roomID string) (map[string]map[string]api.KeyBackupSession, error)
SelectKeysByRoomIDAndSessionID(ctx context.Context, txn *sql.Tx, userID, version, roomID, sessionID string) (map[string]map[string]api.KeyBackupSession, error)
}
type KeyBackupVersionTable interface {
InsertKeyBackup(ctx context.Context, txn *sql.Tx, userID, algorithm string, authData json.RawMessage, etag string) (version string, err error)
UpdateKeyBackupAuthData(ctx context.Context, txn *sql.Tx, userID, version string, authData json.RawMessage) error
UpdateKeyBackupETag(ctx context.Context, txn *sql.Tx, userID, version, etag string) error
DeleteKeyBackup(ctx context.Context, txn *sql.Tx, userID, version string) (bool, error)
SelectKeyBackup(ctx context.Context, txn *sql.Tx, userID, version string) (versionResult, algorithm string, authData json.RawMessage, etag string, deleted bool, err error)
}
type LoginTokenTable interface {
InsertLoginToken(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error
DeleteLoginToken(ctx context.Context, txn *sql.Tx, token string) error
SelectLoginToken(ctx context.Context, token string) (*api.LoginTokenData, error)
}
type OpenIDTable interface {
InsertOpenIDToken(ctx context.Context, txn *sql.Tx, token, localpart string, expiresAtMS int64) (err error)
SelectOpenIDTokenAtrributes(ctx context.Context, token string) (*api.OpenIDTokenAttributes, error)
}
type ProfileTable interface {
InsertProfile(ctx context.Context, txn *sql.Tx, localpart string) error
SelectProfileByLocalpart(ctx context.Context, localpart string) (*authtypes.Profile, error)
SetAvatarURL(ctx context.Context, txn *sql.Tx, localpart string, avatarURL string) (err error)
SetDisplayName(ctx context.Context, txn *sql.Tx, localpart string, displayName string) (err error)
SelectProfilesBySearch(ctx context.Context, searchString string, limit int) ([]authtypes.Profile, error)
}
type ThreePIDTable interface {
SelectLocalpartForThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (localpart string, err error)
SelectThreePIDsForLocalpart(ctx context.Context, localpart string) (threepids []authtypes.ThreePID, err error)
InsertThreePID(ctx context.Context, txn *sql.Tx, threepid, medium, localpart string) (err error)
DeleteThreePID(ctx context.Context, txn *sql.Tx, threepid string, medium string) (err error)
}