diff --git a/syncapi/storage/shared/syncserver.go b/syncapi/storage/shared/syncserver.go index 90db145e..d8ae455e 100644 --- a/syncapi/storage/shared/syncserver.go +++ b/syncapi/storage/shared/syncserver.go @@ -1032,7 +1032,9 @@ func (d *Database) getStateDeltas( } // add peek blocks + peeking := make(map[string]bool) for _, peek := range peeks { + peeking[peek.RoomID] = true if peek.New { // send full room state down instead of a delta var s []types.StreamEvent @@ -1067,6 +1069,14 @@ func (d *Database) getStateDeltas( // the timeline. if membership := getMembershipFromEvent(&ev.Event, userID); membership != "" { if membership == gomatrixserverlib.Join { + if peeking[roomID] { + // we automatically cancel our peeks when we join a room + _, err = d.Peeks.DeletePeeks(ctx, txn, roomID, userID) + if err != nil { + return nil, nil, err + } + } + // send full room state down instead of a delta var s []types.StreamEvent s, err = d.currentStateStreamEventsForRoom(ctx, txn, roomID, stateFilter) diff --git a/syncapi/storage/sqlite3/peeks_table.go b/syncapi/storage/sqlite3/peeks_table.go index 088f4007..7e8d4c69 100644 --- a/syncapi/storage/sqlite3/peeks_table.go +++ b/syncapi/storage/sqlite3/peeks_table.go @@ -48,6 +48,9 @@ const insertPeekSQL = "" + const deletePeekSQL = "" + "DELETE FROM syncapi_peeks WHERE room_id = $1 AND user_id = $2 and device_id = $3" +const deletePeeksSQL = "" + + "DELETE FROM syncapi_peeks WHERE room_id = $1 AND user_id = $2" + const selectPeeksSQL = "" + "SELECT room_id, new FROM syncapi_peeks WHERE user_id = $1 and device_id = $2" @@ -62,6 +65,7 @@ type peekStatements struct { streamIDStatements *streamIDStatements insertPeekStmt *sql.Stmt deletePeekStmt *sql.Stmt + deletePeeksStmt *sql.Stmt selectPeeksStmt *sql.Stmt selectPeekingDevicesStmt *sql.Stmt markPeeksAsOldStmt *sql.Stmt @@ -82,6 +86,9 @@ func NewSqlitePeeksTable(db *sql.DB, streamID *streamIDStatements) (tables.Peeks if s.deletePeekStmt, err = db.Prepare(deletePeekSQL); err != nil { return nil, err } + if s.deletePeeksStmt, err = db.Prepare(deletePeeksSQL); err != nil { + return nil, err + } if s.selectPeeksStmt, err = db.Prepare(selectPeeksSQL); err != nil { return nil, err } @@ -117,6 +124,17 @@ func (s *peekStatements) DeletePeek( return } +func (s *peekStatements) DeletePeeks( + ctx context.Context, txn *sql.Tx, roomID, userID string, +) (streamPos types.StreamPosition, err error) { + streamPos, err = s.streamIDStatements.nextStreamID(ctx, txn) + if err != nil { + return + } + _, err = sqlutil.TxStmt(txn, s.deletePeeksStmt).ExecContext(ctx, roomID, userID) + return +} + func (s *peekStatements) SelectPeeks( ctx context.Context, txn *sql.Tx, userID, deviceID string, ) (peeks []types.Peek, err error) { diff --git a/syncapi/storage/tables/interface.go b/syncapi/storage/tables/interface.go index b7281f11..3c6ee4bb 100644 --- a/syncapi/storage/tables/interface.go +++ b/syncapi/storage/tables/interface.go @@ -42,6 +42,7 @@ type Invites interface { type Peeks interface { InsertPeek(ctx context.Context, txn *sql.Tx, roomID, userID, deviceID string) (streamPos types.StreamPosition, err error) DeletePeek(ctx context.Context, txn *sql.Tx, roomID, userID, deviceID string) (streamPos types.StreamPosition, err error) + DeletePeeks(ctx context.Context, txn *sql.Tx, roomID, userID string) (streamPos types.StreamPosition, err error) SelectPeeks(ctxt context.Context, txn *sql.Tx, userID, deviceID string) (peeks []types.Peek, err error) SelectPeekingDevices(ctxt context.Context) (peekingDevices map[string][]types.PeekingDevice, err error) MarkPeeksAsOld(ctxt context.Context, txn *sql.Tx, userID, deviceID string) (err error) diff --git a/syncapi/sync/notifier.go b/syncapi/sync/notifier.go index e6f7440e..61300232 100644 --- a/syncapi/sync/notifier.go +++ b/syncapi/sync/notifier.go @@ -248,7 +248,7 @@ func (n *Notifier) wakeupUsers(userIDs []string, peekingDevices []types.PeekingD if peekingDevices != nil { for _, peekingDevice := range peekingDevices { // TODO: don't bother waking up for devices whose users we already woke up - if stream := n.fetchUserDeviceStream(peekingDevice.UserID, peekingDevice.ID, false); stream != nil { + if stream := n.fetchUserDeviceStream(peekingDevice.UserID, peekingDevice.DeviceID, false); stream != nil { stream.Broadcast(newPos) // wake up all goroutines Wait()ing on this stream } } @@ -337,7 +337,7 @@ func (n *Notifier) addPeekingDevice(roomID, userID, deviceID string) { if _, ok := n.roomIDToPeekingDevices[roomID]; !ok { n.roomIDToPeekingDevices[roomID] = make(peekingDeviceSet) } - n.roomIDToPeekingDevices[roomID].add(types.PeekingDevice{deviceID, userID}) + n.roomIDToPeekingDevices[roomID].add(types.PeekingDevice{userID, deviceID}) } // Not thread-safe: must be called on the OnNewEvent goroutine only @@ -346,7 +346,7 @@ func (n *Notifier) removePeekingDevice(roomID, userID, deviceID string) { n.roomIDToPeekingDevices[roomID] = make(peekingDeviceSet) } // XXX: is this going to work as a key? - n.roomIDToPeekingDevices[roomID].remove(types.PeekingDevice{deviceID, userID}) + n.roomIDToPeekingDevices[roomID].remove(types.PeekingDevice{userID, deviceID}) } // Not thread-safe: must be called on the OnNewEvent goroutine only diff --git a/syncapi/types/types.go b/syncapi/types/types.go index b9888a65..23d88626 100644 --- a/syncapi/types/types.go +++ b/syncapi/types/types.go @@ -511,8 +511,8 @@ type SendToDeviceEvent struct { } type PeekingDevice struct { - ID string - UserID string + UserID string + DeviceID string } type Peek struct {