mirror of
https://github.com/hoernschen/dendrite.git
synced 2024-12-26 15:08:28 +00:00
64-bit stream IDs for device list updates (#2267)
This commit is contained in:
parent
e1881627d1
commit
e485f9c2bd
11 changed files with 33 additions and 37 deletions
|
@ -203,9 +203,9 @@ func (t *KeyChangeConsumer) onCrossSigningMessage(m api.DeviceMessage) bool {
|
|||
return err == nil
|
||||
}
|
||||
|
||||
func prevID(streamID int) []int {
|
||||
func prevID(streamID int64) []int64 {
|
||||
if streamID <= 1 {
|
||||
return nil
|
||||
}
|
||||
return []int{streamID - 1}
|
||||
return []int64{streamID - 1}
|
||||
}
|
||||
|
|
|
@ -70,7 +70,7 @@ type DeviceMessage struct {
|
|||
*DeviceKeys `json:"DeviceKeys,omitempty"`
|
||||
*eduapi.OutputCrossSigningKeyUpdate `json:"CrossSigningKeyUpdate,omitempty"`
|
||||
// A monotonically increasing number which represents device changes for this user.
|
||||
StreamID int
|
||||
StreamID int64
|
||||
DeviceChangeID int64
|
||||
}
|
||||
|
||||
|
@ -108,7 +108,7 @@ type DeviceKeys struct {
|
|||
}
|
||||
|
||||
// WithStreamID returns a copy of this device message with the given stream ID
|
||||
func (k *DeviceKeys) WithStreamID(streamID int) DeviceMessage {
|
||||
func (k *DeviceKeys) WithStreamID(streamID int64) DeviceMessage {
|
||||
return DeviceMessage{
|
||||
DeviceKeys: k,
|
||||
StreamID: streamID,
|
||||
|
@ -281,7 +281,7 @@ type QueryDeviceMessagesRequest struct {
|
|||
|
||||
type QueryDeviceMessagesResponse struct {
|
||||
// The latest stream ID
|
||||
StreamID int
|
||||
StreamID int64
|
||||
Devices []DeviceMessage
|
||||
Error *KeyError
|
||||
}
|
||||
|
|
|
@ -109,7 +109,7 @@ type DeviceListUpdaterDatabase interface {
|
|||
StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error
|
||||
|
||||
// PrevIDsExists returns true if all prev IDs exist for this user.
|
||||
PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error)
|
||||
PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error)
|
||||
|
||||
// DeviceKeysJSON populates the KeyJSON for the given keys. If any proided `keys` have a `KeyJSON` or `StreamID` already then it will be replaced.
|
||||
DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
|
||||
|
|
|
@ -46,7 +46,7 @@ func (p *mockKeyChangeProducer) ProduceKeyChanges(keys []api.DeviceMessage) erro
|
|||
|
||||
type mockDeviceListUpdaterDatabase struct {
|
||||
staleUsers map[string]bool
|
||||
prevIDsExist func(string, []int) bool
|
||||
prevIDsExist func(string, []int64) bool
|
||||
storedKeys []api.DeviceMessage
|
||||
mu sync.Mutex // protect staleUsers
|
||||
}
|
||||
|
@ -101,7 +101,7 @@ func (d *mockDeviceListUpdaterDatabase) StoreRemoteDeviceKeys(ctx context.Contex
|
|||
}
|
||||
|
||||
// PrevIDsExists returns true if all prev IDs exist for this user.
|
||||
func (d *mockDeviceListUpdaterDatabase) PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error) {
|
||||
func (d *mockDeviceListUpdaterDatabase) PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) {
|
||||
return d.prevIDsExist(userID, prevIDs), nil
|
||||
}
|
||||
|
||||
|
@ -139,7 +139,7 @@ func newFedClient(tripper func(*http.Request) (*http.Response, error)) *gomatrix
|
|||
func TestUpdateHavePrevID(t *testing.T) {
|
||||
db := &mockDeviceListUpdaterDatabase{
|
||||
staleUsers: make(map[string]bool),
|
||||
prevIDsExist: func(string, []int) bool {
|
||||
prevIDsExist: func(string, []int64) bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
|
@ -151,7 +151,7 @@ func TestUpdateHavePrevID(t *testing.T) {
|
|||
Deleted: false,
|
||||
DeviceID: "FOO",
|
||||
Keys: []byte(`{"key":"value"}`),
|
||||
PrevID: []int{0},
|
||||
PrevID: []int64{0},
|
||||
StreamID: 1,
|
||||
UserID: "@alice:localhost",
|
||||
}
|
||||
|
@ -185,7 +185,7 @@ func TestUpdateHavePrevID(t *testing.T) {
|
|||
func TestUpdateNoPrevID(t *testing.T) {
|
||||
db := &mockDeviceListUpdaterDatabase{
|
||||
staleUsers: make(map[string]bool),
|
||||
prevIDsExist: func(string, []int) bool {
|
||||
prevIDsExist: func(string, []int64) bool {
|
||||
return false
|
||||
},
|
||||
}
|
||||
|
@ -226,7 +226,7 @@ func TestUpdateNoPrevID(t *testing.T) {
|
|||
Deleted: false,
|
||||
DeviceID: "another_device_id",
|
||||
Keys: []byte(`{"key":"value"}`),
|
||||
PrevID: []int{3},
|
||||
PrevID: []int64{3},
|
||||
StreamID: 4,
|
||||
UserID: remoteUserID,
|
||||
}
|
||||
|
@ -268,7 +268,7 @@ func TestDebounce(t *testing.T) {
|
|||
t.Skipf("panic on closed channel on GHA")
|
||||
db := &mockDeviceListUpdaterDatabase{
|
||||
staleUsers: make(map[string]bool),
|
||||
prevIDsExist: func(string, []int) bool {
|
||||
prevIDsExist: func(string, []int64) bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
|
|
|
@ -205,7 +205,7 @@ func (a *KeyInternalAPI) QueryDeviceMessages(ctx context.Context, req *api.Query
|
|||
}
|
||||
return
|
||||
}
|
||||
maxStreamID := 0
|
||||
maxStreamID := int64(0)
|
||||
for _, m := range msgs {
|
||||
if m.StreamID > maxStreamID {
|
||||
maxStreamID = m.StreamID
|
||||
|
|
|
@ -49,7 +49,7 @@ type Database interface {
|
|||
StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error
|
||||
|
||||
// PrevIDsExists returns true if all prev IDs exist for this user.
|
||||
PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error)
|
||||
PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error)
|
||||
|
||||
// DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected.
|
||||
// If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice.
|
||||
|
|
|
@ -121,7 +121,7 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
|
|||
func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
|
||||
for i, key := range keys {
|
||||
var keyJSONStr string
|
||||
var streamID int
|
||||
var streamID int64
|
||||
var displayName sql.NullString
|
||||
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
|
@ -138,15 +138,15 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) {
|
||||
func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error) {
|
||||
// nullable if there are no results
|
||||
var nullStream sql.NullInt32
|
||||
var nullStream sql.NullInt64
|
||||
err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
|
||||
if err == sql.ErrNoRows {
|
||||
err = nil
|
||||
}
|
||||
if nullStream.Valid {
|
||||
streamID = nullStream.Int32
|
||||
streamID = nullStream.Int64
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -211,7 +211,7 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
|
|||
}
|
||||
dk.UserID = userID
|
||||
var keyJSON string
|
||||
var streamID int
|
||||
var streamID int64
|
||||
var displayName sql.NullString
|
||||
if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -59,12 +59,8 @@ func (d *Database) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage)
|
|||
return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys)
|
||||
}
|
||||
|
||||
func (d *Database) PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error) {
|
||||
sids := make([]int64, len(prevIDs))
|
||||
for i := range prevIDs {
|
||||
sids[i] = int64(prevIDs[i])
|
||||
}
|
||||
count, err := d.DeviceKeysTable.CountStreamIDsForUser(ctx, userID, sids)
|
||||
func (d *Database) PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) {
|
||||
count, err := d.DeviceKeysTable.CountStreamIDsForUser(ctx, userID, prevIDs)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
@ -85,7 +81,7 @@ func (d *Database) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceM
|
|||
|
||||
func (d *Database) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error {
|
||||
// work out the latest stream IDs for each user
|
||||
userIDToStreamID := make(map[string]int)
|
||||
userIDToStreamID := make(map[string]int64)
|
||||
for _, k := range keys {
|
||||
userIDToStreamID[k.UserID] = 0
|
||||
}
|
||||
|
@ -95,7 +91,7 @@ func (d *Database) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMe
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userIDToStreamID[userID] = int(streamID)
|
||||
userIDToStreamID[userID] = streamID
|
||||
}
|
||||
// set the stream IDs for each key
|
||||
for i := range keys {
|
||||
|
|
|
@ -145,7 +145,7 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
|
|||
dk.Type = api.TypeDeviceKeyUpdate
|
||||
dk.UserID = userID
|
||||
var keyJSON string
|
||||
var streamID int
|
||||
var streamID int64
|
||||
var displayName sql.NullString
|
||||
if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil {
|
||||
return nil, err
|
||||
|
@ -166,7 +166,7 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
|
|||
func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
|
||||
for i, key := range keys {
|
||||
var keyJSONStr string
|
||||
var streamID int
|
||||
var streamID int64
|
||||
var displayName sql.NullString
|
||||
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
|
@ -183,15 +183,15 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) {
|
||||
func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error) {
|
||||
// nullable if there are no results
|
||||
var nullStream sql.NullInt32
|
||||
var nullStream sql.NullInt64
|
||||
err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
|
||||
if err == sql.ErrNoRows {
|
||||
err = nil
|
||||
}
|
||||
if nullStream.Valid {
|
||||
streamID = nullStream.Int32
|
||||
streamID = nullStream.Int64
|
||||
}
|
||||
return
|
||||
}
|
||||
|
@ -204,13 +204,13 @@ func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID
|
|||
}
|
||||
query := strings.Replace(countStreamIDsForUserSQL, "($2)", sqlutil.QueryVariadicOffset(len(streamIDs), 1), 1)
|
||||
// nullable if there are no results
|
||||
var count sql.NullInt32
|
||||
var count sql.NullInt64
|
||||
err := s.db.QueryRowContext(ctx, query, iStreamIDs...).Scan(&count)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if count.Valid {
|
||||
return int(count.Int32), nil
|
||||
return int(count.Int64), nil
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
|
|
@ -177,7 +177,7 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("DeviceKeysForUser returned error: %s", err)
|
||||
}
|
||||
wantStreamIDs := map[string]int{
|
||||
wantStreamIDs := map[string]int64{
|
||||
"AAA": 3,
|
||||
"another_device": 2,
|
||||
}
|
||||
|
|
|
@ -37,7 +37,7 @@ type OneTimeKeys interface {
|
|||
type DeviceKeys interface {
|
||||
SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
|
||||
InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error
|
||||
SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error)
|
||||
SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error)
|
||||
CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error)
|
||||
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error)
|
||||
DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error
|
||||
|
|
Loading…
Reference in a new issue