Fix SQLite session_id (#2977)

This fixes an issue with device_id/session_ids.
If a `device_id` is reused, we would reuse the same `session_id`, since
we delete one device and insert a new one directly, resulting in the
query to get a new `session_id` to return the previous session_id.
(`SELECT count(access_token)`)
This commit is contained in:
Till 2023-02-17 11:39:46 +01:00 committed by GitHub
parent 11d9b9db0e
commit f0805071d5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 74 additions and 18 deletions

View file

@ -588,16 +588,42 @@ func (d *Database) CreateDevice(
deviceID *string, accessToken string, displayName *string, ipAddr, userAgent string,
) (dev *api.Device, returnErr error) {
if deviceID != nil {
returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
var err error
// Revoke existing tokens for this device
if err = d.Devices.DeleteDevice(ctx, txn, *deviceID, localpart, serverName); err != nil {
return err
}
_, ok := d.Writer.(*sqlutil.ExclusiveWriter)
if ok { // we're using most likely using SQLite, so do things a little different
returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
var err error
dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, serverName, accessToken, displayName, ipAddr, userAgent)
return err
})
devices, err := d.Devices.SelectDevicesByLocalpart(ctx, txn, localpart, serverName, "")
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return err
}
// No devices yet, only create a new one
if len(devices) == 0 {
dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, serverName, accessToken, displayName, ipAddr, userAgent)
return err
}
sessionID := devices[0].SessionID + 1
// Revoke existing tokens for this device
if err = d.Devices.DeleteDevice(ctx, txn, *deviceID, localpart, serverName); err != nil {
return err
}
// Create a new device with the session ID incremented
dev, err = d.Devices.InsertDeviceWithSessionID(ctx, txn, *deviceID, localpart, serverName, accessToken, displayName, ipAddr, userAgent, sessionID)
return err
})
} else {
returnErr = d.Writer.Do(d.DB, nil, func(txn *sql.Tx) error {
var err error
// Revoke existing tokens for this device
if err = d.Devices.DeleteDevice(ctx, txn, *deviceID, localpart, serverName); err != nil {
return err
}
dev, err = d.Devices.InsertDevice(ctx, txn, *deviceID, localpart, serverName, accessToken, displayName, ipAddr, userAgent)
return err
})
}
} else {
// We generate device IDs in a loop in case its already taken.
// We cap this at going round 5 times to ensure we don't spin forever
@ -618,7 +644,7 @@ func (d *Database) CreateDevice(
}
}
}
return
return dev, returnErr
}
// generateDeviceID creates a new device id. Returns an error if failed to generate