mirror of
https://github.com/hoernschen/dendrite.git
synced 2024-12-27 23:48:27 +00:00
Associate transactions with session IDs instead of device IDs (#789)
This commit is contained in:
parent
5eb63f1d1e
commit
43308d2f3f
9 changed files with 55 additions and 39 deletions
|
@ -21,5 +21,9 @@ type Device struct {
|
||||||
// The access_token granted to this device.
|
// The access_token granted to this device.
|
||||||
// This uniquely identifies the device from all other devices and clients.
|
// This uniquely identifies the device from all other devices and clients.
|
||||||
AccessToken string
|
AccessToken string
|
||||||
|
// The unique ID of the session identified by the access token.
|
||||||
|
// Can be used as a secure substitution in places where data needs to be
|
||||||
|
// associated with access tokens.
|
||||||
|
SessionID int64
|
||||||
// TODO: display name, last used timestamp, keys, etc
|
// TODO: display name, last used timestamp, keys, etc
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,11 +27,19 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
const devicesSchema = `
|
const devicesSchema = `
|
||||||
|
-- This sequence is used for automatic allocation of session_id.
|
||||||
|
CREATE SEQUENCE IF NOT EXISTS device_session_id_seq START 1;
|
||||||
|
|
||||||
-- Stores data about devices.
|
-- Stores data about devices.
|
||||||
CREATE TABLE IF NOT EXISTS device_devices (
|
CREATE TABLE IF NOT EXISTS device_devices (
|
||||||
-- The access token granted to this device. This has to be the primary key
|
-- The access token granted to this device. This has to be the primary key
|
||||||
-- so we can distinguish which device is making a given request.
|
-- so we can distinguish which device is making a given request.
|
||||||
access_token TEXT NOT NULL PRIMARY KEY,
|
access_token TEXT NOT NULL PRIMARY KEY,
|
||||||
|
-- The auto-allocated unique ID of the session identified by the access token.
|
||||||
|
-- This can be used as a secure substitution of the access token in situations
|
||||||
|
-- where data is associated with access tokens (e.g. transaction storage),
|
||||||
|
-- so we don't have to store users' access tokens everywhere.
|
||||||
|
session_id BIGINT NOT NULL DEFAULT nextval('device_session_id_seq'),
|
||||||
-- The device identifier. This only needs to uniquely identify a device for a given user, not globally.
|
-- The device identifier. This only needs to uniquely identify a device for a given user, not globally.
|
||||||
-- access_tokens will be clobbered based on the device ID for a user.
|
-- access_tokens will be clobbered based on the device ID for a user.
|
||||||
device_id TEXT NOT NULL,
|
device_id TEXT NOT NULL,
|
||||||
|
@ -51,10 +59,11 @@ CREATE UNIQUE INDEX IF NOT EXISTS device_localpart_id_idx ON device_devices(loca
|
||||||
`
|
`
|
||||||
|
|
||||||
const insertDeviceSQL = "" +
|
const insertDeviceSQL = "" +
|
||||||
"INSERT INTO device_devices(device_id, localpart, access_token, created_ts, display_name) VALUES ($1, $2, $3, $4, $5)"
|
"INSERT INTO device_devices(device_id, localpart, access_token, created_ts, display_name) VALUES ($1, $2, $3, $4, $5)" +
|
||||||
|
" RETURNING session_id"
|
||||||
|
|
||||||
const selectDeviceByTokenSQL = "" +
|
const selectDeviceByTokenSQL = "" +
|
||||||
"SELECT device_id, localpart FROM device_devices WHERE access_token = $1"
|
"SELECT session_id, device_id, localpart FROM device_devices WHERE access_token = $1"
|
||||||
|
|
||||||
const selectDeviceByIDSQL = "" +
|
const selectDeviceByIDSQL = "" +
|
||||||
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
|
"SELECT display_name FROM device_devices WHERE localpart = $1 and device_id = $2"
|
||||||
|
@ -120,14 +129,16 @@ func (s *devicesStatements) insertDevice(
|
||||||
displayName *string,
|
displayName *string,
|
||||||
) (*authtypes.Device, error) {
|
) (*authtypes.Device, error) {
|
||||||
createdTimeMS := time.Now().UnixNano() / 1000000
|
createdTimeMS := time.Now().UnixNano() / 1000000
|
||||||
|
var sessionID int64
|
||||||
stmt := common.TxStmt(txn, s.insertDeviceStmt)
|
stmt := common.TxStmt(txn, s.insertDeviceStmt)
|
||||||
if _, err := stmt.ExecContext(ctx, id, localpart, accessToken, createdTimeMS, displayName); err != nil {
|
if err := stmt.QueryRowContext(ctx, id, localpart, accessToken, createdTimeMS, displayName).Scan(&sessionID); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &authtypes.Device{
|
return &authtypes.Device{
|
||||||
ID: id,
|
ID: id,
|
||||||
UserID: userutil.MakeUserID(localpart, s.serverName),
|
UserID: userutil.MakeUserID(localpart, s.serverName),
|
||||||
AccessToken: accessToken,
|
AccessToken: accessToken,
|
||||||
|
SessionID: sessionID,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -161,7 +172,7 @@ func (s *devicesStatements) selectDeviceByToken(
|
||||||
var dev authtypes.Device
|
var dev authtypes.Device
|
||||||
var localpart string
|
var localpart string
|
||||||
stmt := s.selectDeviceByTokenStmt
|
stmt := s.selectDeviceByTokenStmt
|
||||||
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.ID, &localpart)
|
err := stmt.QueryRowContext(ctx, accessToken).Scan(&dev.SessionID, &dev.ID, &localpart)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
dev.UserID = userutil.MakeUserID(localpart, s.serverName)
|
||||||
dev.AccessToken = accessToken
|
dev.AccessToken = accessToken
|
||||||
|
|
|
@ -60,18 +60,18 @@ func SendEvent(
|
||||||
return *resErr
|
return *resErr
|
||||||
}
|
}
|
||||||
|
|
||||||
var txnAndDeviceID *api.TransactionID
|
var txnAndSessionID *api.TransactionID
|
||||||
if txnID != nil {
|
if txnID != nil {
|
||||||
txnAndDeviceID = &api.TransactionID{
|
txnAndSessionID = &api.TransactionID{
|
||||||
TransactionID: *txnID,
|
TransactionID: *txnID,
|
||||||
DeviceID: device.ID,
|
SessionID: device.SessionID,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// pass the new event to the roomserver and receive the correct event ID
|
// pass the new event to the roomserver and receive the correct event ID
|
||||||
// event ID in case of duplicate transaction is discarded
|
// event ID in case of duplicate transaction is discarded
|
||||||
eventID, err := producer.SendEvents(
|
eventID, err := producer.SendEvents(
|
||||||
req.Context(), []gomatrixserverlib.Event{*e}, cfg.Matrix.ServerName, txnAndDeviceID,
|
req.Context(), []gomatrixserverlib.Event{*e}, cfg.Matrix.ServerName, txnAndSessionID,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return httputil.LogThenError(req, err)
|
return httputil.LogThenError(req, err)
|
||||||
|
|
|
@ -75,9 +75,9 @@ type InputRoomEvent struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// TransactionID contains the transaction ID sent by a client when sending an
|
// TransactionID contains the transaction ID sent by a client when sending an
|
||||||
// event, along with the ID of that device.
|
// event, along with the ID of the client session.
|
||||||
type TransactionID struct {
|
type TransactionID struct {
|
||||||
DeviceID string `json:"device_id"`
|
SessionID int64 `json:"session_id"`
|
||||||
TransactionID string `json:"id"`
|
TransactionID string `json:"id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -32,7 +32,7 @@ type RoomEventDatabase interface {
|
||||||
StoreEvent(
|
StoreEvent(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
event gomatrixserverlib.Event,
|
event gomatrixserverlib.Event,
|
||||||
txnAndDeviceID *api.TransactionID,
|
txnAndSessionID *api.TransactionID,
|
||||||
authEventNIDs []types.EventNID,
|
authEventNIDs []types.EventNID,
|
||||||
) (types.RoomNID, types.StateAtEvent, error)
|
) (types.RoomNID, types.StateAtEvent, error)
|
||||||
// Look up the state entries for a list of string event IDs
|
// Look up the state entries for a list of string event IDs
|
||||||
|
@ -67,7 +67,7 @@ type RoomEventDatabase interface {
|
||||||
// Returns an empty string if no such event exists.
|
// Returns an empty string if no such event exists.
|
||||||
GetTransactionEventID(
|
GetTransactionEventID(
|
||||||
ctx context.Context, transactionID string,
|
ctx context.Context, transactionID string,
|
||||||
deviceID string, userID string,
|
sessionID int64, userID string,
|
||||||
) (string, error)
|
) (string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -100,7 +100,7 @@ func processRoomEvent(
|
||||||
if input.TransactionID != nil {
|
if input.TransactionID != nil {
|
||||||
tdID := input.TransactionID
|
tdID := input.TransactionID
|
||||||
eventID, err = db.GetTransactionEventID(
|
eventID, err = db.GetTransactionEventID(
|
||||||
ctx, tdID.TransactionID, tdID.DeviceID, input.Event.Sender(),
|
ctx, tdID.TransactionID, tdID.SessionID, input.Event.Sender(),
|
||||||
)
|
)
|
||||||
// On error OR event with the transaction already processed/processesing
|
// On error OR event with the transaction already processed/processesing
|
||||||
if err != nil || eventID != "" {
|
if err != nil || eventID != "" {
|
||||||
|
|
|
@ -47,7 +47,7 @@ func Open(dataSourceName string) (*Database, error) {
|
||||||
// StoreEvent implements input.EventDatabase
|
// StoreEvent implements input.EventDatabase
|
||||||
func (d *Database) StoreEvent(
|
func (d *Database) StoreEvent(
|
||||||
ctx context.Context, event gomatrixserverlib.Event,
|
ctx context.Context, event gomatrixserverlib.Event,
|
||||||
txnAndDeviceID *api.TransactionID, authEventNIDs []types.EventNID,
|
txnAndSessionID *api.TransactionID, authEventNIDs []types.EventNID,
|
||||||
) (types.RoomNID, types.StateAtEvent, error) {
|
) (types.RoomNID, types.StateAtEvent, error) {
|
||||||
var (
|
var (
|
||||||
roomNID types.RoomNID
|
roomNID types.RoomNID
|
||||||
|
@ -58,10 +58,10 @@ func (d *Database) StoreEvent(
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
|
|
||||||
if txnAndDeviceID != nil {
|
if txnAndSessionID != nil {
|
||||||
if err = d.statements.insertTransaction(
|
if err = d.statements.insertTransaction(
|
||||||
ctx, txnAndDeviceID.TransactionID,
|
ctx, txnAndSessionID.TransactionID,
|
||||||
txnAndDeviceID.DeviceID, event.Sender(), event.EventID(),
|
txnAndSessionID.SessionID, event.Sender(), event.EventID(),
|
||||||
); err != nil {
|
); err != nil {
|
||||||
return 0, types.StateAtEvent{}, err
|
return 0, types.StateAtEvent{}, err
|
||||||
}
|
}
|
||||||
|
@ -322,9 +322,9 @@ func (d *Database) GetLatestEventsForUpdate(
|
||||||
// GetTransactionEventID implements input.EventDatabase
|
// GetTransactionEventID implements input.EventDatabase
|
||||||
func (d *Database) GetTransactionEventID(
|
func (d *Database) GetTransactionEventID(
|
||||||
ctx context.Context, transactionID string,
|
ctx context.Context, transactionID string,
|
||||||
deviceID string, userID string,
|
sessionID int64, userID string,
|
||||||
) (string, error) {
|
) (string, error) {
|
||||||
eventID, err := d.statements.selectTransactionEventID(ctx, transactionID, deviceID, userID)
|
eventID, err := d.statements.selectTransactionEventID(ctx, transactionID, sessionID, userID)
|
||||||
if err == sql.ErrNoRows {
|
if err == sql.ErrNoRows {
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -23,8 +23,8 @@ const transactionsSchema = `
|
||||||
CREATE TABLE IF NOT EXISTS roomserver_transactions (
|
CREATE TABLE IF NOT EXISTS roomserver_transactions (
|
||||||
-- The transaction ID of the event.
|
-- The transaction ID of the event.
|
||||||
transaction_id TEXT NOT NULL,
|
transaction_id TEXT NOT NULL,
|
||||||
-- The device ID of the originating transaction.
|
-- The session ID of the originating transaction.
|
||||||
device_id TEXT NOT NULL,
|
session_id BIGINT NOT NULL,
|
||||||
-- User ID of the sender who authored the event
|
-- User ID of the sender who authored the event
|
||||||
user_id TEXT NOT NULL,
|
user_id TEXT NOT NULL,
|
||||||
-- Event ID corresponding to the transaction
|
-- Event ID corresponding to the transaction
|
||||||
|
@ -32,16 +32,16 @@ CREATE TABLE IF NOT EXISTS roomserver_transactions (
|
||||||
event_id TEXT NOT NULL,
|
event_id TEXT NOT NULL,
|
||||||
-- A transaction ID is unique for a user and device
|
-- A transaction ID is unique for a user and device
|
||||||
-- This automatically creates an index.
|
-- This automatically creates an index.
|
||||||
PRIMARY KEY (transaction_id, device_id, user_id)
|
PRIMARY KEY (transaction_id, session_id, user_id)
|
||||||
);
|
);
|
||||||
`
|
`
|
||||||
const insertTransactionSQL = "" +
|
const insertTransactionSQL = "" +
|
||||||
"INSERT INTO roomserver_transactions (transaction_id, device_id, user_id, event_id)" +
|
"INSERT INTO roomserver_transactions (transaction_id, session_id, user_id, event_id)" +
|
||||||
" VALUES ($1, $2, $3, $4)"
|
" VALUES ($1, $2, $3, $4)"
|
||||||
|
|
||||||
const selectTransactionEventIDSQL = "" +
|
const selectTransactionEventIDSQL = "" +
|
||||||
"SELECT event_id FROM roomserver_transactions" +
|
"SELECT event_id FROM roomserver_transactions" +
|
||||||
" WHERE transaction_id = $1 AND device_id = $2 AND user_id = $3"
|
" WHERE transaction_id = $1 AND session_id = $2 AND user_id = $3"
|
||||||
|
|
||||||
type transactionStatements struct {
|
type transactionStatements struct {
|
||||||
insertTransactionStmt *sql.Stmt
|
insertTransactionStmt *sql.Stmt
|
||||||
|
@ -63,12 +63,12 @@ func (s *transactionStatements) prepare(db *sql.DB) (err error) {
|
||||||
func (s *transactionStatements) insertTransaction(
|
func (s *transactionStatements) insertTransaction(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
transactionID string,
|
transactionID string,
|
||||||
deviceID string,
|
sessionID int64,
|
||||||
userID string,
|
userID string,
|
||||||
eventID string,
|
eventID string,
|
||||||
) (err error) {
|
) (err error) {
|
||||||
_, err = s.insertTransactionStmt.ExecContext(
|
_, err = s.insertTransactionStmt.ExecContext(
|
||||||
ctx, transactionID, deviceID, userID, eventID,
|
ctx, transactionID, sessionID, userID, eventID,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -76,11 +76,11 @@ func (s *transactionStatements) insertTransaction(
|
||||||
func (s *transactionStatements) selectTransactionEventID(
|
func (s *transactionStatements) selectTransactionEventID(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
transactionID string,
|
transactionID string,
|
||||||
deviceID string,
|
sessionID int64,
|
||||||
userID string,
|
userID string,
|
||||||
) (eventID string, err error) {
|
) (eventID string, err error) {
|
||||||
err = s.selectTransactionEventIDStmt.QueryRowContext(
|
err = s.selectTransactionEventIDStmt.QueryRowContext(
|
||||||
ctx, transactionID, deviceID, userID,
|
ctx, transactionID, sessionID, userID,
|
||||||
).Scan(&eventID)
|
).Scan(&eventID)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -54,7 +54,7 @@ CREATE TABLE IF NOT EXISTS syncapi_output_room_events (
|
||||||
-- if there is no delta.
|
-- if there is no delta.
|
||||||
add_state_ids TEXT[],
|
add_state_ids TEXT[],
|
||||||
remove_state_ids TEXT[],
|
remove_state_ids TEXT[],
|
||||||
device_id TEXT, -- The local device that sent the event, if any
|
session_id BIGINT, -- The client session that sent the event, if any
|
||||||
transaction_id TEXT -- The transaction id used to send the event, if any
|
transaction_id TEXT -- The transaction id used to send the event, if any
|
||||||
);
|
);
|
||||||
-- for event selection
|
-- for event selection
|
||||||
|
@ -63,14 +63,14 @@ CREATE UNIQUE INDEX IF NOT EXISTS syncapi_event_id_idx ON syncapi_output_room_ev
|
||||||
|
|
||||||
const insertEventSQL = "" +
|
const insertEventSQL = "" +
|
||||||
"INSERT INTO syncapi_output_room_events (" +
|
"INSERT INTO syncapi_output_room_events (" +
|
||||||
"room_id, event_id, event_json, type, sender, contains_url, add_state_ids, remove_state_ids, device_id, transaction_id" +
|
"room_id, event_id, event_json, type, sender, contains_url, add_state_ids, remove_state_ids, session_id, transaction_id" +
|
||||||
") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id"
|
") VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id"
|
||||||
|
|
||||||
const selectEventsSQL = "" +
|
const selectEventsSQL = "" +
|
||||||
"SELECT id, event_json FROM syncapi_output_room_events WHERE event_id = ANY($1)"
|
"SELECT id, event_json FROM syncapi_output_room_events WHERE event_id = ANY($1)"
|
||||||
|
|
||||||
const selectRecentEventsSQL = "" +
|
const selectRecentEventsSQL = "" +
|
||||||
"SELECT id, event_json, device_id, transaction_id FROM syncapi_output_room_events" +
|
"SELECT id, event_json, session_id, transaction_id FROM syncapi_output_room_events" +
|
||||||
" WHERE room_id = $1 AND id > $2 AND id <= $3" +
|
" WHERE room_id = $1 AND id > $2 AND id <= $3" +
|
||||||
" ORDER BY id DESC LIMIT $4"
|
" ORDER BY id DESC LIMIT $4"
|
||||||
|
|
||||||
|
@ -221,9 +221,10 @@ func (s *outputRoomEventsStatements) insertEvent(
|
||||||
event *gomatrixserverlib.Event, addState, removeState []string,
|
event *gomatrixserverlib.Event, addState, removeState []string,
|
||||||
transactionID *api.TransactionID,
|
transactionID *api.TransactionID,
|
||||||
) (streamPos int64, err error) {
|
) (streamPos int64, err error) {
|
||||||
var deviceID, txnID *string
|
var txnID *string
|
||||||
|
var sessionID *int64
|
||||||
if transactionID != nil {
|
if transactionID != nil {
|
||||||
deviceID = &transactionID.DeviceID
|
sessionID = &transactionID.SessionID
|
||||||
txnID = &transactionID.TransactionID
|
txnID = &transactionID.TransactionID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -246,7 +247,7 @@ func (s *outputRoomEventsStatements) insertEvent(
|
||||||
containsURL,
|
containsURL,
|
||||||
pq.StringArray(addState),
|
pq.StringArray(addState),
|
||||||
pq.StringArray(removeState),
|
pq.StringArray(removeState),
|
||||||
deviceID,
|
sessionID,
|
||||||
txnID,
|
txnID,
|
||||||
).Scan(&streamPos)
|
).Scan(&streamPos)
|
||||||
return
|
return
|
||||||
|
@ -296,11 +297,11 @@ func rowsToStreamEvents(rows *sql.Rows) ([]streamEvent, error) {
|
||||||
var (
|
var (
|
||||||
streamPos int64
|
streamPos int64
|
||||||
eventBytes []byte
|
eventBytes []byte
|
||||||
deviceID *string
|
sessionID *int64
|
||||||
txnID *string
|
txnID *string
|
||||||
transactionID *api.TransactionID
|
transactionID *api.TransactionID
|
||||||
)
|
)
|
||||||
if err := rows.Scan(&streamPos, &eventBytes, &deviceID, &txnID); err != nil {
|
if err := rows.Scan(&streamPos, &eventBytes, &sessionID, &txnID); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
// TODO: Handle redacted events
|
// TODO: Handle redacted events
|
||||||
|
@ -309,9 +310,9 @@ func rowsToStreamEvents(rows *sql.Rows) ([]streamEvent, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if deviceID != nil && txnID != nil {
|
if sessionID != nil && txnID != nil {
|
||||||
transactionID = &api.TransactionID{
|
transactionID = &api.TransactionID{
|
||||||
DeviceID: *deviceID,
|
SessionID: *sessionID,
|
||||||
TransactionID: *txnID,
|
TransactionID: *txnID,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -893,7 +893,7 @@ func streamEventsToEvents(device *authtypes.Device, in []streamEvent) []gomatrix
|
||||||
for i := 0; i < len(in); i++ {
|
for i := 0; i < len(in); i++ {
|
||||||
out[i] = in[i].Event
|
out[i] = in[i].Event
|
||||||
if device != nil && in[i].transactionID != nil {
|
if device != nil && in[i].transactionID != nil {
|
||||||
if device.UserID == in[i].Sender() && device.ID == in[i].transactionID.DeviceID {
|
if device.UserID == in[i].Sender() && device.SessionID == in[i].transactionID.SessionID {
|
||||||
err := out[i].SetUnsignedField(
|
err := out[i].SetUnsignedField(
|
||||||
"transaction_id", in[i].transactionID.TransactionID,
|
"transaction_id", in[i].transactionID.TransactionID,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in a new issue