64-bit stream IDs for device list updates (#2267)

This commit is contained in:
Neil Alexander 2022-03-10 13:17:28 +00:00 committed by GitHub
parent e1881627d1
commit e485f9c2bd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 33 additions and 37 deletions

View file

@ -49,7 +49,7 @@ type Database interface {
StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceMessage, clearUserIDs []string) error
// PrevIDsExists returns true if all prev IDs exist for this user.
PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error)
PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error)
// DeviceKeysForUser returns the device keys for the device IDs given. If the length of deviceIDs is 0, all devices are selected.
// If there are some missing keys, they are omitted from the returned slice. There is no ordering on the returned slice.

View file

@ -121,7 +121,7 @@ func NewPostgresDeviceKeysTable(db *sql.DB) (tables.DeviceKeys, error) {
func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
for i, key := range keys {
var keyJSONStr string
var streamID int
var streamID int64
var displayName sql.NullString
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName)
if err != nil && err != sql.ErrNoRows {
@ -138,15 +138,15 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []
return nil
}
func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) {
func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error) {
// nullable if there are no results
var nullStream sql.NullInt32
var nullStream sql.NullInt64
err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
if err == sql.ErrNoRows {
err = nil
}
if nullStream.Valid {
streamID = nullStream.Int32
streamID = nullStream.Int64
}
return
}
@ -211,7 +211,7 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
}
dk.UserID = userID
var keyJSON string
var streamID int
var streamID int64
var displayName sql.NullString
if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil {
return nil, err

View file

@ -59,12 +59,8 @@ func (d *Database) DeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage)
return d.DeviceKeysTable.SelectDeviceKeysJSON(ctx, keys)
}
func (d *Database) PrevIDsExists(ctx context.Context, userID string, prevIDs []int) (bool, error) {
sids := make([]int64, len(prevIDs))
for i := range prevIDs {
sids[i] = int64(prevIDs[i])
}
count, err := d.DeviceKeysTable.CountStreamIDsForUser(ctx, userID, sids)
func (d *Database) PrevIDsExists(ctx context.Context, userID string, prevIDs []int64) (bool, error) {
count, err := d.DeviceKeysTable.CountStreamIDsForUser(ctx, userID, prevIDs)
if err != nil {
return false, err
}
@ -85,7 +81,7 @@ func (d *Database) StoreRemoteDeviceKeys(ctx context.Context, keys []api.DeviceM
func (d *Database) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMessage) error {
// work out the latest stream IDs for each user
userIDToStreamID := make(map[string]int)
userIDToStreamID := make(map[string]int64)
for _, k := range keys {
userIDToStreamID[k.UserID] = 0
}
@ -95,7 +91,7 @@ func (d *Database) StoreLocalDeviceKeys(ctx context.Context, keys []api.DeviceMe
if err != nil {
return err
}
userIDToStreamID[userID] = int(streamID)
userIDToStreamID[userID] = streamID
}
// set the stream IDs for each key
for i := range keys {

View file

@ -145,7 +145,7 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
dk.Type = api.TypeDeviceKeyUpdate
dk.UserID = userID
var keyJSON string
var streamID int
var streamID int64
var displayName sql.NullString
if err := rows.Scan(&dk.DeviceID, &keyJSON, &streamID, &displayName); err != nil {
return nil, err
@ -166,7 +166,7 @@ func (s *deviceKeysStatements) SelectBatchDeviceKeys(ctx context.Context, userID
func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error {
for i, key := range keys {
var keyJSONStr string
var streamID int
var streamID int64
var displayName sql.NullString
err := s.selectDeviceKeysStmt.QueryRowContext(ctx, key.UserID, key.DeviceID).Scan(&keyJSONStr, &streamID, &displayName)
if err != nil && err != sql.ErrNoRows {
@ -183,15 +183,15 @@ func (s *deviceKeysStatements) SelectDeviceKeysJSON(ctx context.Context, keys []
return nil
}
func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error) {
func (s *deviceKeysStatements) SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error) {
// nullable if there are no results
var nullStream sql.NullInt32
var nullStream sql.NullInt64
err = sqlutil.TxStmt(txn, s.selectMaxStreamForUserStmt).QueryRowContext(ctx, userID).Scan(&nullStream)
if err == sql.ErrNoRows {
err = nil
}
if nullStream.Valid {
streamID = nullStream.Int32
streamID = nullStream.Int64
}
return
}
@ -204,13 +204,13 @@ func (s *deviceKeysStatements) CountStreamIDsForUser(ctx context.Context, userID
}
query := strings.Replace(countStreamIDsForUserSQL, "($2)", sqlutil.QueryVariadicOffset(len(streamIDs), 1), 1)
// nullable if there are no results
var count sql.NullInt32
var count sql.NullInt64
err := s.db.QueryRowContext(ctx, query, iStreamIDs...).Scan(&count)
if err != nil {
return 0, err
}
if count.Valid {
return int(count.Int32), nil
return int(count.Int64), nil
}
return 0, nil
}

View file

@ -177,7 +177,7 @@ func TestDeviceKeysStreamIDGeneration(t *testing.T) {
if err != nil {
t.Fatalf("DeviceKeysForUser returned error: %s", err)
}
wantStreamIDs := map[string]int{
wantStreamIDs := map[string]int64{
"AAA": 3,
"another_device": 2,
}

View file

@ -37,7 +37,7 @@ type OneTimeKeys interface {
type DeviceKeys interface {
SelectDeviceKeysJSON(ctx context.Context, keys []api.DeviceMessage) error
InsertDeviceKeys(ctx context.Context, txn *sql.Tx, keys []api.DeviceMessage) error
SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int32, err error)
SelectMaxStreamIDForUser(ctx context.Context, txn *sql.Tx, userID string) (streamID int64, err error)
CountStreamIDsForUser(ctx context.Context, userID string, streamIDs []int64) (int, error)
SelectBatchDeviceKeys(ctx context.Context, userID string, deviceIDs []string, includeEmpty bool) ([]api.DeviceMessage, error)
DeleteDeviceKeys(ctx context.Context, txn *sql.Tx, userID, deviceID string) error