mirror of
https://github.com/hoernschen/dendrite.git
synced 2025-08-01 13:52:46 +00:00
Virtual hosting schema and logic changes (#2876)
Note that virtual users cannot federate correctly yet.
This commit is contained in:
parent
e177e0ae73
commit
529df30b56
62 changed files with 1250 additions and 732 deletions
|
@ -21,6 +21,7 @@ import (
|
|||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const accountDataSchema = `
|
||||
|
@ -28,27 +29,28 @@ const accountDataSchema = `
|
|||
CREATE TABLE IF NOT EXISTS userapi_account_datas (
|
||||
-- The Matrix user ID localpart for this account
|
||||
localpart TEXT NOT NULL,
|
||||
server_name TEXT NOT NULL,
|
||||
-- The room ID for this data (empty string if not specific to a room)
|
||||
room_id TEXT,
|
||||
-- The account data type
|
||||
type TEXT NOT NULL,
|
||||
-- The account data content
|
||||
content TEXT NOT NULL,
|
||||
|
||||
PRIMARY KEY(localpart, room_id, type)
|
||||
content TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS userapi_account_datas_idx ON userapi_account_datas(localpart, server_name, room_id, type);
|
||||
`
|
||||
|
||||
const insertAccountDataSQL = `
|
||||
INSERT INTO userapi_account_datas(localpart, room_id, type, content) VALUES($1, $2, $3, $4)
|
||||
ON CONFLICT (localpart, room_id, type) DO UPDATE SET content = $4
|
||||
INSERT INTO userapi_account_datas(localpart, server_name, room_id, type, content) VALUES($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (localpart, server_name, room_id, type) DO UPDATE SET content = $5
|
||||
`
|
||||
|
||||
const selectAccountDataSQL = "" +
|
||||
"SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1"
|
||||
"SELECT room_id, type, content FROM userapi_account_datas WHERE localpart = $1 AND server_name = $2"
|
||||
|
||||
const selectAccountDataByTypeSQL = "" +
|
||||
"SELECT content FROM userapi_account_datas WHERE localpart = $1 AND room_id = $2 AND type = $3"
|
||||
"SELECT content FROM userapi_account_datas WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND type = $4"
|
||||
|
||||
type accountDataStatements struct {
|
||||
db *sql.DB
|
||||
|
@ -73,20 +75,23 @@ func NewSQLiteAccountDataTable(db *sql.DB) (tables.AccountDataTable, error) {
|
|||
}
|
||||
|
||||
func (s *accountDataStatements) InsertAccountData(
|
||||
ctx context.Context, txn *sql.Tx, localpart, roomID, dataType string, content json.RawMessage,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
roomID, dataType string, content json.RawMessage,
|
||||
) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, roomID, dataType, content)
|
||||
_, err := sqlutil.TxStmt(txn, s.insertAccountDataStmt).ExecContext(ctx, localpart, serverName, roomID, dataType, content)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *accountDataStatements) SelectAccountData(
|
||||
ctx context.Context, localpart string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (
|
||||
/* global */ map[string]json.RawMessage,
|
||||
/* rooms */ map[string]map[string]json.RawMessage,
|
||||
error,
|
||||
) {
|
||||
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart)
|
||||
rows, err := s.selectAccountDataStmt.QueryContext(ctx, localpart, serverName)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
@ -117,11 +122,13 @@ func (s *accountDataStatements) SelectAccountData(
|
|||
}
|
||||
|
||||
func (s *accountDataStatements) SelectAccountDataByType(
|
||||
ctx context.Context, localpart, roomID, dataType string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
roomID, dataType string,
|
||||
) (data json.RawMessage, err error) {
|
||||
var bytes []byte
|
||||
stmt := s.selectAccountDataByTypeStmt
|
||||
if err = stmt.QueryRowContext(ctx, localpart, roomID, dataType).Scan(&bytes); err != nil {
|
||||
if err = stmt.QueryRowContext(ctx, localpart, serverName, roomID, dataType).Scan(&bytes); err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
|
|
|
@ -34,7 +34,8 @@ const accountsSchema = `
|
|||
-- Stores data about accounts.
|
||||
CREATE TABLE IF NOT EXISTS userapi_accounts (
|
||||
-- The Matrix user ID localpart for this account
|
||||
localpart TEXT NOT NULL PRIMARY KEY,
|
||||
localpart TEXT NOT NULL,
|
||||
server_name TEXT NOT NULL,
|
||||
-- When this account was first created, as a unix timestamp (ms resolution).
|
||||
created_ts BIGINT NOT NULL,
|
||||
-- The password hash for this account. Can be NULL if this is a passwordless account.
|
||||
|
@ -48,25 +49,27 @@ CREATE TABLE IF NOT EXISTS userapi_accounts (
|
|||
-- TODO:
|
||||
-- upgraded_ts, devices, any email reset stuff?
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS userapi_accounts_idx ON userapi_accounts(localpart, server_name);
|
||||
`
|
||||
|
||||
const insertAccountSQL = "" +
|
||||
"INSERT INTO userapi_accounts(localpart, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5)"
|
||||
"INSERT INTO userapi_accounts(localpart, server_name, created_ts, password_hash, appservice_id, account_type) VALUES ($1, $2, $3, $4, $5, $6)"
|
||||
|
||||
const updatePasswordSQL = "" +
|
||||
"UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2"
|
||||
"UPDATE userapi_accounts SET password_hash = $1 WHERE localpart = $2 AND server_name = $3"
|
||||
|
||||
const deactivateAccountSQL = "" +
|
||||
"UPDATE userapi_accounts SET is_deactivated = 1 WHERE localpart = $1"
|
||||
"UPDATE userapi_accounts SET is_deactivated = 1 WHERE localpart = $1 AND server_name = $2"
|
||||
|
||||
const selectAccountByLocalpartSQL = "" +
|
||||
"SELECT localpart, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1"
|
||||
"SELECT localpart, server_name, appservice_id, account_type FROM userapi_accounts WHERE localpart = $1 AND server_name = $2"
|
||||
|
||||
const selectPasswordHashSQL = "" +
|
||||
"SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND is_deactivated = 0"
|
||||
"SELECT password_hash FROM userapi_accounts WHERE localpart = $1 AND server_name = $2 AND is_deactivated = 0"
|
||||
|
||||
const selectNewNumericLocalpartSQL = "" +
|
||||
"SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM userapi_accounts WHERE CAST(localpart AS INT) <> 0"
|
||||
"SELECT COALESCE(MAX(CAST(localpart AS INT)), 0) FROM userapi_accounts WHERE CAST(localpart AS INT) <> 0 AND server_name = $1"
|
||||
|
||||
type accountsStatements struct {
|
||||
db *sql.DB
|
||||
|
@ -119,16 +122,17 @@ func NewSQLiteAccountsTable(db *sql.DB, serverName gomatrixserverlib.ServerName)
|
|||
// this account will be passwordless. Returns an error if this account already exists. Returns the account
|
||||
// on success.
|
||||
func (s *accountsStatements) InsertAccount(
|
||||
ctx context.Context, txn *sql.Tx, localpart, hash, appserviceID string, accountType api.AccountType,
|
||||
ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
hash, appserviceID string, accountType api.AccountType,
|
||||
) (*api.Account, error) {
|
||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||
stmt := s.insertAccountStmt
|
||||
|
||||
var err error
|
||||
if accountType != api.AccountTypeAppService {
|
||||
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, nil, accountType)
|
||||
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, serverName, createdTimeMS, hash, nil, accountType)
|
||||
} else {
|
||||
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, createdTimeMS, hash, appserviceID, accountType)
|
||||
_, err = sqlutil.TxStmt(txn, stmt).ExecContext(ctx, localpart, serverName, createdTimeMS, hash, appserviceID, accountType)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -136,42 +140,43 @@ func (s *accountsStatements) InsertAccount(
|
|||
|
||||
return &api.Account{
|
||||
Localpart: localpart,
|
||||
UserID: userutil.MakeUserID(localpart, s.serverName),
|
||||
ServerName: s.serverName,
|
||||
UserID: userutil.MakeUserID(localpart, serverName),
|
||||
ServerName: serverName,
|
||||
AppServiceID: appserviceID,
|
||||
AccountType: accountType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *accountsStatements) UpdatePassword(
|
||||
ctx context.Context, localpart, passwordHash string,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
passwordHash string,
|
||||
) (err error) {
|
||||
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart)
|
||||
_, err = s.updatePasswordStmt.ExecContext(ctx, passwordHash, localpart, serverName)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *accountsStatements) DeactivateAccount(
|
||||
ctx context.Context, localpart string,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (err error) {
|
||||
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart)
|
||||
_, err = s.deactivateAccountStmt.ExecContext(ctx, localpart, serverName)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *accountsStatements) SelectPasswordHash(
|
||||
ctx context.Context, localpart string,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (hash string, err error) {
|
||||
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart).Scan(&hash)
|
||||
err = s.selectPasswordHashStmt.QueryRowContext(ctx, localpart, serverName).Scan(&hash)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *accountsStatements) SelectAccountByLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
ctx context.Context, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (*api.Account, error) {
|
||||
var appserviceIDPtr sql.NullString
|
||||
var acc api.Account
|
||||
|
||||
stmt := s.selectAccountByLocalpartStmt
|
||||
err := stmt.QueryRowContext(ctx, localpart).Scan(&acc.Localpart, &appserviceIDPtr, &acc.AccountType)
|
||||
err := stmt.QueryRowContext(ctx, localpart, serverName).Scan(&acc.Localpart, &acc.ServerName, &appserviceIDPtr, &acc.AccountType)
|
||||
if err != nil {
|
||||
if err != sql.ErrNoRows {
|
||||
log.WithError(err).Error("Unable to retrieve user from the db")
|
||||
|
@ -182,20 +187,18 @@ func (s *accountsStatements) SelectAccountByLocalpart(
|
|||
acc.AppServiceID = appserviceIDPtr.String
|
||||
}
|
||||
|
||||
acc.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||
acc.ServerName = s.serverName
|
||||
|
||||
acc.UserID = userutil.MakeUserID(acc.Localpart, acc.ServerName)
|
||||
return &acc, nil
|
||||
}
|
||||
|
||||
func (s *accountsStatements) SelectNewNumericLocalpart(
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
ctx context.Context, txn *sql.Tx, serverName gomatrixserverlib.ServerName,
|
||||
) (id int64, err error) {
|
||||
stmt := s.selectNewNumericLocalpartStmt
|
||||
if txn != nil {
|
||||
stmt = sqlutil.TxStmt(txn, stmt)
|
||||
}
|
||||
err = stmt.QueryRowContext(ctx).Scan(&id)
|
||||
err = stmt.QueryRowContext(ctx, serverName).Scan(&id)
|
||||
if err == sql.ErrNoRows {
|
||||
return 1, nil
|
||||
}
|
||||
|
|
|
@ -11,6 +11,7 @@ func UpIsActive(ctx context.Context, tx *sql.Tx) error {
|
|||
ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp;
|
||||
CREATE TABLE userapi_accounts (
|
||||
localpart TEXT NOT NULL PRIMARY KEY,
|
||||
server_name TEXT NOT NULL,
|
||||
created_ts BIGINT NOT NULL,
|
||||
password_hash TEXT,
|
||||
appservice_id TEXT,
|
||||
|
|
|
@ -14,6 +14,7 @@ func UpLastSeenTSIP(ctx context.Context, tx *sql.Tx) error {
|
|||
session_id INTEGER,
|
||||
device_id TEXT ,
|
||||
localpart TEXT ,
|
||||
server_name TEXT NOT NULL,
|
||||
created_ts BIGINT,
|
||||
display_name TEXT,
|
||||
last_seen_ts BIGINT,
|
||||
|
|
|
@ -12,6 +12,7 @@ func UpAddAccountType(ctx context.Context, tx *sql.Tx) error {
|
|||
_, err := tx.ExecContext(ctx, `ALTER TABLE userapi_accounts RENAME TO userapi_accounts_tmp;
|
||||
CREATE TABLE userapi_accounts (
|
||||
localpart TEXT NOT NULL PRIMARY KEY,
|
||||
server_name TEXT NOT NULL,
|
||||
created_ts BIGINT NOT NULL,
|
||||
password_hash TEXT,
|
||||
appservice_id TEXT,
|
||||
|
|
108
userapi/storage/sqlite3/deltas/2022110411000000_server_names.go
Normal file
108
userapi/storage/sqlite3/deltas/2022110411000000_server_names.go
Normal file
|
@ -0,0 +1,108 @@
|
|||
package deltas
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var serverNamesTables = []string{
|
||||
"userapi_accounts",
|
||||
"userapi_account_datas",
|
||||
"userapi_devices",
|
||||
"userapi_notifications",
|
||||
"userapi_openid_tokens",
|
||||
"userapi_profiles",
|
||||
"userapi_pushers",
|
||||
"userapi_threepids",
|
||||
}
|
||||
|
||||
// These tables have a PRIMARY KEY constraint which we need to drop so
|
||||
// that we can recreate a new unique index that contains the server name.
|
||||
var serverNamesDropPK = []string{
|
||||
"userapi_accounts",
|
||||
"userapi_account_datas",
|
||||
"userapi_profiles",
|
||||
}
|
||||
|
||||
// These indices are out of date so let's drop them. They will get recreated
|
||||
// automatically.
|
||||
var serverNamesDropIndex = []string{
|
||||
"userapi_pusher_localpart_idx",
|
||||
"userapi_pusher_app_id_pushkey_localpart_idx",
|
||||
}
|
||||
|
||||
// I know what you're thinking: you're wondering "why doesn't this use $1
|
||||
// and pass variadic parameters to ExecContext?" — the answer is because
|
||||
// PostgreSQL doesn't expect the table name to be specified as a substituted
|
||||
// argument in that way so it results in a syntax error in the query.
|
||||
|
||||
func UpServerNames(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error {
|
||||
for _, table := range serverNamesTables {
|
||||
q := fmt.Sprintf(
|
||||
"SELECT COUNT(name) FROM sqlite_schema WHERE type='table' AND name=%s;",
|
||||
pq.QuoteIdentifier(table),
|
||||
)
|
||||
var c int
|
||||
if err := tx.QueryRowContext(ctx, q).Scan(&c); err != nil || c == 0 {
|
||||
continue
|
||||
}
|
||||
q = fmt.Sprintf(
|
||||
"SELECT COUNT(*) FROM pragma_table_info(%s) WHERE name='server_name'",
|
||||
pq.QuoteIdentifier(table),
|
||||
)
|
||||
if err := tx.QueryRowContext(ctx, q).Scan(&c); err != nil || c == 1 {
|
||||
logrus.Infof("Table %s already has column, skipping", table)
|
||||
continue
|
||||
}
|
||||
if c == 0 {
|
||||
q = fmt.Sprintf(
|
||||
"ALTER TABLE %s ADD COLUMN server_name TEXT NOT NULL DEFAULT '';",
|
||||
pq.QuoteIdentifier(table),
|
||||
)
|
||||
if _, err := tx.ExecContext(ctx, q); err != nil {
|
||||
return fmt.Errorf("add server name to %q error: %w", table, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, table := range serverNamesDropPK {
|
||||
q := fmt.Sprintf(
|
||||
"SELECT COUNT(name), sql FROM sqlite_schema WHERE type='table' AND name=%s;",
|
||||
pq.QuoteIdentifier(table),
|
||||
)
|
||||
var c int
|
||||
var sql string
|
||||
if err := tx.QueryRowContext(ctx, q).Scan(&c, &sql); err != nil || c == 0 {
|
||||
continue
|
||||
}
|
||||
q = fmt.Sprintf(`
|
||||
%s; -- create temporary table
|
||||
INSERT INTO %s SELECT * FROM %s; -- copy data
|
||||
DROP TABLE %s; -- drop original table
|
||||
ALTER TABLE %s RENAME TO %s; -- rename new table
|
||||
`,
|
||||
strings.Replace(sql, table, table+"_tmp", 1), // create temporary table
|
||||
table+"_tmp", table, // copy data
|
||||
table, // drop original table
|
||||
table+"_tmp", table, // rename new table
|
||||
)
|
||||
if _, err := tx.ExecContext(ctx, q); err != nil {
|
||||
return fmt.Errorf("drop PK from %q error: %w", table, err)
|
||||
}
|
||||
}
|
||||
for _, index := range serverNamesDropIndex {
|
||||
q := fmt.Sprintf(
|
||||
"DROP INDEX IF EXISTS %s;",
|
||||
pq.QuoteIdentifier(index),
|
||||
)
|
||||
if _, err := tx.ExecContext(ctx, q); err != nil {
|
||||
return fmt.Errorf("drop index %q error: %w", index, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,28 @@
|
|||
package deltas
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// I know what you're thinking: you're wondering "why doesn't this use $1
|
||||
// and pass variadic parameters to ExecContext?" — the answer is because
|
||||
// PostgreSQL doesn't expect the table name to be specified as a substituted
|
||||
// argument in that way so it results in a syntax error in the query.
|
||||
|
||||
func UpServerNamesPopulate(ctx context.Context, tx *sql.Tx, serverName gomatrixserverlib.ServerName) error {
|
||||
for _, table := range serverNamesTables {
|
||||
q := fmt.Sprintf(
|
||||
"UPDATE %s SET server_name = %s WHERE server_name = '';",
|
||||
pq.QuoteIdentifier(table), pq.QuoteLiteral(string(serverName)),
|
||||
)
|
||||
if _, err := tx.ExecContext(ctx, q); err != nil {
|
||||
return fmt.Errorf("write server names to %q error: %w", table, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -40,49 +40,50 @@ CREATE TABLE IF NOT EXISTS userapi_devices (
|
|||
session_id INTEGER,
|
||||
device_id TEXT ,
|
||||
localpart TEXT ,
|
||||
server_name TEXT NOT NULL,
|
||||
created_ts BIGINT,
|
||||
display_name TEXT,
|
||||
last_seen_ts BIGINT,
|
||||
ip TEXT,
|
||||
user_agent TEXT,
|
||||
|
||||
UNIQUE (localpart, device_id)
|
||||
UNIQUE (localpart, server_name, device_id)
|
||||
);
|
||||
`
|
||||
|
||||
const insertDeviceSQL = "" +
|
||||
"INSERT INTO userapi_devices (device_id, localpart, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" +
|
||||
" VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)"
|
||||
"INSERT INTO userapi_devices (device_id, localpart, server_name, access_token, created_ts, display_name, session_id, last_seen_ts, ip, user_agent)" +
|
||||
" VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)"
|
||||
|
||||
const selectDevicesCountSQL = "" +
|
||||
"SELECT COUNT(access_token) FROM userapi_devices"
|
||||
|
||||
const selectDeviceByTokenSQL = "" +
|
||||
"SELECT session_id, device_id, localpart FROM userapi_devices WHERE access_token = $1"
|
||||
"SELECT session_id, device_id, localpart, server_name FROM userapi_devices WHERE access_token = $1"
|
||||
|
||||
const selectDeviceByIDSQL = "" +
|
||||
"SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 and device_id = $2"
|
||||
"SELECT display_name, last_seen_ts, ip FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id = $3"
|
||||
|
||||
const selectDevicesByLocalpartSQL = "" +
|
||||
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND device_id != $2 ORDER BY last_seen_ts DESC"
|
||||
"SELECT device_id, display_name, last_seen_ts, ip, user_agent FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3 ORDER BY last_seen_ts DESC"
|
||||
|
||||
const updateDeviceNameSQL = "" +
|
||||
"UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
|
||||
"UPDATE userapi_devices SET display_name = $1 WHERE localpart = $2 AND server_name = $3 AND device_id = $4"
|
||||
|
||||
const deleteDeviceSQL = "" +
|
||||
"DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2"
|
||||
"DELETE FROM userapi_devices WHERE device_id = $1 AND localpart = $2 AND server_name = $3"
|
||||
|
||||
const deleteDevicesByLocalpartSQL = "" +
|
||||
"DELETE FROM userapi_devices WHERE localpart = $1 AND device_id != $2"
|
||||
"DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id != $3"
|
||||
|
||||
const deleteDevicesSQL = "" +
|
||||
"DELETE FROM userapi_devices WHERE localpart = $1 AND device_id IN ($2)"
|
||||
"DELETE FROM userapi_devices WHERE localpart = $1 AND server_name = $2 AND device_id IN ($3)"
|
||||
|
||||
const selectDevicesByIDSQL = "" +
|
||||
"SELECT device_id, localpart, display_name, last_seen_ts FROM userapi_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC"
|
||||
"SELECT device_id, localpart, server_name, display_name, last_seen_ts FROM userapi_devices WHERE device_id IN ($1) ORDER BY last_seen_ts DESC"
|
||||
|
||||
const updateDeviceLastSeen = "" +
|
||||
"UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND device_id = $5"
|
||||
"UPDATE userapi_devices SET last_seen_ts = $1, ip = $2, user_agent = $3 WHERE localpart = $4 AND server_name = $5 AND device_id = $6"
|
||||
|
||||
type devicesStatements struct {
|
||||
db *sql.DB
|
||||
|
@ -135,8 +136,9 @@ func NewSQLiteDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName)
|
|||
// Returns an error if the user already has a device with the given device ID.
|
||||
// Returns the device on success.
|
||||
func (s *devicesStatements) InsertDevice(
|
||||
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
|
||||
displayName *string, ipAddr, userAgent string,
|
||||
ctx context.Context, txn *sql.Tx, id string,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
accessToken string, displayName *string, ipAddr, userAgent string,
|
||||
) (*api.Device, error) {
|
||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||
var sessionID int64
|
||||
|
@ -146,12 +148,12 @@ func (s *devicesStatements) InsertDevice(
|
|||
return nil, err
|
||||
}
|
||||
sessionID++
|
||||
if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil {
|
||||
if _, err := insertStmt.ExecContext(ctx, id, localpart, serverName, accessToken, createdTimeMS, displayName, sessionID, createdTimeMS, ipAddr, userAgent); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &api.Device{
|
||||
ID: id,
|
||||
UserID: userutil.MakeUserID(localpart, s.serverName),
|
||||
UserID: userutil.MakeUserID(localpart, serverName),
|
||||
AccessToken: accessToken,
|
||||
SessionID: sessionID,
|
||||
LastSeenTS: createdTimeMS,
|
||||
|
@ -161,44 +163,52 @@ func (s *devicesStatements) InsertDevice(
|
|||
}
|
||||
|
||||
func (s *devicesStatements) DeleteDevice(
|
||||
ctx context.Context, txn *sql.Tx, id, localpart string,
|
||||
ctx context.Context, txn *sql.Tx, id string,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
|
||||
_, err := stmt.ExecContext(ctx, id, localpart)
|
||||
_, err := stmt.ExecContext(ctx, id, localpart, serverName)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *devicesStatements) DeleteDevices(
|
||||
ctx context.Context, txn *sql.Tx, localpart string, devices []string,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
devices []string,
|
||||
) error {
|
||||
orig := strings.Replace(deleteDevicesSQL, "($2)", sqlutil.QueryVariadicOffset(len(devices), 1), 1)
|
||||
orig := strings.Replace(deleteDevicesSQL, "($3)", sqlutil.QueryVariadicOffset(len(devices), 2), 1)
|
||||
prep, err := s.db.Prepare(orig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stmt := sqlutil.TxStmt(txn, prep)
|
||||
params := make([]interface{}, len(devices)+1)
|
||||
params := make([]interface{}, len(devices)+2)
|
||||
params[0] = localpart
|
||||
params[1] = serverName
|
||||
for i, v := range devices {
|
||||
params[i+1] = v
|
||||
params[i+2] = v
|
||||
}
|
||||
_, err = stmt.ExecContext(ctx, params...)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *devicesStatements) DeleteDevicesByLocalpart(
|
||||
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
exceptDeviceID string,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
|
||||
_, err := stmt.ExecContext(ctx, localpart, exceptDeviceID)
|
||||
_, err := stmt.ExecContext(ctx, localpart, serverName, exceptDeviceID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *devicesStatements) UpdateDeviceName(
|
||||
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
deviceID string, displayName *string,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
|
||||
_, err := stmt.ExecContext(ctx, displayName, localpart, deviceID)
|
||||
_, err := stmt.ExecContext(ctx, displayName, localpart, serverName, deviceID)
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -207,10 +217,11 @@ func (s *devicesStatements) SelectDeviceByToken(
|
|||
) (*api.Device, error) {
|
||||
var dev api.Device
|
||||
var localpart string
|
||||
var serverName gomatrixserverlib.ServerName
|
||||
stmt := s.selectDeviceByTokenStmt
|
||||
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart)
|
||||
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart, &serverName)
|
||||
if err == nil {
|
||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||
dev.UserID = userutil.MakeUserID(localpart, serverName)
|
||||
dev.AccessToken = accessToken
|
||||
}
|
||||
return &dev, err
|
||||
|
@ -219,16 +230,18 @@ func (s *devicesStatements) SelectDeviceByToken(
|
|||
// selectDeviceByID retrieves a device from the database with the given user
|
||||
// localpart and deviceID
|
||||
func (s *devicesStatements) SelectDeviceByID(
|
||||
ctx context.Context, localpart, deviceID string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
deviceID string,
|
||||
) (*api.Device, error) {
|
||||
var dev api.Device
|
||||
var displayName, ip sql.NullString
|
||||
stmt := s.selectDeviceByIDStmt
|
||||
var lastseenTS sql.NullInt64
|
||||
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&displayName, &lastseenTS, &ip)
|
||||
err := stmt.QueryRowContext(ctx, localpart, serverName, deviceID).Scan(&displayName, &lastseenTS, &ip)
|
||||
if err == nil {
|
||||
dev.ID = deviceID
|
||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||
dev.UserID = userutil.MakeUserID(localpart, serverName)
|
||||
if displayName.Valid {
|
||||
dev.DisplayName = displayName.String
|
||||
}
|
||||
|
@ -243,10 +256,12 @@ func (s *devicesStatements) SelectDeviceByID(
|
|||
}
|
||||
|
||||
func (s *devicesStatements) SelectDevicesByLocalpart(
|
||||
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
exceptDeviceID string,
|
||||
) ([]api.Device, error) {
|
||||
devices := []api.Device{}
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, exceptDeviceID)
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectDevicesByLocalpartStmt).QueryContext(ctx, localpart, serverName, exceptDeviceID)
|
||||
|
||||
if err != nil {
|
||||
return devices, err
|
||||
|
@ -276,7 +291,7 @@ func (s *devicesStatements) SelectDevicesByLocalpart(
|
|||
dev.UserAgent = useragent.String
|
||||
}
|
||||
|
||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||
dev.UserID = userutil.MakeUserID(localpart, serverName)
|
||||
devices = append(devices, dev)
|
||||
}
|
||||
|
||||
|
@ -298,10 +313,11 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s
|
|||
var devices []api.Device
|
||||
var dev api.Device
|
||||
var localpart string
|
||||
var serverName gomatrixserverlib.ServerName
|
||||
var displayName sql.NullString
|
||||
var lastseents sql.NullInt64
|
||||
for rows.Next() {
|
||||
if err := rows.Scan(&dev.ID, &localpart, &displayName, &lastseents); err != nil {
|
||||
if err := rows.Scan(&dev.ID, &localpart, &serverName, &displayName, &lastseents); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if displayName.Valid {
|
||||
|
@ -310,15 +326,15 @@ func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []s
|
|||
if lastseents.Valid {
|
||||
dev.LastSeenTS = lastseents.Int64
|
||||
}
|
||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||
dev.UserID = userutil.MakeUserID(localpart, serverName)
|
||||
devices = append(devices, dev)
|
||||
}
|
||||
return devices, rows.Err()
|
||||
}
|
||||
|
||||
func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr, userAgent string) error {
|
||||
func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, deviceID, ipAddr, userAgent string) error {
|
||||
lastSeenTs := time.Now().UnixNano() / 1000000
|
||||
stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt)
|
||||
_, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, deviceID)
|
||||
_, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, userAgent, localpart, serverName, deviceID)
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -43,6 +43,7 @@ const notificationSchema = `
|
|||
CREATE TABLE IF NOT EXISTS userapi_notifications (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
localpart TEXT NOT NULL,
|
||||
server_name TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
event_id TEXT NOT NULL,
|
||||
stream_pos BIGINT NOT NULL,
|
||||
|
@ -52,33 +53,33 @@ CREATE TABLE IF NOT EXISTS userapi_notifications (
|
|||
read BOOLEAN NOT NULL DEFAULT FALSE
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, room_id, event_id);
|
||||
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, room_id, id);
|
||||
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, id);
|
||||
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_event_id_idx ON userapi_notifications(localpart, server_name, room_id, event_id);
|
||||
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_room_id_id_idx ON userapi_notifications(localpart, server_name, room_id, id);
|
||||
CREATE INDEX IF NOT EXISTS userapi_notification_localpart_id_idx ON userapi_notifications(localpart, server_name, id);
|
||||
`
|
||||
|
||||
const insertNotificationSQL = "" +
|
||||
"INSERT INTO userapi_notifications (localpart, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7)"
|
||||
"INSERT INTO userapi_notifications (localpart, server_name, room_id, event_id, stream_pos, ts_ms, highlight, notification_json) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
|
||||
|
||||
const deleteNotificationsUpToSQL = "" +
|
||||
"DELETE FROM userapi_notifications WHERE localpart = $1 AND room_id = $2 AND stream_pos <= $3"
|
||||
"DELETE FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND stream_pos <= $4"
|
||||
|
||||
const updateNotificationReadSQL = "" +
|
||||
"UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND room_id = $3 AND stream_pos <= $4 AND read <> $1"
|
||||
"UPDATE userapi_notifications SET read = $1 WHERE localpart = $2 AND server_name = $3 AND room_id = $4 AND stream_pos <= $5 AND read <> $1"
|
||||
|
||||
const selectNotificationSQL = "" +
|
||||
"SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND id > $2 AND (" +
|
||||
"(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" +
|
||||
") AND NOT read ORDER BY localpart, id LIMIT $4"
|
||||
"SELECT id, room_id, ts_ms, read, notification_json FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND id > $3 AND (" +
|
||||
"(($4 & 1) <> 0 AND highlight) OR (($4 & 2) <> 0 AND NOT highlight)" +
|
||||
") AND NOT read ORDER BY localpart, id LIMIT $5"
|
||||
|
||||
const selectNotificationCountSQL = "" +
|
||||
"SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND (" +
|
||||
"(($2 & 1) <> 0 AND highlight) OR (($2 & 2) <> 0 AND NOT highlight)" +
|
||||
"SELECT COUNT(*) FROM userapi_notifications WHERE localpart = $1 AND server_name = $2 AND (" +
|
||||
"(($3 & 1) <> 0 AND highlight) OR (($3 & 2) <> 0 AND NOT highlight)" +
|
||||
") AND NOT read"
|
||||
|
||||
const selectRoomNotificationCountsSQL = "" +
|
||||
"SELECT COUNT(*), COUNT(*) FILTER (WHERE highlight) FROM userapi_notifications " +
|
||||
"WHERE localpart = $1 AND room_id = $2 AND NOT read"
|
||||
"WHERE localpart = $1 AND server_name = $2 AND room_id = $3 AND NOT read"
|
||||
|
||||
const cleanNotificationsSQL = "" +
|
||||
"DELETE FROM userapi_notifications WHERE" +
|
||||
|
@ -111,7 +112,7 @@ func (s *notificationsStatements) Clean(ctx context.Context, txn *sql.Tx) error
|
|||
}
|
||||
|
||||
// Insert inserts a notification into the database.
|
||||
func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart, eventID string, pos uint64, highlight bool, n *api.Notification) error {
|
||||
func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, eventID string, pos uint64, highlight bool, n *api.Notification) error {
|
||||
roomID, tsMS := n.RoomID, n.TS
|
||||
nn := *n
|
||||
// Clears out fields that have their own columns to (1) shrink the
|
||||
|
@ -122,13 +123,13 @@ func (s *notificationsStatements) Insert(ctx context.Context, txn *sql.Tx, local
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, roomID, eventID, pos, tsMS, highlight, string(bs))
|
||||
_, err = sqlutil.TxStmt(txn, s.insertStmt).ExecContext(ctx, localpart, serverName, roomID, eventID, pos, tsMS, highlight, string(bs))
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteUpTo deletes all previous notifications, up to and including the event.
|
||||
func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64) (affected bool, _ error) {
|
||||
res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, roomID, pos)
|
||||
func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64) (affected bool, _ error) {
|
||||
res, err := sqlutil.TxStmt(txn, s.deleteUpToStmt).ExecContext(ctx, localpart, serverName, roomID, pos)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
@ -141,8 +142,8 @@ func (s *notificationsStatements) DeleteUpTo(ctx context.Context, txn *sql.Tx, l
|
|||
}
|
||||
|
||||
// UpdateRead updates the "read" value for an event.
|
||||
func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart, roomID string, pos uint64, v bool) (affected bool, _ error) {
|
||||
res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, roomID, pos)
|
||||
func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string, pos uint64, v bool) (affected bool, _ error) {
|
||||
res, err := sqlutil.TxStmt(txn, s.updateReadStmt).ExecContext(ctx, v, localpart, serverName, roomID, pos)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
@ -154,8 +155,8 @@ func (s *notificationsStatements) UpdateRead(ctx context.Context, txn *sql.Tx, l
|
|||
return nrows > 0, nil
|
||||
}
|
||||
|
||||
func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, fromID, uint32(filter), limit)
|
||||
func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, fromID int64, limit int, filter tables.NotificationFilter) ([]*api.Notification, int64, error) {
|
||||
rows, err := sqlutil.TxStmt(txn, s.selectStmt).QueryContext(ctx, localpart, serverName, fromID, uint32(filter), limit)
|
||||
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
|
@ -197,12 +198,12 @@ func (s *notificationsStatements) Select(ctx context.Context, txn *sql.Tx, local
|
|||
return notifs, maxID, rows.Err()
|
||||
}
|
||||
|
||||
func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, filter tables.NotificationFilter) (count int64, err error) {
|
||||
err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, uint32(filter)).Scan(&count)
|
||||
func (s *notificationsStatements) SelectCount(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, filter tables.NotificationFilter) (count int64, err error) {
|
||||
err = sqlutil.TxStmt(txn, s.selectCountStmt).QueryRowContext(ctx, localpart, serverName, uint32(filter)).Scan(&count)
|
||||
return
|
||||
}
|
||||
|
||||
func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart, roomID string) (total int64, highlight int64, err error) {
|
||||
err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, roomID).Scan(&total, &highlight)
|
||||
func (s *notificationsStatements) SelectRoomCounts(ctx context.Context, txn *sql.Tx, localpart string, serverName gomatrixserverlib.ServerName, roomID string) (total int64, highlight int64, err error) {
|
||||
err = sqlutil.TxStmt(txn, s.selectRoomCountsStmt).QueryRowContext(ctx, localpart, serverName, roomID).Scan(&total, &highlight)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package sqlite3
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
|
@ -18,16 +19,17 @@ CREATE TABLE IF NOT EXISTS userapi_openid_tokens (
|
|||
token TEXT NOT NULL PRIMARY KEY,
|
||||
-- The Matrix user ID for this account
|
||||
localpart TEXT NOT NULL,
|
||||
server_name TEXT NOT NULL,
|
||||
-- When the token expires, as a unix timestamp (ms resolution).
|
||||
token_expires_at_ms BIGINT NOT NULL
|
||||
);
|
||||
`
|
||||
|
||||
const insertOpenIDTokenSQL = "" +
|
||||
"INSERT INTO userapi_openid_tokens(token, localpart, token_expires_at_ms) VALUES ($1, $2, $3)"
|
||||
"INSERT INTO userapi_openid_tokens(token, localpart, server_name, token_expires_at_ms) VALUES ($1, $2, $3, $4)"
|
||||
|
||||
const selectOpenIDTokenSQL = "" +
|
||||
"SELECT localpart, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1"
|
||||
"SELECT localpart, server_name, token_expires_at_ms FROM userapi_openid_tokens WHERE token = $1"
|
||||
|
||||
type openIDTokenStatements struct {
|
||||
db *sql.DB
|
||||
|
@ -56,11 +58,11 @@ func NewSQLiteOpenIDTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (
|
|||
func (s *openIDTokenStatements) InsertOpenIDToken(
|
||||
ctx context.Context,
|
||||
txn *sql.Tx,
|
||||
token, localpart string,
|
||||
token, localpart string, serverName gomatrixserverlib.ServerName,
|
||||
expiresAtMS int64,
|
||||
) (err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertTokenStmt)
|
||||
_, err = stmt.ExecContext(ctx, token, localpart, expiresAtMS)
|
||||
_, err = stmt.ExecContext(ctx, token, localpart, serverName, expiresAtMS)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -71,10 +73,13 @@ func (s *openIDTokenStatements) SelectOpenIDTokenAtrributes(
|
|||
token string,
|
||||
) (*api.OpenIDTokenAttributes, error) {
|
||||
var openIDTokenAttrs api.OpenIDTokenAttributes
|
||||
var localpart string
|
||||
var serverName gomatrixserverlib.ServerName
|
||||
err := s.selectTokenStmt.QueryRowContext(ctx, token).Scan(
|
||||
&openIDTokenAttrs.UserID,
|
||||
&localpart, &serverName,
|
||||
&openIDTokenAttrs.ExpiresAtMS,
|
||||
)
|
||||
openIDTokenAttrs.UserID = fmt.Sprintf("@%s:%s", localpart, serverName)
|
||||
if err != nil {
|
||||
if err != sql.ErrNoRows {
|
||||
log.WithError(err).Error("Unable to retrieve token from the db")
|
||||
|
|
|
@ -23,36 +23,40 @@ import (
|
|||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
const profilesSchema = `
|
||||
-- Stores data about accounts profiles.
|
||||
CREATE TABLE IF NOT EXISTS userapi_profiles (
|
||||
-- The Matrix user ID localpart for this account
|
||||
localpart TEXT NOT NULL PRIMARY KEY,
|
||||
localpart TEXT NOT NULL,
|
||||
server_name TEXT NOT NULL,
|
||||
-- The display name for this account
|
||||
display_name TEXT,
|
||||
-- The URL of the avatar for this account
|
||||
avatar_url TEXT
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS userapi_profiles_idx ON userapi_profiles(localpart, server_name);
|
||||
`
|
||||
|
||||
const insertProfileSQL = "" +
|
||||
"INSERT INTO userapi_profiles(localpart, display_name, avatar_url) VALUES ($1, $2, $3)"
|
||||
"INSERT INTO userapi_profiles(localpart, server_name, display_name, avatar_url) VALUES ($1, $2, $3, $4)"
|
||||
|
||||
const selectProfileByLocalpartSQL = "" +
|
||||
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1"
|
||||
"SELECT localpart, server_name, display_name, avatar_url FROM userapi_profiles WHERE localpart = $1 AND server_name = $2"
|
||||
|
||||
const setAvatarURLSQL = "" +
|
||||
"UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2" +
|
||||
"UPDATE userapi_profiles SET avatar_url = $1 WHERE localpart = $2 AND server_name = $3" +
|
||||
" RETURNING display_name"
|
||||
|
||||
const setDisplayNameSQL = "" +
|
||||
"UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2" +
|
||||
"UPDATE userapi_profiles SET display_name = $1 WHERE localpart = $2 AND server_name = $3" +
|
||||
" RETURNING avatar_url"
|
||||
|
||||
const selectProfilesBySearchSQL = "" +
|
||||
"SELECT localpart, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
|
||||
"SELECT localpart, server_name, display_name, avatar_url FROM userapi_profiles WHERE localpart LIKE $1 OR display_name LIKE $1 LIMIT $2"
|
||||
|
||||
type profilesStatements struct {
|
||||
db *sql.DB
|
||||
|
@ -83,18 +87,20 @@ func NewSQLiteProfilesTable(db *sql.DB, serverNoticesLocalpart string) (tables.P
|
|||
}
|
||||
|
||||
func (s *profilesStatements) InsertProfile(
|
||||
ctx context.Context, txn *sql.Tx, localpart string,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, "", "")
|
||||
_, err := sqlutil.TxStmt(txn, s.insertProfileStmt).ExecContext(ctx, localpart, serverName, "", "")
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *profilesStatements) SelectProfileByLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (*authtypes.Profile, error) {
|
||||
var profile authtypes.Profile
|
||||
err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart).Scan(
|
||||
&profile.Localpart, &profile.DisplayName, &profile.AvatarURL,
|
||||
err := s.selectProfileByLocalpartStmt.QueryRowContext(ctx, localpart, serverName).Scan(
|
||||
&profile.Localpart, &profile.ServerName, &profile.DisplayName, &profile.AvatarURL,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -103,13 +109,16 @@ func (s *profilesStatements) SelectProfileByLocalpart(
|
|||
}
|
||||
|
||||
func (s *profilesStatements) SetAvatarURL(
|
||||
ctx context.Context, txn *sql.Tx, localpart string, avatarURL string,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
avatarURL string,
|
||||
) (*authtypes.Profile, bool, error) {
|
||||
profile := &authtypes.Profile{
|
||||
Localpart: localpart,
|
||||
AvatarURL: avatarURL,
|
||||
Localpart: localpart,
|
||||
ServerName: string(serverName),
|
||||
AvatarURL: avatarURL,
|
||||
}
|
||||
old, err := s.SelectProfileByLocalpart(ctx, localpart)
|
||||
old, err := s.SelectProfileByLocalpart(ctx, localpart, serverName)
|
||||
if err != nil {
|
||||
return old, false, err
|
||||
}
|
||||
|
@ -117,18 +126,21 @@ func (s *profilesStatements) SetAvatarURL(
|
|||
return old, false, nil
|
||||
}
|
||||
stmt := sqlutil.TxStmt(txn, s.setAvatarURLStmt)
|
||||
err = stmt.QueryRowContext(ctx, avatarURL, localpart).Scan(&profile.DisplayName)
|
||||
err = stmt.QueryRowContext(ctx, avatarURL, localpart, serverName).Scan(&profile.DisplayName)
|
||||
return profile, true, err
|
||||
}
|
||||
|
||||
func (s *profilesStatements) SetDisplayName(
|
||||
ctx context.Context, txn *sql.Tx, localpart string, displayName string,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
displayName string,
|
||||
) (*authtypes.Profile, bool, error) {
|
||||
profile := &authtypes.Profile{
|
||||
Localpart: localpart,
|
||||
ServerName: string(serverName),
|
||||
DisplayName: displayName,
|
||||
}
|
||||
old, err := s.SelectProfileByLocalpart(ctx, localpart)
|
||||
old, err := s.SelectProfileByLocalpart(ctx, localpart, serverName)
|
||||
if err != nil {
|
||||
return old, false, err
|
||||
}
|
||||
|
@ -136,7 +148,7 @@ func (s *profilesStatements) SetDisplayName(
|
|||
return old, false, nil
|
||||
}
|
||||
stmt := sqlutil.TxStmt(txn, s.setDisplayNameStmt)
|
||||
err = stmt.QueryRowContext(ctx, displayName, localpart).Scan(&profile.AvatarURL)
|
||||
err = stmt.QueryRowContext(ctx, displayName, localpart, serverName).Scan(&profile.AvatarURL)
|
||||
return profile, true, err
|
||||
}
|
||||
|
||||
|
@ -154,7 +166,7 @@ func (s *profilesStatements) SelectProfilesBySearch(
|
|||
defer internal.CloseAndLogIfError(ctx, rows, "selectProfilesBySearch: rows.close() failed")
|
||||
for rows.Next() {
|
||||
var profile authtypes.Profile
|
||||
if err := rows.Scan(&profile.Localpart, &profile.DisplayName, &profile.AvatarURL); err != nil {
|
||||
if err := rows.Scan(&profile.Localpart, &profile.ServerName, &profile.DisplayName, &profile.AvatarURL); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if profile.Localpart != s.serverNoticesLocalpart {
|
||||
|
|
|
@ -25,6 +25,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
// See https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-pushers
|
||||
|
@ -33,6 +34,7 @@ CREATE TABLE IF NOT EXISTS userapi_pushers (
|
|||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
-- The Matrix user ID localpart for this pusher
|
||||
localpart TEXT NOT NULL,
|
||||
server_name TEXT NOT NULL,
|
||||
session_id BIGINT DEFAULT NULL,
|
||||
profile_tag TEXT,
|
||||
kind TEXT NOT NULL,
|
||||
|
@ -49,22 +51,22 @@ CREATE TABLE IF NOT EXISTS userapi_pushers (
|
|||
CREATE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_idx ON userapi_pushers(app_id, pushkey);
|
||||
|
||||
-- For faster retrieving by localpart.
|
||||
CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart);
|
||||
CREATE INDEX IF NOT EXISTS userapi_pusher_localpart_idx ON userapi_pushers(localpart, server_name);
|
||||
|
||||
-- Pushkey must be unique for a given user and app.
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart);
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS userapi_pusher_app_id_pushkey_localpart_idx ON userapi_pushers(app_id, pushkey, localpart, server_name);
|
||||
`
|
||||
|
||||
const insertPusherSQL = "" +
|
||||
"INSERT INTO userapi_pushers (localpart, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" +
|
||||
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)" +
|
||||
"ON CONFLICT (app_id, pushkey, localpart) DO UPDATE SET session_id = $2, pushkey_ts_ms = $4, kind = $5, app_display_name = $7, device_display_name = $8, profile_tag = $9, lang = $10, data = $11"
|
||||
"INSERT INTO userapi_pushers (localpart, server_name, session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data)" +
|
||||
"VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)" +
|
||||
"ON CONFLICT (app_id, pushkey, localpart, server_name) DO UPDATE SET session_id = $3, pushkey_ts_ms = $5, kind = $6, app_display_name = $8, device_display_name = $9, profile_tag = $10, lang = $11, data = $12"
|
||||
|
||||
const selectPushersSQL = "" +
|
||||
"SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1"
|
||||
"SELECT session_id, pushkey, pushkey_ts_ms, kind, app_id, app_display_name, device_display_name, profile_tag, lang, data FROM userapi_pushers WHERE localpart = $1 AND server_name = $2"
|
||||
|
||||
const deletePusherSQL = "" +
|
||||
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3"
|
||||
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2 AND localpart = $3 AND server_name = $4"
|
||||
|
||||
const deletePushersByAppIdAndPushKeySQL = "" +
|
||||
"DELETE FROM userapi_pushers WHERE app_id = $1 AND pushkey = $2"
|
||||
|
@ -95,18 +97,19 @@ type pushersStatements struct {
|
|||
// Returns nil error success.
|
||||
func (s *pushersStatements) InsertPusher(
|
||||
ctx context.Context, txn *sql.Tx, session_id int64,
|
||||
pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data, localpart string,
|
||||
pushkey string, pushkeyTS int64, kind api.PusherKind, appid, appdisplayname, devicedisplayname, profiletag, lang, data,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
|
||||
logrus.Debugf("Created pusher %d", session_id)
|
||||
_, err := sqlutil.TxStmt(txn, s.insertPusherStmt).ExecContext(ctx, localpart, serverName, session_id, pushkey, pushkeyTS, kind, appid, appdisplayname, devicedisplayname, profiletag, lang, data)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *pushersStatements) SelectPushers(
|
||||
ctx context.Context, txn *sql.Tx, localpart string,
|
||||
ctx context.Context, txn *sql.Tx,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) ([]api.Pusher, error) {
|
||||
pushers := []api.Pusher{}
|
||||
rows, err := s.selectPushersStmt.QueryContext(ctx, localpart)
|
||||
rows, err := s.selectPushersStmt.QueryContext(ctx, localpart, serverName)
|
||||
|
||||
if err != nil {
|
||||
return pushers, err
|
||||
|
@ -143,9 +146,10 @@ func (s *pushersStatements) SelectPushers(
|
|||
|
||||
// deletePusher removes a single pusher by pushkey and user localpart.
|
||||
func (s *pushersStatements) DeletePusher(
|
||||
ctx context.Context, txn *sql.Tx, appid, pushkey, localpart string,
|
||||
ctx context.Context, txn *sql.Tx, appid, pushkey,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) error {
|
||||
_, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart)
|
||||
_, err := sqlutil.TxStmt(txn, s.deletePusherStmt).ExecContext(ctx, appid, pushkey, localpart, serverName)
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
@ -15,6 +15,8 @@
|
|||
package sqlite3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
|
@ -41,18 +43,24 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
|
|||
Up: deltas.UpRenameTables,
|
||||
Down: deltas.DownRenameTables,
|
||||
})
|
||||
m.AddMigrations(sqlutil.Migration{
|
||||
Version: "userapi: server names",
|
||||
Up: func(ctx context.Context, txn *sql.Tx) error {
|
||||
return deltas.UpServerNames(ctx, txn, serverName)
|
||||
},
|
||||
})
|
||||
if err = m.Up(base.Context()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
accountDataTable, err := NewSQLiteAccountDataTable(db)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewSQLiteAccountDataTable: %w", err)
|
||||
}
|
||||
accountsTable, err := NewSQLiteAccountsTable(db, serverName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewSQLiteAccountsTable: %w", err)
|
||||
}
|
||||
accountDataTable, err := NewSQLiteAccountDataTable(db)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewSQLiteAccountDataTable: %w", err)
|
||||
}
|
||||
devicesTable, err := NewSQLiteDevicesTable(db, serverName)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("NewSQLiteDevicesTable: %w", err)
|
||||
|
@ -93,6 +101,18 @@ func NewDatabase(base *base.BaseDendrite, dbProperties *config.DatabaseOptions,
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("NewSQLiteStatsTable: %w", err)
|
||||
}
|
||||
|
||||
m = sqlutil.NewMigrator(db)
|
||||
m.AddMigrations(sqlutil.Migration{
|
||||
Version: "userapi: server names populate",
|
||||
Up: func(ctx context.Context, txn *sql.Tx) error {
|
||||
return deltas.UpServerNamesPopulate(ctx, txn, serverName)
|
||||
},
|
||||
})
|
||||
if err = m.Up(base.Context()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &shared.Database{
|
||||
AccountDatas: accountDataTable,
|
||||
Accounts: accountsTable,
|
||||
|
|
|
@ -21,6 +21,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/auth/authtypes"
|
||||
)
|
||||
|
@ -34,21 +35,22 @@ CREATE TABLE IF NOT EXISTS userapi_threepids (
|
|||
medium TEXT NOT NULL DEFAULT 'email',
|
||||
-- The localpart of the Matrix user ID associated to this 3PID
|
||||
localpart TEXT NOT NULL,
|
||||
server_name TEXT NOT NULL,
|
||||
|
||||
PRIMARY KEY(threepid, medium)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS account_threepid_localpart ON userapi_threepids(localpart);
|
||||
CREATE INDEX IF NOT EXISTS account_threepid_localpart ON userapi_threepids(localpart, server_name);
|
||||
`
|
||||
|
||||
const selectLocalpartForThreePIDSQL = "" +
|
||||
"SELECT localpart FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
|
||||
"SELECT localpart, server_name FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
|
||||
|
||||
const selectThreePIDsForLocalpartSQL = "" +
|
||||
"SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1"
|
||||
"SELECT threepid, medium FROM userapi_threepids WHERE localpart = $1 AND server_name = $2"
|
||||
|
||||
const insertThreePIDSQL = "" +
|
||||
"INSERT INTO userapi_threepids (threepid, medium, localpart) VALUES ($1, $2, $3)"
|
||||
"INSERT INTO userapi_threepids (threepid, medium, localpart, server_name) VALUES ($1, $2, $3, $4)"
|
||||
|
||||
const deleteThreePIDSQL = "" +
|
||||
"DELETE FROM userapi_threepids WHERE threepid = $1 AND medium = $2"
|
||||
|
@ -79,19 +81,20 @@ func NewSQLiteThreePIDTable(db *sql.DB) (tables.ThreePIDTable, error) {
|
|||
|
||||
func (s *threepidStatements) SelectLocalpartForThreePID(
|
||||
ctx context.Context, txn *sql.Tx, threepid string, medium string,
|
||||
) (localpart string, err error) {
|
||||
) (localpart string, serverName gomatrixserverlib.ServerName, err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.selectLocalpartForThreePIDStmt)
|
||||
err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart)
|
||||
err = stmt.QueryRowContext(ctx, threepid, medium).Scan(&localpart, &serverName)
|
||||
if err == sql.ErrNoRows {
|
||||
return "", nil
|
||||
return "", "", nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *threepidStatements) SelectThreePIDsForLocalpart(
|
||||
ctx context.Context, localpart string,
|
||||
ctx context.Context,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (threepids []authtypes.ThreePID, err error) {
|
||||
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart)
|
||||
rows, err := s.selectThreePIDsForLocalpartStmt.QueryContext(ctx, localpart, serverName)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -113,10 +116,11 @@ func (s *threepidStatements) SelectThreePIDsForLocalpart(
|
|||
}
|
||||
|
||||
func (s *threepidStatements) InsertThreePID(
|
||||
ctx context.Context, txn *sql.Tx, threepid, medium, localpart string,
|
||||
ctx context.Context, txn *sql.Tx, threepid, medium,
|
||||
localpart string, serverName gomatrixserverlib.ServerName,
|
||||
) (err error) {
|
||||
stmt := sqlutil.TxStmt(txn, s.insertThreePIDStmt)
|
||||
_, err = stmt.ExecContext(ctx, threepid, medium, localpart)
|
||||
_, err = stmt.ExecContext(ctx, threepid, medium, localpart, serverName)
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue