From 0cc1c2f90972c84d5ee655577923528fafe7d8f9 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 30 Nov 2017 14:51:08 +0000 Subject: [PATCH] Add transaction ID to events if sending device --- .../dendrite/syncapi/storage/syncserver.go | 38 +++++++++++++------ .../dendrite/syncapi/sync/notifier.go | 2 +- .../dendrite/syncapi/sync/request.go | 8 ++-- .../dendrite/syncapi/sync/requestpool.go | 8 ++-- 4 files changed, 36 insertions(+), 20 deletions(-) diff --git a/src/github.com/matrix-org/dendrite/syncapi/storage/syncserver.go b/src/github.com/matrix-org/dendrite/syncapi/storage/syncserver.go index 8a5b9648..58884b86 100644 --- a/src/github.com/matrix-org/dendrite/syncapi/storage/syncserver.go +++ b/src/github.com/matrix-org/dendrite/syncapi/storage/syncserver.go @@ -19,6 +19,7 @@ import ( "database/sql" "fmt" + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" "github.com/matrix-org/dendrite/roomserver/api" // Import the postgres database driver. _ "github.com/lib/pq" @@ -92,7 +93,7 @@ func (d *SyncServerDatabase) Events(ctx context.Context, eventIDs []string) ([]g if err != nil { return nil, err } - return streamEventsToEvents(streamEvents), nil + return streamEventsToEvents(nil, streamEvents), nil } // WriteEvent into the database. It is not safe to call this function from multiple goroutines, as it would create races @@ -211,7 +212,7 @@ func (d *SyncServerDatabase) syncStreamPositionTx( // IncrementalSync returns all the data needed in order to create an incremental sync response. func (d *SyncServerDatabase) IncrementalSync( ctx context.Context, - userID string, + device *authtypes.Device, fromPos, toPos types.StreamPosition, numRecentEventsPerRoom int, ) (*types.Response, error) { @@ -226,21 +227,21 @@ func (d *SyncServerDatabase) IncrementalSync( // joined rooms, but also which rooms have membership transitions for this user between the 2 stream positions. // This works out what the 'state' key should be for each room as well as which membership block // to put the room into. - deltas, err := d.getStateDeltas(ctx, txn, fromPos, toPos, userID) + deltas, err := d.getStateDeltas(ctx, device, txn, fromPos, toPos, device.UserID) if err != nil { return nil, err } res := types.NewResponse(toPos) for _, delta := range deltas { - err = d.addRoomDeltaToResponse(ctx, txn, fromPos, toPos, delta, numRecentEventsPerRoom, res) + err = d.addRoomDeltaToResponse(ctx, device, txn, fromPos, toPos, delta, numRecentEventsPerRoom, res) if err != nil { return nil, err } } // TODO: This should be done in getStateDeltas - if err = d.addInvitesToResponse(ctx, txn, userID, fromPos, toPos, res); err != nil { + if err = d.addInvitesToResponse(ctx, txn, device.UserID, fromPos, toPos, res); err != nil { return nil, err } @@ -292,7 +293,7 @@ func (d *SyncServerDatabase) CompleteSync( if err != nil { return nil, err } - recentEvents := streamEventsToEvents(recentStreamEvents) + recentEvents := streamEventsToEvents(nil, recentStreamEvents) stateEvents = removeDuplicates(stateEvents, recentEvents) jr := types.NewJoinResponse() @@ -390,7 +391,9 @@ func (d *SyncServerDatabase) addInvitesToResponse( // addRoomDeltaToResponse adds a room state delta to a sync response func (d *SyncServerDatabase) addRoomDeltaToResponse( - ctx context.Context, txn *sql.Tx, + ctx context.Context, + device *authtypes.Device, + txn *sql.Tx, fromPos, toPos types.StreamPosition, delta stateDelta, numRecentEventsPerRoom int, @@ -412,7 +415,7 @@ func (d *SyncServerDatabase) addRoomDeltaToResponse( if err != nil { return err } - recentEvents := streamEventsToEvents(recentStreamEvents) + recentEvents := streamEventsToEvents(device, recentStreamEvents) delta.stateEvents = removeDuplicates(delta.stateEvents, recentEvents) // roll back // Don't bother appending empty room entries @@ -529,7 +532,7 @@ func (d *SyncServerDatabase) fetchMissingStateEvents( } func (d *SyncServerDatabase) getStateDeltas( - ctx context.Context, txn *sql.Tx, + ctx context.Context, device *authtypes.Device, txn *sql.Tx, fromPos, toPos types.StreamPosition, userID string, ) ([]stateDelta, error) { // Implement membership change algorithm: https://github.com/matrix-org/synapse/blob/v0.19.3/synapse/handlers/sync.py#L821 @@ -578,7 +581,7 @@ func (d *SyncServerDatabase) getStateDeltas( deltas = append(deltas, stateDelta{ membership: membership, membershipPos: ev.streamPosition, - stateEvents: streamEventsToEvents(stateStreamEvents), + stateEvents: streamEventsToEvents(device, stateStreamEvents), roomID: roomID, }) break @@ -594,7 +597,7 @@ func (d *SyncServerDatabase) getStateDeltas( for _, joinedRoomID := range joinedRoomIDs { deltas = append(deltas, stateDelta{ membership: "join", - stateEvents: streamEventsToEvents(state[joinedRoomID]), + stateEvents: streamEventsToEvents(device, state[joinedRoomID]), roomID: joinedRoomID, }) } @@ -602,10 +605,21 @@ func (d *SyncServerDatabase) getStateDeltas( return deltas, nil } -func streamEventsToEvents(in []streamEvent) []gomatrixserverlib.Event { +func streamEventsToEvents(device *authtypes.Device, in []streamEvent) []gomatrixserverlib.Event { out := make([]gomatrixserverlib.Event, len(in)) for i := 0; i < len(in); i++ { out[i] = in[i].Event + if device != nil && in[i].transactionID != nil { + if device.ID == in[i].transactionID.DeviceID { + // TODO: Don't clobber unsigned + ev, err := out[i].SetUnsigned(map[string]string{ + "transaction_id": in[i].transactionID.TransactionID, + }) + if err == nil { + out[i] = ev + } + } + } } return out } diff --git a/src/github.com/matrix-org/dendrite/syncapi/sync/notifier.go b/src/github.com/matrix-org/dendrite/syncapi/sync/notifier.go index 4712a2c7..5ed701d8 100644 --- a/src/github.com/matrix-org/dendrite/syncapi/sync/notifier.go +++ b/src/github.com/matrix-org/dendrite/syncapi/sync/notifier.go @@ -123,7 +123,7 @@ func (n *Notifier) GetListener(req syncRequest) UserStreamListener { n.removeEmptyUserStreams() - return n.fetchUserStream(req.userID, true).GetListener(req.ctx) + return n.fetchUserStream(req.device.UserID, true).GetListener(req.ctx) } // Load the membership states required to notify users correctly. diff --git a/src/github.com/matrix-org/dendrite/syncapi/sync/request.go b/src/github.com/matrix-org/dendrite/syncapi/sync/request.go index 7f525981..3c1befdd 100644 --- a/src/github.com/matrix-org/dendrite/syncapi/sync/request.go +++ b/src/github.com/matrix-org/dendrite/syncapi/sync/request.go @@ -20,6 +20,8 @@ import ( "strconv" "time" + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/syncapi/types" "github.com/matrix-org/util" log "github.com/sirupsen/logrus" @@ -31,7 +33,7 @@ const defaultTimelineLimit = 20 // syncRequest represents a /sync request, with sensible defaults/sanity checks applied. type syncRequest struct { ctx context.Context - userID string + device authtypes.Device limit int timeout time.Duration since *types.StreamPosition // nil means that no since token was supplied @@ -39,7 +41,7 @@ type syncRequest struct { log *log.Entry } -func newSyncRequest(req *http.Request, userID string) (*syncRequest, error) { +func newSyncRequest(req *http.Request, device authtypes.Device) (*syncRequest, error) { timeout := getTimeout(req.URL.Query().Get("timeout")) fullState := req.URL.Query().Get("full_state") wantFullState := fullState != "" && fullState != "false" @@ -50,7 +52,7 @@ func newSyncRequest(req *http.Request, userID string) (*syncRequest, error) { // TODO: Additional query params: set_presence, filter return &syncRequest{ ctx: req.Context(), - userID: userID, + device: device, timeout: timeout, since: since, wantFullState: wantFullState, diff --git a/src/github.com/matrix-org/dendrite/syncapi/sync/requestpool.go b/src/github.com/matrix-org/dendrite/syncapi/sync/requestpool.go index 15993b77..e9600243 100644 --- a/src/github.com/matrix-org/dendrite/syncapi/sync/requestpool.go +++ b/src/github.com/matrix-org/dendrite/syncapi/sync/requestpool.go @@ -48,7 +48,7 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype // Extract values from request logger := util.GetLogger(req.Context()) userID := device.UserID - syncReq, err := newSyncRequest(req, userID) + syncReq, err := newSyncRequest(req, *device) if err != nil { return util.JSONResponse{ Code: 400, @@ -122,16 +122,16 @@ func (rp *RequestPool) OnIncomingSyncRequest(req *http.Request, device *authtype func (rp *RequestPool) currentSyncForUser(req syncRequest, currentPos types.StreamPosition) (res *types.Response, err error) { // TODO: handle ignored users if req.since == nil { - res, err = rp.db.CompleteSync(req.ctx, req.userID, req.limit) + res, err = rp.db.CompleteSync(req.ctx, req.device.UserID, req.limit) } else { - res, err = rp.db.IncrementalSync(req.ctx, req.userID, *req.since, currentPos, req.limit) + res, err = rp.db.IncrementalSync(req.ctx, &req.device, *req.since, currentPos, req.limit) } if err != nil { return } - res, err = rp.appendAccountData(res, req.userID, req, currentPos) + res, err = rp.appendAccountData(res, req.device.UserID, req, currentPos) return }