mirror of
https://github.com/hoernschen/dendrite.git
synced 2025-08-02 06:12:45 +00:00
Refactor user API storage (#2202)
* Refactor User API database * Fix migration bugs
This commit is contained in:
parent
9bd5e414c9
commit
9f4a39e8e0
22 changed files with 1165 additions and 1671 deletions
|
@ -23,6 +23,7 @@ import (
|
|||
"github.com/matrix-org/dendrite/internal"
|
||||
"github.com/matrix-org/dendrite/internal/sqlutil"
|
||||
"github.com/matrix-org/dendrite/userapi/api"
|
||||
"github.com/matrix-org/dendrite/userapi/storage/tables"
|
||||
|
||||
"github.com/matrix-org/dendrite/clientapi/userutil"
|
||||
"github.com/matrix-org/gomatrixserverlib"
|
||||
|
@ -84,7 +85,6 @@ const updateDeviceLastSeen = "" +
|
|||
|
||||
type devicesStatements struct {
|
||||
db *sql.DB
|
||||
writer sqlutil.Writer
|
||||
insertDeviceStmt *sql.Stmt
|
||||
selectDevicesCountStmt *sql.Stmt
|
||||
selectDeviceByTokenStmt *sql.Stmt
|
||||
|
@ -98,55 +98,33 @@ type devicesStatements struct {
|
|||
serverName gomatrixserverlib.ServerName
|
||||
}
|
||||
|
||||
func (s *devicesStatements) execSchema(db *sql.DB) error {
|
||||
func NewSQLiteDevicesTable(db *sql.DB, serverName gomatrixserverlib.ServerName) (tables.DevicesTable, error) {
|
||||
s := &devicesStatements{
|
||||
db: db,
|
||||
serverName: serverName,
|
||||
}
|
||||
_, err := db.Exec(devicesSchema)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *devicesStatements) prepare(db *sql.DB, writer sqlutil.Writer, server gomatrixserverlib.ServerName) (err error) {
|
||||
s.db = db
|
||||
s.writer = writer
|
||||
if err = s.execSchema(db); err != nil {
|
||||
return
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.insertDeviceStmt, err = db.Prepare(insertDeviceSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectDevicesCountStmt, err = db.Prepare(selectDevicesCountSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectDeviceByTokenStmt, err = db.Prepare(selectDeviceByTokenSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectDeviceByIDStmt, err = db.Prepare(selectDeviceByIDSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectDevicesByLocalpartStmt, err = db.Prepare(selectDevicesByLocalpartSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.updateDeviceNameStmt, err = db.Prepare(updateDeviceNameSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteDeviceStmt, err = db.Prepare(deleteDeviceSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.deleteDevicesByLocalpartStmt, err = db.Prepare(deleteDevicesByLocalpartSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.selectDevicesByIDStmt, err = db.Prepare(selectDevicesByIDSQL); err != nil {
|
||||
return
|
||||
}
|
||||
if s.updateDeviceLastSeenStmt, err = db.Prepare(updateDeviceLastSeen); err != nil {
|
||||
return
|
||||
}
|
||||
s.serverName = server
|
||||
return
|
||||
return s, sqlutil.StatementList{
|
||||
{&s.insertDeviceStmt, insertDeviceSQL},
|
||||
{&s.selectDevicesCountStmt, selectDevicesCountSQL},
|
||||
{&s.selectDeviceByTokenStmt, selectDeviceByTokenSQL},
|
||||
{&s.selectDeviceByIDStmt, selectDeviceByIDSQL},
|
||||
{&s.selectDevicesByLocalpartStmt, selectDevicesByLocalpartSQL},
|
||||
{&s.updateDeviceNameStmt, updateDeviceNameSQL},
|
||||
{&s.deleteDeviceStmt, deleteDeviceSQL},
|
||||
{&s.deleteDevicesByLocalpartStmt, deleteDevicesByLocalpartSQL},
|
||||
{&s.selectDevicesByIDStmt, selectDevicesByIDSQL},
|
||||
{&s.updateDeviceLastSeenStmt, updateDeviceLastSeen},
|
||||
}.Prepare(db)
|
||||
}
|
||||
|
||||
// insertDevice creates a new device. Returns an error if any device with the same access token already exists.
|
||||
// Returns an error if the user already has a device with the given device ID.
|
||||
// Returns the device on success.
|
||||
func (s *devicesStatements) insertDevice(
|
||||
func (s *devicesStatements) InsertDevice(
|
||||
ctx context.Context, txn *sql.Tx, id, localpart, accessToken string,
|
||||
displayName *string, ipAddr, userAgent string,
|
||||
) (*api.Device, error) {
|
||||
|
@ -172,7 +150,7 @@ func (s *devicesStatements) insertDevice(
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (s *devicesStatements) deleteDevice(
|
||||
func (s *devicesStatements) DeleteDevice(
|
||||
ctx context.Context, txn *sql.Tx, id, localpart string,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteDeviceStmt)
|
||||
|
@ -180,7 +158,7 @@ func (s *devicesStatements) deleteDevice(
|
|||
return err
|
||||
}
|
||||
|
||||
func (s *devicesStatements) deleteDevices(
|
||||
func (s *devicesStatements) DeleteDevices(
|
||||
ctx context.Context, txn *sql.Tx, localpart string, devices []string,
|
||||
) error {
|
||||
orig := strings.Replace(deleteDevicesSQL, "($2)", sqlutil.QueryVariadicOffset(len(devices), 1), 1)
|
||||
|
@ -198,7 +176,7 @@ func (s *devicesStatements) deleteDevices(
|
|||
return err
|
||||
}
|
||||
|
||||
func (s *devicesStatements) deleteDevicesByLocalpart(
|
||||
func (s *devicesStatements) DeleteDevicesByLocalpart(
|
||||
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.deleteDevicesByLocalpartStmt)
|
||||
|
@ -206,7 +184,7 @@ func (s *devicesStatements) deleteDevicesByLocalpart(
|
|||
return err
|
||||
}
|
||||
|
||||
func (s *devicesStatements) updateDeviceName(
|
||||
func (s *devicesStatements) UpdateDeviceName(
|
||||
ctx context.Context, txn *sql.Tx, localpart, deviceID string, displayName *string,
|
||||
) error {
|
||||
stmt := sqlutil.TxStmt(txn, s.updateDeviceNameStmt)
|
||||
|
@ -214,7 +192,7 @@ func (s *devicesStatements) updateDeviceName(
|
|||
return err
|
||||
}
|
||||
|
||||
func (s *devicesStatements) selectDeviceByToken(
|
||||
func (s *devicesStatements) SelectDeviceByToken(
|
||||
ctx context.Context, accessToken string,
|
||||
) (*api.Device, error) {
|
||||
var dev api.Device
|
||||
|
@ -230,7 +208,7 @@ func (s *devicesStatements) selectDeviceByToken(
|
|||
|
||||
// selectDeviceByID retrieves a device from the database with the given user
|
||||
// localpart and deviceID
|
||||
func (s *devicesStatements) selectDeviceByID(
|
||||
func (s *devicesStatements) SelectDeviceByID(
|
||||
ctx context.Context, localpart, deviceID string,
|
||||
) (*api.Device, error) {
|
||||
var dev api.Device
|
||||
|
@ -247,7 +225,7 @@ func (s *devicesStatements) selectDeviceByID(
|
|||
return &dev, err
|
||||
}
|
||||
|
||||
func (s *devicesStatements) selectDevicesByLocalpart(
|
||||
func (s *devicesStatements) SelectDevicesByLocalpart(
|
||||
ctx context.Context, txn *sql.Tx, localpart, exceptDeviceID string,
|
||||
) ([]api.Device, error) {
|
||||
devices := []api.Device{}
|
||||
|
@ -288,7 +266,7 @@ func (s *devicesStatements) selectDevicesByLocalpart(
|
|||
return devices, nil
|
||||
}
|
||||
|
||||
func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
|
||||
func (s *devicesStatements) SelectDevicesByID(ctx context.Context, deviceIDs []string) ([]api.Device, error) {
|
||||
sqlQuery := strings.Replace(selectDevicesByIDSQL, "($1)", sqlutil.QueryVariadic(len(deviceIDs)), 1)
|
||||
iDeviceIDs := make([]interface{}, len(deviceIDs))
|
||||
for i := range deviceIDs {
|
||||
|
@ -317,7 +295,7 @@ func (s *devicesStatements) selectDevicesByID(ctx context.Context, deviceIDs []s
|
|||
return devices, rows.Err()
|
||||
}
|
||||
|
||||
func (s *devicesStatements) updateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error {
|
||||
func (s *devicesStatements) UpdateDeviceLastSeen(ctx context.Context, txn *sql.Tx, localpart, deviceID, ipAddr string) error {
|
||||
lastSeenTs := time.Now().UnixNano() / 1000000
|
||||
stmt := sqlutil.TxStmt(txn, s.updateDeviceLastSeenStmt)
|
||||
_, err := stmt.ExecContext(ctx, lastSeenTs, ipAddr, localpart, deviceID)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue