Refactor user API storage (#2202)

* Refactor User API database

* Fix migration bugs
This commit is contained in:
Neil Alexander 2022-02-18 13:51:59 +00:00 committed by GitHub
parent 9bd5e414c9
commit 9f4a39e8e0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 1165 additions and 1671 deletions

View file

@ -23,6 +23,7 @@ import (
"github.com/matrix-org/dendrite/internal"
"github.com/matrix-org/dendrite/internal/sqlutil"
"github.com/matrix-org/dendrite/userapi/api"
"github.com/matrix-org/dendrite/userapi/storage/tables"
"github.com/matrix-org/dendrite/clientapi/userutil"
"github.com/matrix-org/gomatrixserverlib"
@ -84,7 +85,6 @@ const updateDeviceLastSeen = "" +
type devicesStatements struct {
db *sql.DB
writer sqlutil.Writer
insertDeviceStmt *sql.Stmt
selectDevicesCountStmt *sql.Stmt
selectDeviceByTokenStmt *sql.Stmt
@ -98,55 +98,33 @@ type devicesStatements struct {
serverName gomatrixserverlib.ServerName
}
func (s *devicesStatements) execSchema(db *sql.DB) error {
func NewSQLiteDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.DevicesTable, error) {
s := &devicesStatements{
db: db,
serverName: serverName,
}
_, err := db.Exec(devicesSchema)
return err
}
func (s *devicesStatements) prepare(db *sql.DB, writer sqlutil.Writer, server gomatrixserverlib.ServerName) (err error) {
s.db = db
s.writer = writer
if err = s.execSchema(db); err != nil {
return
if err != nil {
return nil, err
}
if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil {
return
}
if s.selectDevicesCountStmt, err = db.Prepare(selectDevicesCountSQL); err != nil {
return
}
if s.selectDeviceByTokenStmt, err = db.Prepare(selectDeviceByTokenSQL); err != nil {
return
}
if s.selectDeviceByIDStmt, err = db.Prepare(selectDeviceByIDSQL); err != nil {
return
}
if s.selectDevicesByLocalpartStmt, err = db.Prepare(selectDevicesByLocalpartSQL); err != nil {
return
}
if s.updateDeviceNameStmt, err = db.Prepare(updateDeviceNameSQL); err != nil {
return
}
if s.deleteDeviceStmt, err = db.Prepare(deleteDeviceSQL); err != nil {
return
}
if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil {
return
}
if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil {
return
}
if s.updateDeviceLastSeenStmt, err = db.Prepare(updateDeviceLastSeen); err != nil {
return
}
s.serverName = server
return
return s, sqlutil.StatementList{
{&s.insertDeviceStmt, insertDeviceSQL},
{&s.selectDevicesCountStmt, selectDevicesCountSQL},
{&s.selectDeviceByTokenStmt, selectDeviceByTokenSQL},
{&s.selectDeviceByIDStmt, selectDeviceByIDSQL},
{&s.selectDevicesByLocalpartStmt, selectDevicesByLocalpartSQL},
{&s.updateDeviceNameStmt, updateDeviceNameSQL},
{&s.deleteDeviceStmt, deleteDeviceSQL},
{&s.deleteDevicesByLocalpartStmt, deleteDevicesByLocalpartSQL},
{&s.selectDevicesByIDStmt, selectDevicesByIDSQL},
{&s.updateDeviceLastSeenStmt, updateDeviceLastSeen},
}.Prepare(db)
}
// insertDevice creates a new device. Returns an error if any device with the same access token already exists.
// Returns an error if the user already has a device with the given device ID.
// Returns the device on success.
func (s *devicesStatements) insertDevice(
func (s *devicesStatements) InsertDevice(
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
displayName *string, ipAddr, userAgent string,
) (*api.Device, error) {
@ -172,7 +150,7 @@ func (s *devicesStatements) insertDevice(
}, nil
}
func (s *devicesStatements) deleteDevice(
func (s *devicesStatements) DeleteDevice(
ctx context.Context, txn *sql.Tx, id, localpart string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
@ -180,7 +158,7 @@ func (s *devicesStatements) deleteDevice(
return err
}
func (s *devicesStatements) deleteDevices(
func (s *devicesStatements) DeleteDevices(
ctx context.Context, txn *sql.Tx, localpart string, devices []string,
) error {
orig := strings.Replace(deleteDevicesSQL, "($2)", sqlutil.QueryVariadicOffset(len(devices), 1), 1)
@ -198,7 +176,7 @@ func (s *devicesStatements) deleteDevices(
return err
}
func (s *devicesStatements) deleteDevicesByLocalpart(
func (s *devicesStatements) DeleteDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
) error {
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
@ -206,7 +184,7 @@ func (s *devicesStatements) deleteDevicesByLocalpart(
return err
}
func (s *devicesStatements) updateDeviceName(
func (s *devicesStatements) UpdateDeviceName(
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
) error {
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
@ -214,7 +192,7 @@ func (s *devicesStatements) updateDeviceName(
return err
}
func (s *devicesStatements) selectDeviceByToken(
func (s *devicesStatements) SelectDeviceByToken(
ctx context.Context, accessToken string,
) (*api.Device, error) {
var dev api.Device
@ -230,7 +208,7 @@ func (s *devicesStatements) selectDeviceByToken(
// selectDeviceByID retrieves a device from the database with the given user
// localpart and deviceID
func (s *devicesStatements) selectDeviceByID(
func (s *devicesStatements) SelectDeviceByID(
ctx context.Context, localpart, deviceID string,
) (*api.Device, error) {
var dev api.Device
@ -247,7 +225,7 @@ func (s *devicesStatements) selectDeviceByID(
return &dev, err
}
func (s *devicesStatements) selectDevicesByLocalpart(
func (s *devicesStatements) SelectDevicesByLocalpart(
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
) ([]api.Device, error) {
devices := []api.Device{}
@ -288,7 +266,7 @@ func (s *devicesStatements) selectDevicesByLocalpart(
return devices, nil
}
func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
sqlQuery := strings.Replace(selectDevicesByIDSQL, "($1)", sqlutil.QueryVariadic(len(deviceIDs)), 1)
iDeviceIDs := make([]interface{}, len(deviceIDs))
for i := range deviceIDs {
@ -317,7 +295,7 @@ func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []s
return devices, rows.Err()
}
func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error {
func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error {
lastSeenTs := time.Now().UnixNano() / 1000000
stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt)
_, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, localpart, deviceID)