Associate transactions with session IDs instead of device IDs (#789)

This commit is contained in:
Alex Chen 2019-08-24 00:55:40 +08:00 committed by GitHub
parent 5eb63f1d1e
commit 43308d2f3f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 55 additions and 39 deletions

View file

@ -21,5 +21,9 @@ type Device struct {
// The access_token granted to this device.
// This uniquely identifies the device from all other devices and clients.
AccessToken string
// The unique ID of the session identified by the access token.
// Can be used as a secure substitution in places where data needs to be
// associated with access tokens.
SessionID int64
// TODO: display name, last used timestamp, keys, etc
}

View file

@ -27,11 +27,19 @@ import (
)
const devicesSchema = `
-- This sequence is used for automatic allocation of session_id.
CREATE SEQUENCE IF NOT EXISTS device_session_id_seq START 1;
-- Stores data about devices.
CREATE TABLE IF NOT EXISTS device_devices (
-- The access token granted to this device. This has to be the primary key
-- so we can distinguish which device is making a given request.
access_token TEXT NOT NULL PRIMARY KEY,
-- The auto-allocated unique ID of the session identified by the access token.
-- This can be used as a secure substitution of the access token in situations
-- where data is associated with access tokens (e.g. transaction storage),
-- so we don't have to store users' access tokens everywhere.
session_id BIGINT NOT NULL DEFAULT nextval('device_session_id_seq'),
-- The device identifier. This only needs to uniquely identify a device for a given user, not globally.
-- access_tokens will be clobbered based on the device ID for a user.
device_id TEXT NOT NULL,
@ -51,10 +59,11 @@ CREATE UNIQUE INDEX IF NOT EXISTS device_localpart_id_idx ON device_devices(loca
`
const insertDeviceSQL = "" +
"INSERT INTO device_devices(device_id, localpart, access_token, created_ts, display_name) VALUES ($1, $2, $3, $4, $5)"
"INSERT INTO device_devices(device_id, localpart, access_token, created_ts, display_name) VALUES ($1, $2, $3, $4, $5)" +
" RETURNING session_id"
const selectDeviceByTokenSQL = "" +
"SELECT device_id, localpart FROM device_devices WHERE access_token = $1"
"SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1"
const selectDeviceByIDSQL = "" +
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
@ -120,14 +129,16 @@ func (s *devicesStatements) insertDevice(
displayName *string,
) (*authtypes.Device, error) {
createdTimeMS := time.Now().UnixNano() / 1000000
var sessionID int64
stmt := common.TxStmt(txn, s.insertDeviceStmt)
if _, err := stmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName); err != nil {
if err := stmt.QueryRowContext(ctx, id, localpart, accessToken, createdTimeMS, displayName).Scan(&sessionID); err != nil {
return nil, err
}
return &authtypes.Device{
ID: id,
UserID: userutil.MakeUserID(localpart, s.serverName),
AccessToken: accessToken,
SessionID: sessionID,
}, nil
}
@ -161,7 +172,7 @@ func (s *devicesStatements) selectDeviceByToken(
var dev authtypes.Device
var localpart string
stmt := s.selectDeviceByTokenStmt
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.ID, &localpart)
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart)
if err == nil {
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
dev.AccessToken = accessToken

View file

@ -60,18 +60,18 @@ func SendEvent(
return *resErr
}
var txnAndDeviceID *api.TransactionID
var txnAndSessionID *api.TransactionID
if txnID != nil {
txnAndDeviceID = &api.TransactionID{
txnAndSessionID = &api.TransactionID{
TransactionID: *txnID,
DeviceID: device.ID,
SessionID: device.SessionID,
}
}
// pass the new event to the roomserver and receive the correct event ID
// event ID in case of duplicate transaction is discarded
eventID, err := producer.SendEvents(
req.Context(), []gomatrixserverlib.Event{*e}, cfg.Matrix.ServerName, txnAndDeviceID,
req.Context(), []gomatrixserverlib.Event{*e}, cfg.Matrix.ServerName, txnAndSessionID,
)
if err != nil {
return httputil.LogThenError(req, err)