Move account/device DBs to userapi (#1141)

This commit is contained in:
Kegsay 2020-06-17 12:05:56 +01:00 committed by GitHub
parent 5d5aa0a31d
commit e09d24e732
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
51 changed files with 40 additions and 40 deletions

View file

@ -0,0 +1,38 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package devices
import (
"context"
"github.com/matrix-org/dendrite/userapi/api"
)
type Database interface {
GetDeviceByAccessToken(ctx context.Context, token string) (*api.Device, error)
GetDeviceByID(ctx context.Context, localpart, deviceID string) (*api.Device, error)
GetDevicesByLocalpart(ctx context.Context, localpart string) ([]api.Device, error)
// CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
// an error will be returned.
// If no device ID is given one is generated.
// Returns the device on success.
CreateDevice(ctx context.Context, localpart string, deviceID *string, accessToken string, displayName *string) (dev *api.Device, returnErr error)
UpdateDevice(ctx context.Context, localpart, deviceID string, displayName *string) error
RemoveDevice(ctx context.Context, deviceID, localpart string) error
RemoveDevices(ctx context.Context, localpart string, devices []string) error
RemoveAllDevices(ctx context.Context, localpart string) error
}

View file

@ -0,0 +1,249 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"database/sql"
"time"
"github.com/lib/pq"
"github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
)
const devicesSchema = `
-- This sequence is used for automatic allocation of session_id.
CREATE SEQUENCE IF NOT EXISTS device_session_id_seq START 1;
-- Stores data about devices.
CREATE TABLE IF NOT EXISTS device_devices (
-- The access token granted to this device. This has to be the primary key
-- so we can distinguish which device is making a given request.
access_token TEXT NOT NULL PRIMARY KEY,
-- The auto-allocated unique ID of the session identified by the access token.
-- This can be used as a secure substitution of the access token in situations
-- where data is associated with access tokens (e.g. transaction storage),
-- so we don't have to store users' access tokens everywhere.
session_id BIGINT NOT NULL DEFAULT nextval('device_session_id_seq'),
-- The device identifier. This only needs to uniquely identify a device for a given user, not globally.
-- access_tokens will be clobbered based on the device ID for a user.
device_id TEXT NOT NULL,
-- The Matrix user ID localpart for this device. This is preferable to storing the full user_id
-- as it is smaller, makes it clearer that we only manage devices for our own users, and may make
-- migration to different domain names easier.
localpart TEXT NOT NULL,
-- When this devices was first recognised on the network, as a unix timestamp (ms resolution).
created_ts BIGINT NOT NULL,
-- The display name, human friendlier than device_id and updatable
display_name TEXT
-- TODO: device keys, device display names, last used ts and IP address?, token restrictions (if 3rd-party OAuth app)
);
-- Device IDs must be unique for a given user.
CREATE UNIQUE INDEX IF NOT EXISTS device_localpart_id_idx ON device_devices(localpart, device_id);
`
const insertDeviceSQL = "" +
"INSERT INTO device_devices(device_id, localpart, access_token, created_ts, display_name) VALUES ($1, $2, $3, $4, $5)" +
" RETURNING session_id"
const selectDeviceByTokenSQL = "" +
"SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1"
const selectDeviceByIDSQL = "" +
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
const selectDevicesByLocalpartSQL = "" +
"SELECT device_id, display_name FROM device_devices WHERE localpart = $1"
const updateDeviceNameSQL = "" +
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
const deleteDeviceSQL = "" +
"DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2"
const deleteDevicesByLocalpartSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1"
const deleteDevicesSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id = ANY($2)"
type devicesStatements struct {
insertDeviceStmt *sql.Stmt
selectDeviceByTokenStmt *sql.Stmt
selectDeviceByIDStmt *sql.Stmt
selectDevicesByLocalpartStmt *sql.Stmt
updateDeviceNameStmt *sql.Stmt
deleteDeviceStmt *sql.Stmt
deleteDevicesByLocalpartStmt *sql.Stmt
deleteDevicesStmt *sql.Stmt
serverName gomatrixserverlib.ServerName
}
func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
_, err = db.Exec(devicesSchema)
if err != nil {
return
}
if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil {
return
}
if s.selectDeviceByTokenStmt, err = db.Prepare(selectDeviceByTokenSQL); err != nil {
return
}
if s.selectDeviceByIDStmt, err = db.Prepare(selectDeviceByIDSQL); err != nil {
return
}
if s.selectDevicesByLocalpartStmt, err = db.Prepare(selectDevicesByLocalpartSQL); err != nil {
return
}
if s.updateDeviceNameStmt, err = db.Prepare(updateDeviceNameSQL); err != nil {
return
}
if s.deleteDeviceStmt, err = db.Prepare(deleteDeviceSQL); err != nil {
return
}
if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil {
return
}
if s.deleteDevicesStmt, err = db.Prepare(deleteDevicesSQL); err != nil {
return
}
s.serverName = server
return
}
// insertDevice creates a new device. Returns an error if any device with the same access token already exists.
// Returns an error if the user already has a device with the given device ID.
// Returns the device on success.
func (s *devicesStatements) insertDevice(
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
displayName *string,
) (*api.Device, error) {
createdTimeMS := time.Now().UnixNano() / 1000000
var sessionID int64
stmt := sqlutil.TxStmt(txn, s.insertDeviceStmt)
if err := stmt.QueryRowContext(ctx, id, localpart, accessToken, createdTimeMS, displayName).Scan(&sessionID); err != nil {
return nil, err
}
return &api.Device{
ID: id,
UserID: userutil.MakeUserID(localpart, s.serverName),
AccessToken: accessToken,
SessionID: sessionID,
}, nil
}
// deleteDevice removes a single device by id and user localpart.
func (s *devicesStatements) deleteDevice(
ctx context.Context, txn *sql.Tx, id, localpart string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
_, err := stmt.ExecContext(ctx, id, localpart)
return err
}
// deleteDevices removes a single or multiple devices by ids and user localpart.
// Returns an error if the execution failed.
func (s *devicesStatements) deleteDevices(
ctx context.Context, txn *sql.Tx, localpart string, devices []string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesStmt)
_, err := stmt.ExecContext(ctx, localpart, pq.Array(devices))
return err
}
// deleteDevicesByLocalpart removes all devices for the
// given user localpart.
func (s *devicesStatements) deleteDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
_, err := stmt.ExecContext(ctx, localpart)
return err
}
func (s *devicesStatements) updateDeviceName(
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
) error {
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
_, err := stmt.ExecContext(ctx, displayName, localpart, deviceID)
return err
}
func (s *devicesStatements) selectDeviceByToken(
ctx context.Context, accessToken string,
) (*api.Device, error) {
var dev api.Device
var localpart string
stmt := s.selectDeviceByTokenStmt
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart)
if err == nil {
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
dev.AccessToken = accessToken
}
return &dev, err
}
// selectDeviceByID retrieves a device from the database with the given user
// localpart and deviceID
func (s *devicesStatements) selectDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*api.Device, error) {
var dev api.Device
stmt := s.selectDeviceByIDStmt
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&dev.DisplayName)
if err == nil {
dev.ID = deviceID
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
}
return &dev, err
}
func (s *devicesStatements) selectDevicesByLocalpart(
ctx context.Context, localpart string,
) ([]api.Device, error) {
devices := []api.Device{}
rows, err := s.selectDevicesByLocalpartStmt.QueryContext(ctx, localpart)
if err != nil {
return devices, err
}
defer internal.CloseAndLogIfError(ctx, rows, "selectDevicesByLocalpart: rows.close() failed")
for rows.Next() {
var dev api.Device
var id, displayname sql.NullString
err = rows.Scan(&id, &displayname)
if err != nil {
return devices, err
}
if id.Valid {
dev.ID = id.String
}
if displayname.Valid {
dev.DisplayName = displayname.String
}
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
devices = append(devices, dev)
}
return devices, rows.Err()
}

View file

@ -0,0 +1,182 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postgres
import (
"context"
"crypto/rand"
"database/sql"
"encoding/base64"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
)
// The length of generated device IDs
var deviceIDByteLength = 6
// Database represents a device database.
type Database struct {
db *sql.DB
devices devicesStatements
}
// NewDatabase creates a new device database
func NewDatabase(dataSourceName string, dbProperties sqlutil.DbProperties, serverName gomatrixserverlib.ServerName) (*Database, error) {
var db *sql.DB
var err error
if db, err = sqlutil.Open("postgres", dataSourceName, dbProperties); err != nil {
return nil, err
}
d := devicesStatements{}
if err = d.prepare(db, serverName); err != nil {
return nil, err
}
return &Database{db, d}, nil
}
// GetDeviceByAccessToken returns the device matching the given access token.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByAccessToken(
ctx context.Context, token string,
) (*api.Device, error) {
return d.devices.selectDeviceByToken(ctx, token)
}
// GetDeviceByID returns the device matching the given ID.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*api.Device, error) {
return d.devices.selectDeviceByID(ctx, localpart, deviceID)
}
// GetDevicesByLocalpart returns the devices matching the given localpart.
func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string,
) ([]api.Device, error) {
return d.devices.selectDevicesByLocalpart(ctx, localpart)
}
// CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
// an error will be returned.
// If no device ID is given one is generated.
// Returns the device on success.
func (d *Database) CreateDevice(
ctx context.Context, localpart string, deviceID *string, accessToken string,
displayName *string,
) (dev *api.Device, returnErr error) {
if deviceID != nil {
returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error
// Revoke existing tokens for this device
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
return err
}
dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName)
return err
})
} else {
// We generate device IDs in a loop in case its already taken.
// We cap this at going round 5 times to ensure we don't spin forever
var newDeviceID string
for i := 1; i <= 5; i++ {
newDeviceID, returnErr = generateDeviceID()
if returnErr != nil {
return
}
returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName)
return err
})
if returnErr == nil {
return
}
}
}
return
}
// generateDeviceID creates a new device id. Returns an error if failed to generate
// random bytes.
func generateDeviceID() (string, error) {
b := make([]byte, deviceIDByteLength)
_, err := rand.Read(b)
if err != nil {
return "", err
}
// url-safe no padding
return base64.RawURLEncoding.EncodeToString(b), nil
}
// UpdateDevice updates the given device with the display name.
// Returns SQL error if there are problems and nil on success.
func (d *Database) UpdateDevice(
ctx context.Context, localpart, deviceID string, displayName *string,
) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
})
}
// RemoveDevice revokes a device by deleting the entry in the database
// matching with the given device ID and user ID localpart.
// If the device doesn't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevice(
ctx context.Context, deviceID, localpart string,
) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveDevices revokes one or more devices by deleting the entry in the database
// matching with the given device IDs and user ID localpart.
// If the devices don't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevices(
ctx context.Context, localpart string, devices []string,
) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveAllDevices revokes devices by deleting the entry in the
// database matching the given user ID localpart.
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveAllDevices(
ctx context.Context, localpart string,
) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows {
return err
}
return nil
})
}

View file

@ -0,0 +1,249 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"database/sql"
"strings"
"time"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/gomatrixserverlib"
)
const devicesSchema = `
-- This sequence is used for automatic allocation of session_id.
-- CREATE SEQUENCE IF NOT EXISTS device_session_id_seq START 1;
-- Stores data about devices.
CREATE TABLE IF NOT EXISTS device_devices (
access_token TEXT PRIMARY KEY,
session_id INTEGER,
device_id TEXT ,
localpart TEXT ,
created_ts BIGINT,
display_name TEXT,
UNIQUE (localpart, device_id)
);
`
const insertDeviceSQL = "" +
"INSERT INTO device_devices (device_id, localpart, access_token, created_ts, display_name, session_id)" +
" VALUES ($1, $2, $3, $4, $5, $6)"
const selectDevicesCountSQL = "" +
"SELECT COUNT(access_token) FROM device_devices"
const selectDeviceByTokenSQL = "" +
"SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1"
const selectDeviceByIDSQL = "" +
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
const selectDevicesByLocalpartSQL = "" +
"SELECT device_id, display_name FROM device_devices WHERE localpart = $1"
const updateDeviceNameSQL = "" +
"UPDATE device_devices SET display_name = $1 WHERE localpart = $2 AND device_id = $3"
const deleteDeviceSQL = "" +
"DELETE FROM device_devices WHERE device_id = $1 AND localpart = $2"
const deleteDevicesByLocalpartSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1"
const deleteDevicesSQL = "" +
"DELETE FROM device_devices WHERE localpart = $1 AND device_id IN ($2)"
type devicesStatements struct {
db *sql.DB
insertDeviceStmt *sql.Stmt
selectDevicesCountStmt *sql.Stmt
selectDeviceByTokenStmt *sql.Stmt
selectDeviceByIDStmt *sql.Stmt
selectDevicesByLocalpartStmt *sql.Stmt
updateDeviceNameStmt *sql.Stmt
deleteDeviceStmt *sql.Stmt
deleteDevicesByLocalpartStmt *sql.Stmt
serverName gomatrixserverlib.ServerName
}
func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
s.db = db
_, err = db.Exec(devicesSchema)
if err != nil {
return
}
if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil {
return
}
if s.selectDevicesCountStmt, err = db.Prepare(selectDevicesCountSQL); err != nil {
return
}
if s.selectDeviceByTokenStmt, err = db.Prepare(selectDeviceByTokenSQL); err != nil {
return
}
if s.selectDeviceByIDStmt, err = db.Prepare(selectDeviceByIDSQL); err != nil {
return
}
if s.selectDevicesByLocalpartStmt, err = db.Prepare(selectDevicesByLocalpartSQL); err != nil {
return
}
if s.updateDeviceNameStmt, err = db.Prepare(updateDeviceNameSQL); err != nil {
return
}
if s.deleteDeviceStmt, err = db.Prepare(deleteDeviceSQL); err != nil {
return
}
if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil {
return
}
s.serverName = server
return
}
// insertDevice creates a new device. Returns an error if any device with the same access token already exists.
// Returns an error if the user already has a device with the given device ID.
// Returns the device on success.
func (s *devicesStatements) insertDevice(
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
displayName *string,
) (*api.Device, error) {
createdTimeMS := time.Now().UnixNano() / 1000000
var sessionID int64
countStmt := sqlutil.TxStmt(txn, s.selectDevicesCountStmt)
insertStmt := sqlutil.TxStmt(txn, s.insertDeviceStmt)
if err := countStmt.QueryRowContext(ctx).Scan(&sessionID); err != nil {
return nil, err
}
sessionID++
if _, err := insertStmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName, sessionID); err != nil {
return nil, err
}
return &api.Device{
ID: id,
UserID: userutil.MakeUserID(localpart, s.serverName),
AccessToken: accessToken,
SessionID: sessionID,
}, nil
}
func (s *devicesStatements) deleteDevice(
ctx context.Context, txn *sql.Tx, id, localpart string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
_, err := stmt.ExecContext(ctx, id, localpart)
return err
}
func (s *devicesStatements) deleteDevices(
ctx context.Context, txn *sql.Tx, localpart string, devices []string,
) error {
orig := strings.Replace(deleteDevicesSQL, "($1)", sqlutil.QueryVariadic(len(devices)), 1)
prep, err := s.db.Prepare(orig)
if err != nil {
return err
}
stmt := sqlutil.TxStmt(txn, prep)
params := make([]interface{}, len(devices)+1)
params[0] = localpart
for i, v := range devices {
params[i+1] = v
}
params = append(params, params...)
_, err = stmt.ExecContext(ctx, params...)
return err
}
func (s *devicesStatements) deleteDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
_, err := stmt.ExecContext(ctx, localpart)
return err
}
func (s *devicesStatements) updateDeviceName(
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
) error {
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
_, err := stmt.ExecContext(ctx, displayName, localpart, deviceID)
return err
}
func (s *devicesStatements) selectDeviceByToken(
ctx context.Context, accessToken string,
) (*api.Device, error) {
var dev api.Device
var localpart string
stmt := s.selectDeviceByTokenStmt
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart)
if err == nil {
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
dev.AccessToken = accessToken
}
return &dev, err
}
// selectDeviceByID retrieves a device from the database with the given user
// localpart and deviceID
func (s *devicesStatements) selectDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*api.Device, error) {
var dev api.Device
stmt := s.selectDeviceByIDStmt
err := stmt.QueryRowContext(ctx, localpart, deviceID).Scan(&dev.DisplayName)
if err == nil {
dev.ID = deviceID
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
}
return &dev, err
}
func (s *devicesStatements) selectDevicesByLocalpart(
ctx context.Context, localpart string,
) ([]api.Device, error) {
devices := []api.Device{}
rows, err := s.selectDevicesByLocalpartStmt.QueryContext(ctx, localpart)
if err != nil {
return devices, err
}
for rows.Next() {
var dev api.Device
var id, displayname sql.NullString
err = rows.Scan(&id, &displayname)
if err != nil {
return devices, err
}
if id.Valid {
dev.ID = id.String
}
if displayname.Valid {
dev.DisplayName = displayname.String
}
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
devices = append(devices, dev)
}
return devices, nil
}

View file

@ -0,0 +1,188 @@
// Copyright 2017 Vector Creations Ltd
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sqlite3
import (
"context"
"crypto/rand"
"database/sql"
"encoding/base64"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/gomatrixserverlib"
_ "github.com/mattn/go-sqlite3"
)
// The length of generated device IDs
var deviceIDByteLength = 6
// Database represents a device database.
type Database struct {
db *sql.DB
devices devicesStatements
}
// NewDatabase creates a new device database
func NewDatabase(dataSourceName string, serverName gomatrixserverlib.ServerName) (*Database, error) {
var db *sql.DB
var err error
cs, err := sqlutil.ParseFileURI(dataSourceName)
if err != nil {
return nil, err
}
if db, err = sqlutil.Open(sqlutil.SQLiteDriverName(), cs, nil); err != nil {
return nil, err
}
d := devicesStatements{}
if err = d.prepare(db, serverName); err != nil {
return nil, err
}
return &Database{db, d}, nil
}
// GetDeviceByAccessToken returns the device matching the given access token.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByAccessToken(
ctx context.Context, token string,
) (*api.Device, error) {
return d.devices.selectDeviceByToken(ctx, token)
}
// GetDeviceByID returns the device matching the given ID.
// Returns sql.ErrNoRows if no matching device was found.
func (d *Database) GetDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*api.Device, error) {
return d.devices.selectDeviceByID(ctx, localpart, deviceID)
}
// GetDevicesByLocalpart returns the devices matching the given localpart.
func (d *Database) GetDevicesByLocalpart(
ctx context.Context, localpart string,
) ([]api.Device, error) {
return d.devices.selectDevicesByLocalpart(ctx, localpart)
}
// CreateDevice makes a new device associated with the given user ID localpart.
// If there is already a device with the same device ID for this user, that access token will be revoked
// and replaced with the given accessToken. If the given accessToken is already in use for another device,
// an error will be returned.
// If no device ID is given one is generated.
// Returns the device on success.
func (d *Database) CreateDevice(
ctx context.Context, localpart string, deviceID *string, accessToken string,
displayName *string,
) (dev *api.Device, returnErr error) {
if deviceID != nil {
returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error
// Revoke existing tokens for this device
if err = d.devices.deleteDevice(ctx, txn, *deviceID, localpart); err != nil {
return err
}
dev, err = d.devices.insertDevice(ctx, txn, *deviceID, localpart, accessToken, displayName)
return err
})
} else {
// We generate device IDs in a loop in case its already taken.
// We cap this at going round 5 times to ensure we don't spin forever
var newDeviceID string
for i := 1; i <= 5; i++ {
newDeviceID, returnErr = generateDeviceID()
if returnErr != nil {
return
}
returnErr = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
var err error
dev, err = d.devices.insertDevice(ctx, txn, newDeviceID, localpart, accessToken, displayName)
return err
})
if returnErr == nil {
return
}
}
}
return
}
// generateDeviceID creates a new device id. Returns an error if failed to generate
// random bytes.
func generateDeviceID() (string, error) {
b := make([]byte, deviceIDByteLength)
_, err := rand.Read(b)
if err != nil {
return "", err
}
// url-safe no padding
return base64.RawURLEncoding.EncodeToString(b), nil
}
// UpdateDevice updates the given device with the display name.
// Returns SQL error if there are problems and nil on success.
func (d *Database) UpdateDevice(
ctx context.Context, localpart, deviceID string, displayName *string,
) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
return d.devices.updateDeviceName(ctx, txn, localpart, deviceID, displayName)
})
}
// RemoveDevice revokes a device by deleting the entry in the database
// matching with the given device ID and user ID localpart.
// If the device doesn't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevice(
ctx context.Context, deviceID, localpart string,
) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.devices.deleteDevice(ctx, txn, deviceID, localpart); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveDevices revokes one or more devices by deleting the entry in the database
// matching with the given device IDs and user ID localpart.
// If the devices don't exist, it will not return an error
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveDevices(
ctx context.Context, localpart string, devices []string,
) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.devices.deleteDevices(ctx, txn, localpart, devices); err != sql.ErrNoRows {
return err
}
return nil
})
}
// RemoveAllDevices revokes devices by deleting the entry in the
// database matching the given user ID localpart.
// If something went wrong during the deletion, it will return the SQL error.
func (d *Database) RemoveAllDevices(
ctx context.Context, localpart string,
) error {
return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error {
if err := d.devices.deleteDevicesByLocalpart(ctx, txn, localpart); err != sql.ErrNoRows {
return err
}
return nil
})
}

View file

@ -0,0 +1,43 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// +build !wasm
package devices
import (
"net/url"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/devices/postgres"
"github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3"
"github.com/matrix-org/gomatrixserverlib"
)
// NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme)
// and sets postgres connection parameters
func NewDatabase(dataSourceName string, dbProperties sqlutil.DbProperties, serverName gomatrixserverlib.ServerName) (Database, error) {
uri, err := url.Parse(dataSourceName)
if err != nil {
return postgres.NewDatabase(dataSourceName, dbProperties, serverName)
}
switch uri.Scheme {
case "postgres":
return postgres.NewDatabase(dataSourceName, dbProperties, serverName)
case "file":
return sqlite3.NewDatabase(dataSourceName, serverName)
default:
return postgres.NewDatabase(dataSourceName, dbProperties, serverName)
}
}

View file

@ -0,0 +1,43 @@
// Copyright 2020 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package devices
import (
"fmt"
"net/url"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3"
"github.com/matrix-org/gomatrixserverlib"
)
func NewDatabase(
dataSourceName string,
dbProperties sqlutil.DbProperties, // nolint:unparam
serverName gomatrixserverlib.ServerName,
) (Database, error) {
uri, err := url.Parse(dataSourceName)
if err != nil {
return nil, fmt.Errorf("Cannot use postgres implementation")
}
switch uri.Scheme {
case "postgres":
return nil, fmt.Errorf("Cannot use postgres implementation")
case "file":
return sqlite3.NewDatabase(dataSourceName, serverName)
default:
return nil, fmt.Errorf("Cannot use postgres implementation")
}
}