mirror of
https://github.com/hoernschen/dendrite.git
synced 2025-08-02 14:12:47 +00:00
Refactor user API storage (#2202)
* Refactor User API database * Fix migration bugs
This commit is contained in:
parent
9bd5e414c9
commit
9f4a39e8e0
22 changed files with 1165 additions and 1671 deletions
|
@ -24,6 +24,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
)
|
||||
|
||||
|
@ -111,53 +112,32 @@ type devicesStatements struct {
|
|||
serverName gomatrixserverlib.ServerName
|
||||
}
|
||||
|
||||
func (s *devicesStatements) execSchema(db *sql.DB) error {
|
||||
func NewPostgresDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.DevicesTable, error) {
|
||||
s := &devicesStatements{
|
||||
serverName: serverName,
|
||||
}
|
||||
_, err := db.Exec(devicesSchema)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *devicesStatements) prepare(db *sql.DB, server gomatrixserverlib.ServerName) (err error) {
|
||||
if err = s.execSchema(db); err != nil {
|
||||
return
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectDeviceByTokenStmt, err = db.Prepare(selectDeviceByTokenSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectDeviceByIDStmt, err = db.Prepare(selectDeviceByIDSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectDevicesByLocalpartStmt, err = db.Prepare(selectDevicesByLocalpartSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.updateDeviceNameStmt, err = db.Prepare(updateDeviceNameSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteDeviceStmt, err = db.Prepare(deleteDeviceSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteDevicesStmt, err = db.Prepare(deleteDevicesSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.updateDeviceLastSeenStmt, err = db.Prepare(updateDeviceLastSeen); err != nil {
|
||||
return
|
||||
}
|
||||
s.serverName = server
|
||||
return
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.insertDeviceStmt, insertDeviceSQL},
|
||||
{&s.selectDeviceByTokenStmt, selectDeviceByTokenSQL},
|
||||
{&s.selectDeviceByIDStmt, selectDeviceByIDSQL},
|
||||
{&s.selectDevicesByLocalpartStmt, selectDevicesByLocalpartSQL},
|
||||
{&s.updateDeviceNameStmt, updateDeviceNameSQL},
|
||||
{&s.deleteDeviceStmt, deleteDeviceSQL},
|
||||
{&s.deleteDevicesByLocalpartStmt, deleteDevicesByLocalpartSQL},
|
||||
{&s.deleteDevicesStmt, deleteDevicesSQL},
|
||||
{&s.selectDevicesByIDStmt, selectDevicesByIDSQL},
|
||||
{&s.updateDeviceLastSeenStmt, updateDeviceLastSeen},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
// insertDevice creates a new device. Returns an error if any device with the same access token already exists.
|
||||
// Returns an error if the user already has a device with the given device ID.
|
||||
// Returns the device on success.
|
||||
func (s *devicesStatements) insertDevice(
|
||||
func (s *devicesStatements) InsertDevice(
|
||||
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
|
||||
displayName *string, ipAddr, userAgent string,
|
||||
) (*api.Device, error) {
|
||||
|
@ -179,7 +159,7 @@ func (s *devicesStatements) insertDevice(
|
|||
}
|
||||
|
||||
// deleteDevice removes a single device by id and user localpart.
|
||||
func (s *devicesStatements) deleteDevice(
|
||||
func (s *devicesStatements) DeleteDevice(
|
||||
ctx context.Context, txn *sql.Tx, id, localpart string,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
|
||||
|
@ -189,7 +169,7 @@ func (s *devicesStatements) deleteDevice(
|
|||
|
||||
// deleteDevices removes a single or multiple devices by ids and user localpart.
|
||||
// Returns an error if the execution failed.
|
||||
func (s *devicesStatements) deleteDevices(
|
||||
func (s *devicesStatements) DeleteDevices(
|
||||
ctx context.Context, txn *sql.Tx, localpart string, devices []string,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteDevicesStmt)
|
||||
|
@ -199,7 +179,7 @@ func (s *devicesStatements) deleteDevices(
|
|||
|
||||
// deleteDevicesByLocalpart removes all devices for the
|
||||
// given user localpart.
|
||||
func (s *devicesStatements) deleteDevicesByLocalpart(
|
||||
func (s *devicesStatements) DeleteDevicesByLocalpart(
|
||||
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
|
||||
|
@ -207,7 +187,7 @@ func (s *devicesStatements) deleteDevicesByLocalpart(
|
|||
return err
|
||||
}
|
||||
|
||||
func (s *devicesStatements) updateDeviceName(
|
||||
func (s *devicesStatements) UpdateDeviceName(
|
||||
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
|
||||
|
@ -215,7 +195,7 @@ func (s *devicesStatements) updateDeviceName(
|
|||
return err
|
||||
}
|
||||
|
||||
func (s *devicesStatements) selectDeviceByToken(
|
||||
func (s *devicesStatements) SelectDeviceByToken(
|
||||
ctx context.Context, accessToken string,
|
||||
) (*api.Device, error) {
|
||||
var dev api.Device
|
||||
|
@ -231,7 +211,7 @@ func (s *devicesStatements) selectDeviceByToken(
|
|||
|
||||
// selectDeviceByID retrieves a device from the database with the given user
|
||||
// localpart and deviceID
|
||||
func (s *devicesStatements) selectDeviceByID(
|
||||
func (s *devicesStatements) SelectDeviceByID(
|
||||
ctx context.Context, localpart, deviceID string,
|
||||
) (*api.Device, error) {
|
||||
var dev api.Device
|
||||
|
@ -248,7 +228,7 @@ func (s *devicesStatements) selectDeviceByID(
|
|||
return &dev, err
|
||||
}
|
||||
|
||||
func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
|
||||
func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
|
||||
rows, err := s.selectDevicesByIDStmt.QueryContext(ctx, pq.StringArray(deviceIDs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -271,7 +251,7 @@ func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []s
|
|||
return devices, rows.Err()
|
||||
}
|
||||
|
||||
func (s *devicesStatements) selectDevicesByLocalpart(
|
||||
func (s *devicesStatements) SelectDevicesByLocalpart(
|
||||
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
|
||||
) ([]api.Device, error) {
|
||||
devices := []api.Device{}
|
||||
|
@ -313,7 +293,7 @@ func (s *devicesStatements) selectDevicesByLocalpart(
|
|||
return devices, rows.Err()
|
||||
}
|
||||
|
||||
func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error {
|
||||
func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error {
|
||||
lastSeenTs := time.Now().UnixNano() / 1000000
|
||||
stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt)
|
||||
_, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, localpart, deviceID)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue